Skip to content

Commit 3ab77c2

Browse files
authored
[PWGHF] Add table for ML study to treeCreator (#15591)
1 parent c5d1739 commit 3ab77c2

File tree

1 file changed

+153
-12
lines changed

1 file changed

+153
-12
lines changed

PWGHF/TableProducer/treeCreatorXicToXiPiPi.cxx

Lines changed: 153 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ DECLARE_SOA_COLUMN(PtPi1, ptPi1, float);
7878
DECLARE_SOA_COLUMN(ImpactParameterPi1, impactParameterPi1, float); //! Normalised impact parameter of Pi1 (prong2)
7979
DECLARE_SOA_COLUMN(ImpactParameterNormalisedPi1, impactParameterNormalisedPi1, float); //! Normalised impact parameter of Pi1 (prong2)
8080
DECLARE_SOA_COLUMN(MaxNormalisedDeltaIP, maxNormalisedDeltaIP, float); //! Maximum normalized difference between measured and expected impact parameter of candidate prongs
81+
DECLARE_SOA_COLUMN(MlScoreBkg, mlScoreBkg, float); //! ML score for background class
82+
DECLARE_SOA_COLUMN(MlScorePrompt, mlScorePrompt, float); //! ML score for prompt signal class
83+
DECLARE_SOA_COLUMN(MlScoreNonPrompt, mlScoreNonPrompt, float); //! ML score for non-prompt signal class (3-class model only, -1 otherwise)
8184
} // namespace full
8285

