@@ -78,6 +78,9 @@ DECLARE_SOA_COLUMN(PtPi1, ptPi1, float);
7878DECLARE_SOA_COLUMN (ImpactParameterPi1, impactParameterPi1, float ); // ! Normalised impact parameter of Pi1 (prong2)
7979DECLARE_SOA_COLUMN (ImpactParameterNormalisedPi1, impactParameterNormalisedPi1, float ); // ! Normalised impact parameter of Pi1 (prong2)
8080DECLARE_SOA_COLUMN (MaxNormalisedDeltaIP, maxNormalisedDeltaIP, float ); // ! Maximum normalized difference between measured and expected impact parameter of candidate prongs
81+ DECLARE_SOA_COLUMN (MlScoreBkg, mlScoreBkg, float ); // ! ML score for background class
82+ DECLARE_SOA_COLUMN (MlScorePrompt, mlScorePrompt, float ); // ! ML score for prompt signal class
83+ DECLARE_SOA_COLUMN (MlScoreNonPrompt, mlScoreNonPrompt, float ); // ! ML score for non-prompt signal class (3-class model only, -1 otherwise)
8184} // namespace full
8285
8386DECLARE_SOA_TABLE (HfCandXicToXiPiPiLites, " AOD" , " HFXICXI2PILITE" ,
@@ -186,6 +189,37 @@ DECLARE_SOA_TABLE(HfCandXicToXiPiPiLiteKfs, "AOD", "HFXICXI2PILITKF",
186189 hf_cand_xic_to_xi_pi_pi::DcaXYPi0Xi,
187190 hf_cand_xic_to_xi_pi_pi::DcaXYPi1Xi);
188191
192+ DECLARE_SOA_TABLE (HfCandXicToXiPiPiLiteMLs, " AOD" , " HFXICXI2PIMLITE" ,
193+ full::ParticleFlag,
194+ hf_cand_mc_flag::OriginMcRec,
195+ full::CandidateSelFlag,
196+ full::Y,
197+ full::Eta,
198+ full::Phi,
199+ full::P,
200+ full::Pt,
201+ full::M,
202+ hf_cand_xic_to_xi_pi_pi::InvMassXi,
203+ hf_cand_xic_to_xi_pi_pi::InvMassLambda,
204+ full::DecayLength,
205+ full::DecayLengthXY,
206+ full::Cpa,
207+ full::CpaXY,
208+ hf_cand_xic_to_xi_pi_pi::CpaXi,
209+ hf_cand_xic_to_xi_pi_pi::CpaXYXi,
210+ hf_cand_xic_to_xi_pi_pi::CpaLambda,
211+ hf_cand_xic_to_xi_pi_pi::CpaXYLambda,
212+ full::ImpactParameterXi,
213+ full::ImpactParameterNormalisedXi,
214+ full::ImpactParameterPi0,
215+ full::ImpactParameterNormalisedPi0,
216+ full::ImpactParameterPi1,
217+ full::ImpactParameterNormalisedPi1,
218+ full::MaxNormalisedDeltaIP,
219+ full::MlScoreBkg,
220+ full::MlScorePrompt,
221+ full::MlScoreNonPrompt);
222+
189223DECLARE_SOA_TABLE (HfCandXicToXiPiPiFulls, " AOD" , " HFXICXI2PIFULL" ,
190224 full::ParticleFlag,
191225 hf_cand_mc_flag::OriginMcRec,
@@ -343,6 +377,7 @@ DECLARE_SOA_TABLE(HfCandXicToXiPiPiFullPs, "AOD", "HFXICXI2PIFULLP",
343377struct HfTreeCreatorXicToXiPiPi {
344378 Produces<o2::aod::HfCandXicToXiPiPiLites> rowCandidateLite;
345379 Produces<o2::aod::HfCandXicToXiPiPiLiteKfs> rowCandidateLiteKf;
380+ Produces<o2::aod::HfCandXicToXiPiPiLiteMLs> rowCandidateLiteMl;
346381 Produces<o2::aod::HfCandXicToXiPiPiFulls> rowCandidateFull;
347382 Produces<o2::aod::HfCandXicToXiPiPiFullKfs> rowCandidateFullKf;
348383 Produces<o2::aod::HfCandXicToXiPiPiFullPs> rowCandidateFullParticles;
@@ -356,10 +391,14 @@ struct HfTreeCreatorXicToXiPiPi {
356391 Configurable<float > downSampleBkgFactor{" downSampleBkgFactor" , 1 ., " Fraction of background candidates to keep for ML trainings" };
357392 Configurable<float > ptMaxForDownSample{" ptMaxForDownSample" , 10 ., " Maximum pt for the application of the downsampling factor" };
358393
394+ static constexpr int kNumBinaryClasses = 2 ;
395+
359396 using SelectedCandidates = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfSelXicToXiPiPi>>;
360397 using SelectedCandidatesKf = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfCandXicKF, aod::HfSelXicToXiPiPi>>;
398+ using SelectedCandidatesML = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfMlXicToXiPiPi, aod::HfSelXicToXiPiPi>>;
361399 using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfCandXicMcRec, aod::HfSelXicToXiPiPi>>;
362400 using SelectedCandidatesKfMc = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfCandXicKF, aod::HfCandXicMcRec, aod::HfSelXicToXiPiPi>>;
401+ using SelectedCandidatesMcML = soa::Filtered<soa::Join<aod::HfCandXic, aod::HfMlXicToXiPiPi, aod::HfCandXicMcRec, aod::HfSelXicToXiPiPi>>;
363402 using MatchedGenXicToXiPiPi = soa::Filtered<soa::Join<aod::McParticles, aod::HfCandXicMcGen>>;
364403
365404 Filter filterSelectCandidates = aod::hf_sel_candidate_xic::isSelXicToXiPiPi >= selectionFlagXic;
@@ -372,9 +411,13 @@ struct HfTreeCreatorXicToXiPiPi {
372411
373412 void init (InitContext const &)
374413 {
414+ std::array<bool , 6 > doprocess{doprocessData, doprocessDataKf, doprocessDataWithML, doprocessMc, doprocessMcKf, doprocessMcWithML};
415+ if (std::accumulate (doprocess.begin (), doprocess.end (), 0 ) != 1 ) {
416+ LOGP (fatal, " Only one process function can be enabled at a time." );
417+ }
375418 }
376419
377- template <bool DoMc, bool DoKf, typename T>
420+ template <bool DoMc, bool DoKf, bool DoMl, typename T>
378421 void fillCandidateTable (const T& candidate)
379422 {
380423 int8_t particleFlag = candidate.sign ();
@@ -383,7 +426,7 @@ struct HfTreeCreatorXicToXiPiPi {
383426 particleFlag = candidate.flagMcMatchRec ();
384427 originMc = candidate.originMcRec ();
385428 }
386- if constexpr (!DoKf) {
429+ if constexpr (!DoKf && !DoMl ) {
387430 if (fillCandidateLiteTable) {
388431 rowCandidateLite (
389432 particleFlag,
@@ -484,7 +527,7 @@ struct HfTreeCreatorXicToXiPiPi {
484527 candidate.nSigTofPiFromLambda (),
485528 candidate.nSigTofPrFromLambda ());
486529 }
487- } else {
530+ } else if constexpr (DoKf) {
488531 if (fillCandidateLiteTable) {
489532 rowCandidateLiteKf (
490533 particleFlag,
@@ -636,6 +679,47 @@ struct HfTreeCreatorXicToXiPiPi {
636679 candidate.dcaXYPi1Xi ());
637680 }
638681 }
682+ if constexpr (DoMl) {
683+ float mlScoreBkg = -1 .f , mlScorePrompt = -1 .f , mlScoreNonPrompt = -1 .f ;
684+ const int scoreSize = static_cast <int >(candidate.mlProbXicToXiPiPi ().size ());
685+ if (scoreSize > 0 ) {
686+ mlScoreBkg = candidate.mlProbXicToXiPiPi ()[0 ];
687+ mlScorePrompt = candidate.mlProbXicToXiPiPi ()[1 ];
688+ if (scoreSize > kNumBinaryClasses ) {
689+ mlScoreNonPrompt = candidate.mlProbXicToXiPiPi ()[2 ];
690+ }
691+ }
692+ rowCandidateLiteMl (
693+ particleFlag,
694+ originMc,
695+ candidate.isSelXicToXiPiPi (),
696+ candidate.y (o2::constants::physics::MassXiCPlus),
697+ candidate.eta (),
698+ candidate.phi (),
699+ candidate.p (),
700+ candidate.pt (),
701+ candidate.invMassXicPlus (),
702+ candidate.invMassXi (),
703+ candidate.invMassLambda (),
704+ candidate.decayLength (),
705+ candidate.decayLengthXY (),
706+ candidate.cpa (),
707+ candidate.cpaXY (),
708+ candidate.cpaXi (),
709+ candidate.cpaXYXi (),
710+ candidate.cpaLambda (),
711+ candidate.cpaXYLambda (),
712+ candidate.impactParameter0 (),
713+ candidate.impactParameterNormalised0 (),
714+ candidate.impactParameter1 (),
715+ candidate.impactParameterNormalised1 (),
716+ candidate.impactParameter2 (),
717+ candidate.impactParameterNormalised2 (),
718+ candidate.maxNormalisedDeltaIP (),
719+ mlScoreBkg,
720+ mlScorePrompt,
721+ mlScoreNonPrompt);
722+ }
639723 }
640724
641725 void processData (SelectedCandidates const & candidates)
@@ -653,10 +737,10 @@ struct HfTreeCreatorXicToXiPiPi {
653737 continue ;
654738 }
655739 }
656- fillCandidateTable<false , false >(candidate);
740+ fillCandidateTable<false , false , false >(candidate);
657741 }
658742 }
659- PROCESS_SWITCH (HfTreeCreatorXicToXiPiPi, processData, " Process data with DCAFitter reconstruction" , true );
743+ PROCESS_SWITCH (HfTreeCreatorXicToXiPiPi, processData, " Process data with DCAFitter reconstruction" , false );
660744
661745 void processDataKf (SelectedCandidatesKf const & candidates)
662746 {
@@ -673,11 +757,22 @@ struct HfTreeCreatorXicToXiPiPi {
673757 continue ;
674758 }
675759 }
676- fillCandidateTable<false , true >(candidate);
760+ fillCandidateTable<false , true , false >(candidate);
677761 }
678762 }
679763 PROCESS_SWITCH (HfTreeCreatorXicToXiPiPi, processDataKf, " Process data with KFParticle reconstruction" , false );
680764
765+ void processDataWithML (SelectedCandidatesML const & candidates)
766+ {
767+ // Filling candidate properties
768+ rowCandidateLiteMl.reserve (candidates.size ());
769+
770+ for (const auto & candidate : candidates) {
771+ fillCandidateTable<false , false , true >(candidate);
772+ }
773+ }
774+ PROCESS_SWITCH (HfTreeCreatorXicToXiPiPi, processDataWithML, " Process data with DCAFitter reconstruction and ML" , false );
775+
681776 void processMc (SelectedCandidatesMc const & candidates,
682777 MatchedGenXicToXiPiPi const & particles)
683778 {
@@ -689,7 +784,7 @@ struct HfTreeCreatorXicToXiPiPi {
689784 rowCandidateFull.reserve (recSig.size ());
690785 }
691786 for (const auto & candidate : recSig) {
692- fillCandidateTable<true , false >(candidate);
787+ fillCandidateTable<true , false , false >(candidate);
693788 }
694789 } else if (fillOnlyBackground) {
695790 if (fillCandidateLiteTable) {
@@ -702,7 +797,7 @@ struct HfTreeCreatorXicToXiPiPi {
702797 if (candidate.pt () < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
703798 continue ;
704799 }
705- fillCandidateTable<true , false >(candidate);
800+ fillCandidateTable<true , false , false >(candidate);
706801 }
707802 } else {
708803 if (fillCandidateLiteTable) {
@@ -711,7 +806,7 @@ struct HfTreeCreatorXicToXiPiPi {
711806 rowCandidateFull.reserve (candidates.size ());
712807 }
713808 for (const auto & candidate : candidates) {
714- fillCandidateTable<true , false >(candidate);
809+ fillCandidateTable<true , false , false >(candidate);
715810 }
716811 }
717812
@@ -743,7 +838,7 @@ struct HfTreeCreatorXicToXiPiPi {
743838 rowCandidateFull.reserve (recSigKf.size ());
744839 }
745840 for (const auto & candidate : recSigKf) {
746- fillCandidateTable<true , true >(candidate);
841+ fillCandidateTable<true , true , false >(candidate);
747842 }
748843 } else if (fillOnlyBackground) {
749844 if (fillCandidateLiteTable) {
@@ -756,7 +851,7 @@ struct HfTreeCreatorXicToXiPiPi {
756851 if (candidate.pt () < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
757852 continue ;
758853 }
759- fillCandidateTable<true , true >(candidate);
854+ fillCandidateTable<true , true , false >(candidate);
760855 }
761856 } else {
762857 if (fillCandidateLiteTable) {
@@ -765,7 +860,7 @@ struct HfTreeCreatorXicToXiPiPi {
765860 rowCandidateFull.reserve (candidates.size ());
766861 }
767862 for (const auto & candidate : candidates) {
768- fillCandidateTable<true , true >(candidate);
863+ fillCandidateTable<true , true , false >(candidate);
769864 }
770865 }
771866
@@ -785,6 +880,52 @@ struct HfTreeCreatorXicToXiPiPi {
785880 }
786881 }
787882 PROCESS_SWITCH (HfTreeCreatorXicToXiPiPi, processMcKf, " Process MC with KF Particle reconstruction" , false );
883+
884+ void processMcWithML (SelectedCandidatesMcML const & candidates,
885+ MatchedGenXicToXiPiPi const & particles)
886+ {
887+ // Filling candidate properties
888+ rowCandidateLiteMl.reserve (candidates.size ());
889+ if (fillOnlySignal) {
890+ for (const auto & candidate : candidates) {
891+ if (candidate.flagMcMatchRec () == int8_t (0 )) {
892+ continue ;
893+ }
894+ fillCandidateTable<true , false , true >(candidate);
895+ }
896+ } else if (fillOnlyBackground) {
897+ for (const auto & candidate : candidates) {
898+ if (candidate.flagMcMatchRec () != int8_t (0 )) {
899+ continue ;
900+ }
901+ float const pseudoRndm = candidate.ptProng1 () * 1000 . - static_cast <int64_t >(candidate.ptProng1 () * 1000 );
902+ if (candidate.pt () < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
903+ continue ;
904+ }
905+ fillCandidateTable<true , false , true >(candidate);
906+ }
907+ } else {
908+ for (const auto & candidate : candidates) {
909+ fillCandidateTable<true , false , true >(candidate);
910+ }
911+ }
912+
913+ if (fillGenParticleTable) {
914+ rowCandidateFullParticles.reserve (particles.size ());
915+ for (const auto & particle : particles) {
916+ rowCandidateFullParticles (
917+ particle.flagMcMatchGen (),
918+ particle.originMcGen (),
919+ particle.pdgBhadMotherPart (),
920+ particle.pt (),
921+ particle.eta (),
922+ particle.phi (),
923+ RecoDecay::y (particle.pVector (), o2::constants::physics::MassXiCPlus),
924+ particle.decayLengthMcGen ());
925+ }
926+ }
927+ }
928+ PROCESS_SWITCH (HfTreeCreatorXicToXiPiPi, processMcWithML, " Process MC with DCAFitter reconstruction and ML" , false );
788929};
789930
790931WorkflowSpec defineDataProcessing (ConfigContext const & cfgc)
0 commit comments