3333#include <Framework/InitContext.h>
3434#include <Framework/runDataProcessing.h>
3535
36+ #include <algorithm>
3637#include <cstdint>
3738#include <cstdlib>
39+ #include <vector>
3840
3941using namespace o2;
4042using namespace o2::framework;
@@ -68,8 +70,8 @@ DECLARE_SOA_COLUMN(DecayLength, decayLength, float);
6870DECLARE_SOA_COLUMN(DecayLengthXY, decayLengthXY, float);
6971DECLARE_SOA_COLUMN(DecayLengthNormalised, decayLengthNormalised, float);
7072DECLARE_SOA_COLUMN(DecayLengthXYNormalised, decayLengthXYNormalised, float);
71- DECLARE_SOA_COLUMN(CPA , cpa, float);
72- DECLARE_SOA_COLUMN(CPAXY , cpaXY, float);
73+ DECLARE_SOA_COLUMN(Cpa , cpa, float);
74+ DECLARE_SOA_COLUMN(CpaXY , cpaXY, float);
7375DECLARE_SOA_COLUMN(Ct, ct, float);
7476DECLARE_SOA_COLUMN(PtV0Pos, ptV0Pos, float);
7577DECLARE_SOA_COLUMN(PtV0Neg, ptV0Neg, float);
@@ -84,6 +86,9 @@ DECLARE_SOA_COLUMN(V0CtLambda, v0CtLambda, float);
8486DECLARE_SOA_COLUMN(FlagMc, flagMc, int8_t);
8587DECLARE_SOA_COLUMN(OriginMcRec, originMcRec, int8_t);
8688DECLARE_SOA_COLUMN(OriginMcGen, originMcGen, int8_t);
89+ DECLARE_SOA_COLUMN(MlScoreFirstClass, mlScoreFirstClass, float);
90+ DECLARE_SOA_COLUMN(MlScoreSecondClass, mlScoreSecondClass, float);
91+ DECLARE_SOA_COLUMN(MlScoreThirdClass, mlScoreThirdClass, float);
8792// Events
8893DECLARE_SOA_COLUMN(IsEventReject, isEventReject, int);
8994DECLARE_SOA_COLUMN(RunNumber, runNumber, int);
@@ -118,15 +123,18 @@ DECLARE_SOA_TABLE(HfCandCascLites, "AOD", "HFCANDCASCLITE",
118123 full::NSigmaTOFPr0,
119124 full::M,
120125 full::Pt,
121- full::CPA ,
122- full::CPAXY ,
126+ full::Cpa ,
127+ full::CpaXY ,
123128 full::Ct,
124129 full::Eta,
125130 full::Phi,
126131 full::Y,
127132 full::E,
128133 full::FlagMc,
129- full::OriginMcRec);
134+ full::OriginMcRec,
135+ full::MlScoreFirstClass,
136+ full::MlScoreSecondClass,
137+ full::MlScoreThirdClass);
130138
131139DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
132140 collision::BCId,
@@ -188,15 +196,18 @@ DECLARE_SOA_TABLE(HfCandCascFulls, "AOD", "HFCANDCASCFULL",
188196 full::M,
189197 full::Pt,
190198 full::P,
191- full::CPA ,
192- full::CPAXY ,
199+ full::Cpa ,
200+ full::CpaXY ,
193201 full::Ct,
194202 full::Eta,
195203 full::Phi,
196204 full::Y,
197205 full::E,
198206 full::FlagMc,
199- full::OriginMcRec);
207+ full::OriginMcRec,
208+ full::MlScoreFirstClass,
209+ full::MlScoreSecondClass,
210+ full::MlScoreThirdClass);
200211
201212DECLARE_SOA_TABLE(HfCandCascFullEs, "AOD", "HFCANDCASCFULLE",
202213 collision::BCId,
@@ -228,23 +239,56 @@ struct HfTreeCreatorLcToK0sP {
228239 Configurable<float> ptMaxForDownSample{"ptMaxForDownSample", 24., "Maximum pt for the application of the downsampling factor"};
229240 Configurable<bool> fillOnlySignal{"fillOnlySignal", false, "Flag to fill derived tables with signal for ML trainings"};
230241 Configurable<bool> fillOnlyBackground{"fillOnlyBackground", false, "Flag to fill derived tables with background for ML trainings"};
242+ Configurable<bool> applyMl{"applyMl", false, "Whether ML was used in candidateSelectorLc"};
243+
244+ constexpr static float UndefValueFloat = -999.f;
231245
232246 HfHelper hfHelper;
233247
234- Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1;
235248 using TracksWPid = soa::Join<aod::Tracks, aod::TracksPidPr>;
236249 using SelectedCandidatesMc = soa::Filtered<soa::Join<aod::HfCandCascade, aod::HfCandCascadeMcRec, aod::HfSelLcToK0sP>>;
237-
238- Partition<SelectedCandidatesMc> recSig = nabs(aod::hf_cand_casc::flagMcMatchRec) != int8_t(0);
239- Partition<SelectedCandidatesMc> recBkg = nabs(aod::hf_cand_casc::flagMcMatchRec) == int8_t(0);
250+ Filter filterSelectCandidates = aod::hf_sel_candidate_lc_to_k0s_p::isSelLcToK0sP >= 1;
240251
241252 void init(InitContext const&)
242253 {
243254 }
244255
256+ /// \brief function to get ML score values for the current candidate and assign them to input parameters
257+ /// \param candidate candidate instance
258+ /// \param candidateMlScore instance of handler of vectors with ML scores associated with the current candidate
259+ /// \param mlScoreFirstClass ML score for belonging to the first class
260+ /// \param mlScoreSecondClass ML score for belonging to the second class
261+ /// \param mlScoreThirdClass ML score for belonging to the third class
262+ void assignMlScores(aod::HfMlLcToK0sP::iterator const& candidateMlScore, float& mlScoreFirstClass, float& mlScoreSecondClass, float& mlScoreThirdClass)
263+ {
264+ std::vector<float> mlScores;
265+ std::copy(candidateMlScore.mlProbLcToK0sP().begin(), candidateMlScore.mlProbLcToK0sP().end(), std::back_inserter(mlScores));
266+
267+ constexpr int IndexFirstClass{0};
268+ constexpr int IndexSecondClass{1};
269+ constexpr int IndexThirdClass{2};
270+ if (mlScores.size() == 0) {
271+ return; // when candidateSelectorLcK0sP rejects a candidate by "usual", non-ML cut, the ml score vector remains empty
272+ }
273+ mlScoreFirstClass = mlScores.at(IndexFirstClass);
274+ mlScoreSecondClass = mlScores.at(IndexSecondClass);
275+ if (mlScores.size() > IndexThirdClass) {
276+ mlScoreThirdClass = mlScores.at(IndexThirdClass);
277+ }
278+ }
279+
245280 template <typename T, typename U>
246- void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec)
281+ void fillCandidate(const T& candidate, const U& bach, int8_t flagMc, int8_t originMcRec, aod::HfMlLcToK0sP::iterator const& candidateMlScore )
247282 {
283+
284+ float mlScoreFirstClass{UndefValueFloat};
285+ float mlScoreSecondClass{UndefValueFloat};
286+ float mlScoreThirdClass{UndefValueFloat};
287+
288+ if (applyMl) {
289+ assignMlScores(candidateMlScore, mlScoreFirstClass, mlScoreSecondClass, mlScoreThirdClass);
290+ }
291+
248292 if (fillCandidateLiteTable) {
249293 rowCandidateLite(
250294 candidate.chi2PCA(),
@@ -283,7 +327,10 @@ struct HfTreeCreatorLcToK0sP {
283327 hfHelper.yLc(candidate),
284328 hfHelper.eLc(candidate),
285329 flagMc,
286- originMcRec);
330+ originMcRec,
331+ mlScoreFirstClass,
332+ mlScoreSecondClass,
333+ mlScoreThirdClass);
287334 } else {
288335 rowCandidateFull(
289336 bach.collision().bcId(),
@@ -353,7 +400,10 @@ struct HfTreeCreatorLcToK0sP {
353400 hfHelper.yLc(candidate),
354401 hfHelper.eLc(candidate),
355402 flagMc,
356- originMcRec);
403+ originMcRec,
404+ mlScoreFirstClass,
405+ mlScoreSecondClass,
406+ mlScoreThirdClass);
357407 }
358408 }
359409 template <typename T>
@@ -370,52 +420,41 @@ struct HfTreeCreatorLcToK0sP {
370420 void processMc(aod::Collisions const& collisions,
371421 aod::McCollisions const&,
372422 SelectedCandidatesMc const& candidates,
423+ aod::HfMlLcToK0sP const& candidateMlScores,
373424 soa::Join<aod::McParticles, aod::HfCandCascadeMcGen> const& particles,
374425 TracksWPid const&)
375426 {
376427
428+ if (applyMl && candidateMlScores.size() == 0) {
429+ LOG(fatal) << "ML enabled but table with the ML scores is empty! Please check your configurables.";
430+ return;
431+ }
432+
377433 // Filling event properties
378434 rowCandidateFullEvents.reserve(collisions.size());
379435 for (const auto& collision : collisions) {
380436 fillEvent(collision);
381437 }
382438
383- if (fillOnlySignal) {
384- if (fillCandidateLiteTable) {
385- rowCandidateLite.reserve(recSig.size());
386- } else {
387- rowCandidateFull.reserve(recSig.size());
388- }
389- for (const auto& candidate : recSig) {
390- auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
391- fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
392- }
393- } else if (fillOnlyBackground) {
394- if (fillCandidateLiteTable) {
395- rowCandidateLite.reserve(recBkg.size());
396- } else {
397- rowCandidateFull.reserve(recBkg.size());
398- }
399- for (const auto& candidate : recBkg) {
400- if (downSampleBkgFactor < 1.) {
401- float pseudoRndm = candidate.ptProng0() * 1000. - static_cast<int64_t>(candidate.ptProng0() * 1000);
402- if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) {
403- continue;
404- }
405- }
406- auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
407- fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
408- }
439+ if (fillCandidateLiteTable) {
440+ rowCandidateLite.reserve(candidates.size());
409441 } else {
410- // Filling candidate properties
411- if (fillCandidateLiteTable) {
412- rowCandidateLite.reserve(candidates.size());
442+ rowCandidateFull.reserve(candidates.size());
443+ }
444+
445+ int iCand{0};
446+ for (const auto& candidate : candidates) {
447+ auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand);
448+ ++iCand;
449+ auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
450+ const int flag = candidate.flagMcMatchRec();
451+
452+ if (fillOnlySignal && flag != 0) {
453+ fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
454+ } else if (fillOnlyBackground && flag == 0) {
455+ fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
413456 } else {
414- rowCandidateFull.reserve(candidates.size());
415- }
416- for (const auto& candidate : candidates) {
417- auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
418- fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec());
457+ fillCandidate(candidate, bach, candidate.flagMcMatchRec(), candidate.originMcRec(), candidateMlScore);
419458 }
420459 }
421460
@@ -439,9 +478,15 @@ struct HfTreeCreatorLcToK0sP {
439478
440479 void processData(aod::Collisions const& collisions,
441480 soa::Join<aod::HfCandCascade, aod::HfSelLcToK0sP> const& candidates,
481+ aod::HfMlLcToK0sP const& candidateMlScores,
442482 TracksWPid const&)
443483 {
444484
485+ if (applyMl && candidateMlScores.size() == 0) {
486+ LOG(fatal) << "ML enabled but table with the ML scores is empty! Please check your configurables.";
487+ return;
488+ }
489+
445490 // Filling event properties
446491 rowCandidateFullEvents.reserve(collisions.size());
447492 for (const auto& collision : collisions) {
@@ -454,11 +499,15 @@ struct HfTreeCreatorLcToK0sP {
454499 } else {
455500 rowCandidateFull.reserve(candidates.size());
456501 }
502+
503+ int iCand{0};
457504 for (const auto& candidate : candidates) {
505+ auto candidateMlScore = candidateMlScores.rawIteratorAt(iCand);
506+ ++iCand;
458507 auto bach = candidate.prong0_as<TracksWPid>(); // bachelor
459508 double pseudoRndm = bach.pt() * 1000. - static_cast<int16_t>(bach.pt() * 1000);
460509 if (candidate.isSelLcToK0sP() >= 1 && pseudoRndm < downSampleBkgFactor) {
461- fillCandidate(candidate, bach, 0, 0);
510+ fillCandidate(candidate, bach, 0, 0, candidateMlScore );
462511 }
463512 }
464513 }
0 commit comments