8386
DECLARE_SOA_TABLE(HfCandXicToXiPiPiLites, "AOD", "HFXICXI2PILITE",
@@ -186,6 +189,37 @@ DECLARE_SOA_TABLE(HfCandXicToXiPiPiLiteKfs, "AOD", "HFXICXI2PILITKF",
186189
hf_cand_xic_to_xi_pi_pi::DcaXYPi0Xi,
187190
hf_cand_xic_to_xi_pi_pi::DcaXYPi1Xi);
188191

192+
DECLARE_SOA_TABLE(HfCandXicToXiPiPiLiteMLs, "AOD", "HFXICXI2PIMLITE",
193+
full::ParticleFlag,
194+
hf_cand_mc_flag::OriginMcRec,
195+
full::CandidateSelFlag,
196+
full::Y,
197+
full::Eta,
198+
full::Phi,
199+
full::P,
200+
full::Pt,
201+
full::M,
202+
hf_cand_xic_to_xi_pi_pi::InvMassXi,
203+
hf_cand_xic_to_xi_pi_pi::InvMassLambda,
204+
full::DecayLength,
205+
full::DecayLengthXY,
206+
full::Cpa,
207+
full::CpaXY,
208+
hf_cand_xic_to_xi_pi_pi::CpaXi,
209+
hf_cand_xic_to_xi_pi_pi::CpaXYXi,
210+
hf_cand_xic_to_xi_pi_pi::CpaLambda,
211+
hf_cand_xic_to_xi_pi_pi::CpaXYLambda,
212+
full::ImpactParameterXi,
213+
full::ImpactParameterNormalisedXi,
214+
full::ImpactParameterPi0,
215+
full::ImpactParameterNormalisedPi0,
216+
full::ImpactParameterPi1,
217+
full::ImpactParameterNormalisedPi1,
218+
full::MaxNormalisedDeltaIP,
219+
full::MlScoreBkg,
220+
full::MlScorePrompt,
221+
full::MlScoreNonPrompt);
222+
189223
DECLARE_SOA_TABLE(HfCandXicToXiPiPiFulls, "AOD", "HFXICXI2PIFULL",
190224
full::ParticleFlag,
191225
hf_cand_mc_flag::OriginMcRec,
@@ -343,6 +377,7 @@ DECLARE_SOA_TABLE(HfCandXicToXiPiPiFullPs, "AOD", "HFXICXI2PIFULLP",
343377
struct HfTreeCreatorXicToXiPiPi {
344378
Produces<o2::aod::HfCandXicToXiPiPiLites> rowCandidateLite;
345379
Produces<o2::aod::HfCandXicToXiPiPiLiteKfs> rowCandidateLiteKf;
380+
Produces<o2::aod::HfCandXicToXiPiPiLiteMLs> rowCandidateLiteMl;
346381
Produces<o2::aod::HfCandXicToXiPiPiFulls> rowCandidateFull;
347382
Produces<o2::aod::HfCandXicToXiPiPiFullKfs> rowCandidateFullKf;
348383
Produces<o2::aod::HfCandXicToXiPiPiFullPs> rowCandidateFullParticles;
@@ -356,10 +391,14 @@ struct HfTreeCreatorXicToXiPiPi {
356391
Configurable<float> downSampleBkgFactor{"downSampleBkgFactor", 1., "Fraction of background candidates to keep for ML trainings"};
357392
Configurable<float> ptMaxForDownSample{"ptMaxForDownSample", 10., "Maximum pt for the application of the downsampling factor"};
358393

394+
static constexpr int kNumBinaryClasses = 2;
395+
359396
using SelectedCandidates = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfSelXicToXiPiPi>>;
360397
using SelectedCandidatesKf = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfCandXicKF, aod::HfSelXicToXiPiPi>>;
398+
using SelectedCandidatesML = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfMlXicToXiPiPi, aod::HfSelXicToXiPiPi>>;
361399
using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfCandXicMcRec, aod::HfSelXicToXiPiPi>>;
362400
using SelectedCandidatesKfMc = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfCandXicKF, aod::HfCandXicMcRec, aod::HfSelXicToXiPiPi>>;
401+
using SelectedCandidatesMcML = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfMlXicToXiPiPi, aod::HfCandXicMcRec, aod::HfSelXicToXiPiPi>>;
363402
using MatchedGenXicToXiPiPi = soa::Filtered<soa::Join<aod::McParticles, aod::HfCandXicMcGen>>;
364403

365404
Filter filterSelectCandidates = aod::hf_sel_candidate_xic::isSelXicToXiPiPi >= selectionFlagXic;
@@ -372,9 +411,13 @@ struct HfTreeCreatorXicToXiPiPi {
372411

373412
void init(InitContext const&)
374413
{
414+
std::array<bool, 6> doprocess{doprocessData, doprocessDataKf, doprocessDataWithML, doprocessMc, doprocessMcKf, doprocessMcWithML};
415+
if (std::accumulate(doprocess.begin(), doprocess.end(), 0) != 1) {
416+
LOGP(fatal, "Only one process function can be enabled at a time.");
417+
}
375418
}
376419

377-
template <bool DoMc, bool DoKf, typename T>
420+
template <bool DoMc, bool DoKf, bool DoMl, typename T>
378421
void fillCandidateTable(const T& candidate)
379422
{
380423
int8_t particleFlag = candidate.sign();
@@ -383,7 +426,7 @@ struct HfTreeCreatorXicToXiPiPi {
383426
particleFlag = candidate.flagMcMatchRec();
384427
originMc = candidate.originMcRec();
385428
}
386-
if constexpr (!DoKf) {
429+
if constexpr (!DoKf && !DoMl) {
387430
if (fillCandidateLiteTable) {
388431
rowCandidateLite(
389432
particleFlag,
@@ -484,7 +527,7 @@ struct HfTreeCreatorXicToXiPiPi {
484527
candidate.nSigTofPiFromLambda(),
485528
candidate.nSigTofPrFromLambda());
486529
}
487-
} else {
530+
} else if constexpr (DoKf) {
488531
if (fillCandidateLiteTable) {
489532
rowCandidateLiteKf(
490533
particleFlag,
@@ -636,6 +679,47 @@ struct HfTreeCreatorXicToXiPiPi {
636679
candidate.dcaXYPi1Xi());
637680
}
638681
}
682+
if constexpr (DoMl) {
683+
float mlScoreBkg = -1.f, mlScorePrompt = -1.f, mlScoreNonPrompt = -1.f;
684+
const int scoreSize = static_cast<int>(candidate.mlProbXicToXiPiPi().size());
685+
if (scoreSize > 0) {
686+
mlScoreBkg = candidate.mlProbXicToXiPiPi()[0];
687+
mlScorePrompt = candidate.mlProbXicToXiPiPi()[1];
688+
if (scoreSize > kNumBinaryClasses) {
689+
mlScoreNonPrompt = candidate.mlProbXicToXiPiPi()[2];
690+
}
691+
}
692+
rowCandidateLiteMl(
693+
particleFlag,
694+
originMc,
695+
candidate.isSelXicToXiPiPi(),
696+
candidate.y(o2::constants::physics::MassXiCPlus),
697+
candidate.eta(),
698+
candidate.phi(),
699+
candidate.p(),
700+
candidate.pt(),
701+
candidate.invMassXicPlus(),
702+
candidate.invMassXi(),
703+
candidate.invMassLambda(),
704+
candidate.decayLength(),
705+
candidate.decayLengthXY(),
706+
candidate.cpa(),
707+
candidate.cpaXY(),
708+
candidate.cpaXi(),
709+
candidate.cpaXYXi(),
710+
candidate.cpaLambda(),
711+
candidate.cpaXYLambda(),
712+
candidate.impactParameter0(),
713+
candidate.impactParameterNormalised0(),
714+
candidate.impactParameter1(),
715+
candidate.impactParameterNormalised1(),
716+
candidate.impactParameter2(),
717+
candidate.impactParameterNormalised2(),
718+
candidate.maxNormalisedDeltaIP(),
719+
mlScoreBkg,
720+
mlScorePrompt,
721+
mlScoreNonPrompt);
722+
}
639723
}
640724

641725
void processData(SelectedCandidates const& candidates)
@@ -653,10 +737,10 @@ struct HfTreeCreatorXicToXiPiPi {
653737
continue;
654738
}
655739
}
656-
fillCandidateTable<false, false>(candidate);
740+
fillCandidateTable<false, false, false>(candidate);
657741
}
658742
}
659-
PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processData, "Process data with DCAFitter reconstruction", true);
743+
PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processData, "Process data with DCAFitter reconstruction", false);
660744

661745
void processDataKf(SelectedCandidatesKf const& candidates)
662746
{
@@ -673,11 +757,22 @@ struct HfTreeCreatorXicToXiPiPi {
673757
continue;
674758
}
675759
}
676-
fillCandidateTable<false, true>(candidate);
760+
fillCandidateTable<false, true, false>(candidate);
677761
}
678762
}
679763
PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processDataKf, "Process data with KFParticle reconstruction", false);
680764

765+
void processDataWithML(SelectedCandidatesML const& candidates)
766+
{
767+
// Filling candidate properties
768+
rowCandidateLiteMl.reserve(candidates.size());
769+
770+
for (const auto& candidate : candidates) {
771+
fillCandidateTable<false, false, true>(candidate);
772+
}
773+
}
774+
PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processDataWithML, "Process data with DCAFitter reconstruction and ML", false);
775+
681776
void processMc(SelectedCandidatesMc const& candidates,
682777
MatchedGenXicToXiPiPi const& particles)
683778
{
@@ -689,7 +784,7 @@ struct HfTreeCreatorXicToXiPiPi {
689784
rowCandidateFull.reserve(recSig.size());
690785
}
691786
for (const auto& candidate : recSig) {
692-
fillCandidateTable<true, false>(candidate);
787+
fillCandidateTable<true, false, false>(candidate);
693788
}
694789
} else if (fillOnlyBackground) {
695790
if (fillCandidateLiteTable) {
@@ -702,7 +797,7 @@ struct HfTreeCreatorXicToXiPiPi {
702797
if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
703798
continue;
704799
}
705-
fillCandidateTable<true, false>(candidate);
800+
fillCandidateTable<true, false, false>(candidate);
706801
}
707802
} else {
708803
if (fillCandidateLiteTable) {
@@ -711,7 +806,7 @@ struct HfTreeCreatorXicToXiPiPi {
711806
rowCandidateFull.reserve(candidates.size());
712807
}
713808
for (const auto& candidate : candidates) {
714-
fillCandidateTable<true, false>(candidate);
809+
fillCandidateTable<true, false, false>(candidate);
715810
}
716811
}
717812

@@ -743,7 +838,7 @@ struct HfTreeCreatorXicToXiPiPi {
743838
rowCandidateFull.reserve(recSigKf.size());
744839
}
745840
for (const auto& candidate : recSigKf) {
746-
fillCandidateTable<true, true>(candidate);
841+
fillCandidateTable<true, true, false>(candidate);
747842
}
748843
} else if (fillOnlyBackground) {
749844
if (fillCandidateLiteTable) {
@@ -756,7 +851,7 @@ struct HfTreeCreatorXicToXiPiPi {
756851
if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
757852
continue;
758853
}
759-
fillCandidateTable<true, true>(candidate);
854+
fillCandidateTable<true, true, false>(candidate);
760855
}
761856
} else {
762857
if (fillCandidateLiteTable) {
@@ -765,7 +860,7 @@ struct HfTreeCreatorXicToXiPiPi {
765860
rowCandidateFull.reserve(candidates.size());
766861
}
767862
for (const auto& candidate : candidates) {
768-
fillCandidateTable<true, true>(candidate);
863+
fillCandidateTable<true, true, false>(candidate);
769864
}
770865
}
771866

@@ -785,6 +880,52 @@ struct HfTreeCreatorXicToXiPiPi {
785880
}
786881
}
787882
PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processMcKf, "Process MC with KF Particle reconstruction", false);
883+
884+
void processMcWithML(SelectedCandidatesMcML const& candidates,
885+
MatchedGenXicToXiPiPi const& particles)
886+
{
887+
// Filling candidate properties
888+
rowCandidateLiteMl.reserve(candidates.size());
889+
if (fillOnlySignal) {
890+
for (const auto& candidate : candidates) {
891+
if (candidate.flagMcMatchRec() == int8_t(0)) {
892+
continue;
893+
}
894+
fillCandidateTable<true, false, true>(candidate);
895+
}
896+
} else if (fillOnlyBackground) {
897+
for (const auto& candidate : candidates) {
898+
if (candidate.flagMcMatchRec() != int8_t(0)) {
899+
continue;
900+
}
901+
float const pseudoRndm = candidate.ptProng1() * 1000. - static_cast<int64_t>(candidate.ptProng1() * 1000);
902+
if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
903+
continue;
904+
}
905+
fillCandidateTable<true, false, true>(candidate);
906+
}
907+
} else {
908+
for (const auto& candidate : candidates) {
909+
fillCandidateTable<true, false, true>(candidate);
910+
}
911+
}
912+
913+
if (fillGenParticleTable) {
914+
rowCandidateFullParticles.reserve(particles.size());
915+
for (const auto& particle : particles) {
916+
rowCandidateFullParticles(
917+
particle.flagMcMatchGen(),
918+
particle.originMcGen(),
919+
particle.pdgBhadMotherPart(),
920+
particle.pt(),
921+
particle.eta(),
922+
particle.phi(),
923+
RecoDecay::y(particle.pVector(), o2::constants::physics::MassXiCPlus),
924+
particle.decayLengthMcGen());
925+
}
926+
}
927+
}
928+
PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processMcWithML, "Process MC with DCAFitter reconstruction and ML", false);
788929
};
789930

790931
WorkflowSpec defineDataProcessing(ConfigContext const& cfgc)

0 commit comments

Comments
 (0)