R3BROOT
R3B analysis software
Loading...
Searching...
No Matches
R3BNeulandNeutronsScikit.cxx
Go to the documentation of this file.
2#include "FairLogger.h"
3#include "FairRootManager.h"
4#include "TPython.h"
5#include <TClonesArray.h>
6#include <iostream>
7#include <utility>
8
10 TString inputMult,
11 TString inputCluster,
12 TString output)
13 : FairTask("R3BNeulandNeutronsScikit")
14 , fInputMult(std::move(inputMult))
15 , fMultiplicity(nullptr)
16 , fInputCluster(std::move(inputCluster))
17 , fClusters(nullptr)
18 , fNeutrons(std::move(output))
19 , fMinProb(0.1)
20{
21 TPython::Exec("import sys; v = sys.version.replace('\\n', '')");
22 TPython::Exec("print(f'R3BNeulandNeutronsScikit running TPython with Python version {v}')");
23
24 TPython::Exec("import numpy as np; import joblib; model = joblib.load('" + model + "')");
25 // Note: ModuleNotFoundError: No module named 'sklearn.ensemble._forest' -> That means wrong python version!
26
27 TPython::Exec("import ROOT; tca = ROOT.TClonesArray(\"R3BNeulandCluster\")");
28
29 // Example to test model:
30 // double prob = (Double_t)TPython::Eval("model.predict_proba([[63.421486, 1.582491, 1.0, 621.388550, 8.881784e-16,
31 // 0.000000, 1.582491, -18.501257, -7.500000, 1522.5]])[0][1]");
32
33 // Share single cluster between python and C++ **SLOW**
34 // TPython::Exec("import ROOT");
35 // TPython::Exec("cluster = ROOT.R3BNeulandCluster()");
36 // R3BNeulandCluster* cluster = ...
37 // TPython::Bind(cluster, "cluster");
38 // TString fPredictor = "model.predict_proba([["
39 // "cluster.GetT(),\n"
40 // "cluster.GetE(),\n"
41 // "cluster.GetSize(),\n"
42 // "cluster.GetEToF(),\n"
43 // "cluster.GetEnergyMoment(),\n"
44 // "cluster.GetLastHit().GetT() - cluster.GetFirstHit().GetT(),\n"
45 // "cluster.GetMaxEnergyHit().GetE(),\n"
46 // "cluster.GetPosition().X(),\n"
47 // "cluster.GetPosition().Y(),\n"
48 // "cluster.GetPosition().Z(),\n"
49 // "]])[0][1]"
50 // double prob = (Double_t)TPython::Eval(fPredictor);
51}
52
54{
55 auto ioman = FairRootManager::Instance();
56 if (ioman == nullptr)
57 {
58 LOG(fatal) << "R3BNeulandNeutronsScikit: No FairRootManager";
59 return kFATAL;
60 }
61
62 fMultiplicity = ioman->InitObjectAs<const R3BNeulandMultiplicity*>(fInputMult);
63 if (fMultiplicity == nullptr)
64 {
65 throw std::runtime_error(("R3BNeulandNeutronsScikit: R3BNeulandMultiplicity " + fInputMult +
66 " could not be provided by the FairRootManager")
67 .Data());
68 }
69
70 fClusters = (TClonesArray*)ioman->GetObject(fInputCluster);
71 if (fClusters != nullptr && !TString(fClusters->GetClass()->GetName()).EqualTo("R3BNeulandCluster"))
72 {
73 throw std::runtime_error(("R3BNeulandNeutronsScikit: TClonesArray " + fInputCluster +
74 " does not contain elements of type R3BNeulandCluster")
75 .Data());
76 }
77
78 fNeutrons.Init();
79 return kSUCCESS;
80}
81
83{
84 fNeutrons.Reset();
85
86 if (fClusters->GetEntries() == 0)
87 {
88 return;
89 }
90
91 // Let Python know about the new clusters
92 TPython::Bind(fClusters, "tca");
93 TPython::Exec("data = np.array([[cluster.GetT(), cluster.GetE(), cluster.GetSize(), cluster.GetEToF(), "
94 "cluster.GetEnergyMoment(), (cluster.GetLastHit().GetT() - cluster.GetFirstHit().GetT()), "
95 "cluster.GetMaxEnergyHit().GetE(), cluster.GetPosition().X(), cluster.GetPosition().Y(), "
96 "cluster.GetPosition().Z()] for cluster in tca])");
97
98 // Use model to predict probabilities - **SLOW**
99 TPython::Exec("results = model.predict_proba(data)[:, 1]");
100
102 std::vector<ClusterWithProba> cwps;
103 const int nClusters = fClusters->GetEntries();
104 cwps.reserve(nClusters);
105 for (int i = 0; i < nClusters; i++)
106 {
107 cwps.emplace_back(ClusterWithProba{ (R3BNeulandCluster*)fClusters->At(i),
108 (Double_t)TPython::Eval(TString::Format("results[%d]", i)) });
109 }
110
111 // Sort scored clusters, high probability first
112 std::sort(cwps.begin(), cwps.end(), std::greater<ClusterWithProba>());
113
114 // Log sorted results
115 if (FairLogger::GetLogger()->IsLogNeeded(fair::Severity::debug))
116 {
117 LOG(debug) << "R3BNeulandNeutronsScikit::Exec";
118 for (const auto& cwp : cwps)
119 {
120 LOG(debug) << cwp.c->GetPosition().X() << "\t" << cwp.p << "\t";
121 }
122 }
123
124 // With the multiplicity from somewhere else, take the n clusters with the highest prob as neutrons
125 const auto mult = fMultiplicity->GetMultiplicity();
126 for (size_t n = 0; n < cwps.size() && n < mult; n++)
127 {
128 if (cwps.at(n).p > fMinProb)
129 {
130 fNeutrons.Insert(R3BNeulandNeutron(*(cwps.at(n).c)));
131 }
132 }
133}
134
ClassImp(R3B::Neuland::Cal2HitPar)
void Exec(Option_t *) override
R3BNeulandNeutronsScikit(TString model, TString inputMult="NeulandMultiplicity", TString inputCluster="NeulandClusters", TString output="NeulandNeutrons")