R3BROOT
R3B analysis software
Loading...
Searching...
No Matches
R3BNeulandMultiplicityScikit.cxx
Go to the documentation of this file.
2#include "FairLogger.h"
3#include "FairRootManager.h"
4#include "TPython.h"
5#include <numeric>
6#include <utility>
7
8R3BNeulandMultiplicityScikit::R3BNeulandMultiplicityScikit(TString model, TString input, TString output)
9 : FairTask("R3BNeulandMultiplicityScikit")
10 , fClusters(std::move(input))
11 , fMultiplicity(new R3BNeulandMultiplicity())
12 , fOutputName(std::move(output))
13{
14 TPython::Exec("import sys; v = sys.version.replace('\\n', '')");
15 TPython::Exec("print(f'R3BNeulandMultiplicityScikit running TPython with Python version {v}')");
16
17 TPython::Exec("import numpy as np; import joblib; multmodel = joblib.load('" + model + "')");
18 // Note: ModuleNotFoundError: No module named 'sklearn. ...' -> That means wrong python version!
19
20 TPython::Exec("import ROOT; mult = ROOT.R3BNeulandMultiplicity()");
21 TPython::Bind(fMultiplicity, "mult"); // Bind here to prevent 1st event missing
22}
23
25
27{
28 // Input
29 fClusters.Init();
30
31 // Output
32 auto ioman = FairRootManager::Instance();
33 if (ioman == nullptr)
34 {
35 LOG(fatal) << "R3BNeulandMultiplicityScikit: No FairRootManager";
36 return kFATAL;
37 }
38 ioman->RegisterAny(fOutputName, fMultiplicity, true);
39
40 return kSUCCESS;
41}
42
44{
45 fMultiplicity->m.fill(0.);
46 const auto clusters = fClusters.Retrieve();
47 const int nClusters = clusters.size();
48
49 if (nClusters == 0)
50 {
51 LOG(debug) << "R3BNeulandMultiplicityScikit::Exec 0 Clusters -> Mult 0";
52 fMultiplicity->m[0] = 1.;
53 return;
54 }
55
56 const int nHits = std::accumulate(
57 clusters.cbegin(), clusters.cend(), 0, [](size_t s, const R3BNeulandCluster* c) { return s + c->GetSize(); });
58 const int Edep = (int)std::accumulate(
59 clusters.cbegin(), clusters.cend(), 0., [](Double_t s, const R3BNeulandCluster* c) { return s + c->GetE(); });
60
61 // Use model to predict probabilities. Note: In the model tested here, no case "0" -> i + 1
62 TPython::Exec(
63 TString::Format("for i, p in enumerate(multmodel.predict_proba([[%d, %d, %d]])[0]):\n mult.m[i + 1] = p",
64 nHits,
65 nClusters,
66 Edep));
67
68 // Python -> C++ -> Autosave in branch
69 TPython::Bind(fMultiplicity, "mult");
70
71 // Log
72 if (FairLogger::GetLogger()->IsLogNeeded(fair::Severity::debug))
73 {
74 TPython::Exec("print([p for p in mult.m])");
75 LOG(debug) << "R3BNeulandMultiplicityScikit::Exec "
76 << std::accumulate(fMultiplicity->m.cbegin(),
77 fMultiplicity->m.cend(),
78 std::string(),
79 [](std::string a, double b) { return std::move(a) + ", " + std::to_string(b); });
80 }
81}
82
ClassImp(R3B::Neuland::Cal2HitPar)
static const Double_t c
R3BNeulandMultiplicityScikit(TString model, TString inputCluster="NeulandClusters", TString output="NeulandMultiplicity")