R3BROOT
R3B analysis software
Loading...
Searching...
No Matches
R3BNeulandNeutronsKeras.cxx
Go to the documentation of this file.
2#include "FairLogger.h"
3#include "FairRootManager.h"
4#include "TPython.h"
5#include <TClonesArray.h>
6#include <iostream>
7#include <utility>
8
10 TString scaler,
11 TString inputMult,
12 TString inputCluster,
13 TString output)
14 : FairTask("R3BNeulandNeutronsKeras")
15 , fInputMult(std::move(inputMult))
16 , fMultiplicity(nullptr)
17 , fInputCluster(std::move(inputCluster))
18 , fClusters(nullptr)
19 , fNeutrons(std::move(output))
20 , fMinProb(0.1)
21{
22 // Warning: The python instance is shared. Here use prefix keras_ for all variables
23 // Report python version
24 TPython::Exec("import sys; v = sys.version.replace('\\n', '')");
25 TPython::Exec("print(f'R3BNeulandNeutronsKeras running TPython with Python version {v}')");
26
27 // Load Keras model, report keras and tensorflow version
28 TPython::Exec(
29 "import numpy as np; import tensorflow; from tensorflow import keras; keras_model = keras.models.load_model('" +
30 model + "')");
31 TPython::Exec("print(f'... with tensorflow {tensorflow.__version__} and keras {keras.__version__}')");
32
33 // Load scaler
34 TPython::Exec("import joblib; keras_scaler = joblib.load('" + scaler + "')");
35
36 // Prepare cluster container exchange
37 TPython::Exec("import ROOT; keras_tca = ROOT.TClonesArray(\"R3BNeulandCluster\")");
38}
39
41{
42 auto ioman = FairRootManager::Instance();
43 if (ioman == nullptr)
44 {
45 LOG(fatal) << "R3BNeulandNeutronsKeras: No FairRootManager";
46 return kFATAL;
47 }
48
49 fMultiplicity = ioman->InitObjectAs<const R3BNeulandMultiplicity*>(fInputMult);
50 if (fMultiplicity == nullptr)
51 {
52 throw std::runtime_error(("R3BNeulandNeutronsKeras: R3BNeulandMultiplicity " + fInputMult +
53 " could not be provided by the FairRootManager")
54 .Data());
55 }
56
57 fClusters = (TClonesArray*)ioman->GetObject(fInputCluster);
58 if (fClusters != nullptr && !TString(fClusters->GetClass()->GetName()).EqualTo("R3BNeulandCluster"))
59 {
60 throw std::runtime_error(("R3BNeulandNeutronsKeras: TClonesArray " + fInputCluster +
61 " does not contain elements of type R3BNeulandCluster")
62 .Data());
63 }
64
65 fNeutrons.Init();
66 return kSUCCESS;
67}
68
70{
71 fNeutrons.Reset();
72
73 if (fClusters->GetEntries() == 0)
74 {
75 return;
76 }
77
78 // Let Python know about the new clusters
79 TPython::Bind(fClusters, "keras_tca");
80 TPython::Exec("keras_data = np.array([[cluster.GetT(), cluster.GetE(), cluster.GetSize(), cluster.GetEToF(), "
81 "cluster.GetEnergyMoment(), (cluster.GetLastHit().GetT() - cluster.GetFirstHit().GetT()), "
82 "cluster.GetMaxEnergyHit().GetE(), cluster.GetPosition().X(), cluster.GetPosition().Y(), "
83 "cluster.GetPosition().Z()] for cluster in keras_tca])");
84 // TPython::Exec("print(keras_data[0])");
85 TPython::Exec("keras_data_scaled = keras_scaler.transform(keras_data)");
86 // TPython::Exec("print(keras_data_scaled[0])");
87
88 // Use model to predict probabilities - **SLOW**
89 TPython::Exec("keras_results = keras_model.predict(keras_data_scaled)[:, 1]");
90 // TPython::Exec("print(keras_results)");
91
92 // Make a new container with scored clusters
93 std::vector<ClusterWithProba> cwps;
94 const int nClusters = fClusters->GetEntries();
95 cwps.reserve(nClusters);
96 for (int i = 0; i < nClusters; i++)
97 {
98 cwps.emplace_back(ClusterWithProba{ (R3BNeulandCluster*)fClusters->At(i),
99 (Double_t)TPython::Eval(TString::Format("keras_results[%d]", i)) });
100 }
101
102 // Sort scored clusters, high probability first
103 std::sort(cwps.begin(), cwps.end(), std::greater<ClusterWithProba>());
104
105 // Log sorted results
106 if (FairLogger::GetLogger()->IsLogNeeded(fair::Severity::debug))
107 {
108 LOG(debug) << "R3BNeulandNeutronsKeras::Exec";
109 for (const auto& cwp : cwps)
110 {
111 LOG(debug) << cwp.c->GetPosition().X() << "\t" << cwp.p << "\t";
112 }
113 }
114
115 // With the multiplicity from somewhere else, take the n clusters with the highest prob as neutrons
116 const auto mult = fMultiplicity->GetMultiplicity();
117 for (size_t n = 0; n < cwps.size() && n < mult; n++)
118 {
119 if (cwps.at(n).p > fMinProb)
120 {
121 fNeutrons.Insert(R3BNeulandNeutron(*(cwps.at(n).c)));
122 }
123 }
124}
125
ClassImp(R3B::Neuland::Cal2HitPar)
R3BNeulandNeutronsKeras(TString model, TString scaler, TString inputMult="NeulandMultiplicity", TString inputCluster="NeulandClusters", TString output="NeulandNeutrons")
void Exec(Option_t *) override