@@ -414,6 +414,7 @@ class pidTPCModule
414414 network.initModel (pidTPCopts.networkPathLocally .value , pidTPCopts.enableNetworkOptimizations .value , pidTPCopts.networkSetNumThreads .value , strtoul (headers[" Valid-From" ].c_str (), NULL , 0 ), strtoul (headers[" Valid-Until" ].c_str (), NULL , 0 ));
415415 std::vector<float > dummyInput (network.getNumInputNodes (), 1 .);
416416 network.evalModel (dummyInput); // / Init the model evaluations
417+ setupColumnInputNetwork ();
417418 LOGP (info, " Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}" , headers[" LPMProductionTag" ], headers[" RecoPassName" ], headers[" NN-Version" ]);
418419 } else {
419420 LOG (fatal) << " No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata[" RecoPassName" ] << " . Please ensure that the requested pass has dedicated NN corrections available" ;
@@ -427,6 +428,7 @@ class pidTPCModule
427428 network.initModel (pidTPCopts.networkPathLocally .value , pidTPCopts.enableNetworkOptimizations .value , pidTPCopts.networkSetNumThreads .value );
428429 std::vector<float > dummyInput (network.getNumInputNodes (), 1 .);
429430 network.evalModel (dummyInput); // This is an initialisation and might reduce the overhead of the model
431+ setupColumnInputNetwork ();
430432 }
431433 } else {
432434 return ;
@@ -438,6 +440,22 @@ class pidTPCModule
438440 }
439441 } // end init
440442
443+ // __________________________________________________
444+ void setupColumnInputNetwork ()
445+ {
446+ int nInputs = network.getNumInputNodes ();
447+ std::vector<std::string> colNames;
448+ colNames.reserve (nInputs);
449+ // Column names must match the order used in createNetworkPrediction
450+ const char * baseNames[] = {" tpcInnerParam" , " tgl" , " signed1Pt" , " mass" ,
451+ " multNorm" , " nclsNorm" , " occupancyNorm" ,
452+ " hadronicRateNorm" , " phiMod" };
453+ for (int i = 0 ; i < nInputs; i++) {
454+ colNames.emplace_back (baseNames[i]);
455+ }
456+ network.setupColumnInputs (colNames);
457+ }
458+
441459 // __________________________________________________
442460 template <typename TCCDB, typename M, typename T, typename B>
443461 std::vector<float > createNetworkPrediction (TCCDB& ccdb, soa::Join<aod::Collisions, aod::EvSels> const & collisions, M const & mults, T const & tracks, B const & bcs, const size_t size)
@@ -489,6 +507,7 @@ class pidTPCModule
489507 network.initModel (pidTPCopts.networkPathLocally .value , pidTPCopts.enableNetworkOptimizations .value , pidTPCopts.networkSetNumThreads .value , strtoul (headers[" Valid-From" ].c_str (), NULL , 0 ), strtoul (headers[" Valid-Until" ].c_str (), NULL , 0 ));
490508 std::vector<float > dummyInput (network.getNumInputNodes (), 1 .);
491509 network.evalModel (dummyInput);
510+ setupColumnInputNetwork ();
492511 LOGP (info, " Retrieved NN corrections for production tag {}, pass number {}, NN-Version number {}" , headers[" LPMProductionTag" ], headers[" RecoPassName" ], headers[" NN-Version" ]);
493512 } else {
494513 LOG (fatal) << " No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata[" RecoPassName" ] << " . Please ensure that the requested pass has dedicated NN corrections available" ;
@@ -497,19 +516,14 @@ class pidTPCModule
497516 }
498517
499518 // Defining some network parameters
500- int input_dimensions = network.getNumInputNodes ();
519+ int input_dimensions = network.getNumColumns ();
501520 int output_dimensions = network.getNumOutputNodes ();
502- const uint64_t track_prop_size = input_dimensions * size;
503521 const uint64_t prediction_size = output_dimensions * size;
504522
505523 network_prediction = std::vector<float >(prediction_size * 9 ); // For each mass hypotheses
506524 const float nNclNormalization = response->GetNClNormalization ();
507525 float duration_network = 0 ;
508526
509- std::vector<float > track_properties (track_prop_size);
510- uint64_t counter_track_props = 0 ;
511- int loop_counter = 0 ;
512-
513527 // To load the Hadronic rate once for each collision
514528 float hadronicRateBegin = 0 .;
515529 std::vector<float > hadronicRateForCollision (collisions.size (), 0 .0f );
@@ -530,88 +544,98 @@ class pidTPCModule
530544 hadronicRateBegin = 0 .0f ;
531545 }
532546
533- // Filling a std::vector<float> to be evaluated by the network
534- // Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
547+ // Extract per-column data in a single pass over tracks (instead of 9x)
535548 static constexpr int NParticleTypes = 9 ;
536549 constexpr int ExpectedInputDimensionsNNV2 = 7 ;
537550 constexpr int ExpectedInputDimensionsNNV3 = 8 ;
538551 constexpr int ExpectedInputDimensionsNNV4 = 9 ;
539- constexpr auto NetworkVersionV2 = " 2" ;
540- constexpr auto NetworkVersionV3 = " 3" ;
541- constexpr auto NetworkVersionV4 = " 4" ;
542- for (int j = 0 ; j < NParticleTypes; j++) { // Loop over particle number for which network correction is used
543- for (auto const & trk : tracks) {
544- if (!trk.hasTPC ()) {
552+
553+ const float hadronicRateDivisor = (collsys == CollisionSystemType::kCollSyspp ) ? 1500 .f : 50 .f ;
554+
555+ std::vector<float > colTpcInnerParam, colTgl, colSigned1Pt, colMass;
556+ std::vector<float > colMultNorm, colNclsNorm;
557+ std::vector<float > colOccupancyNorm, colHadronicRateNorm, colPhiMod;
558+ colTpcInnerParam.reserve (size);
559+ colTgl.reserve (size);
560+ colSigned1Pt.reserve (size);
561+ colMultNorm.reserve (size);
562+ colNclsNorm.reserve (size);
563+ if (input_dimensions >= ExpectedInputDimensionsNNV2) {
564+ colOccupancyNorm.reserve (size);
565+ }
566+ if (input_dimensions >= ExpectedInputDimensionsNNV3) {
567+ colHadronicRateNorm.reserve (size);
568+ }
569+ if (input_dimensions >= ExpectedInputDimensionsNNV4) {
570+ colPhiMod.reserve (size);
571+ }
572+
573+ for (auto const & trk : tracks) {
574+ if (!trk.hasTPC ()) {
575+ continue ;
576+ }
577+ if (pidTPCopts.skipTPCOnly ) {
578+ if (!trk.hasITS () && !trk.hasTRD () && !trk.hasTOF ()) {
545579 continue ;
546580 }
547- if (pidTPCopts.skipTPCOnly ) {
548- if (!trk.hasITS () && !trk.hasTRD () && !trk.hasTOF ()) {
549- continue ;
550- }
551- }
552- track_properties[counter_track_props] = trk.tpcInnerParam ();
553- track_properties[counter_track_props + 1 ] = trk.tgl ();
554- track_properties[counter_track_props + 2 ] = trk.signed1Pt ();
555- track_properties[counter_track_props + 3 ] = o2::track::pid_constants::sMasses [j];
556- track_properties[counter_track_props + 4 ] = trk.has_collision () ? mults[trk.collisionId ()] / 11000 . : 1 .;
557- track_properties[counter_track_props + 5 ] = std::sqrt (nNclNormalization / trk.tpcNClsFound ());
558- if (input_dimensions == ExpectedInputDimensionsNNV2 && networkVersion == NetworkVersionV2) {
559- track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
560- }
561- if (input_dimensions == ExpectedInputDimensionsNNV3 && networkVersion == NetworkVersionV3) {
562- track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
563- if (trk.has_collision ()) {
564- if (collsys == CollisionSystemType::kCollSyspp ) {
565- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 1500 .;
566- } else {
567- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 50 .;
568- }
569- } else {
570- // asign Hadronic Rate at beginning of run if track does not belong to a collision
571- if (collsys == CollisionSystemType::kCollSyspp ) {
572- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 1500 .;
573- } else {
574- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 50 .;
575- }
576- }
581+ }
582+ colTpcInnerParam.push_back (trk.tpcInnerParam ());
583+ colTgl.push_back (trk.tgl ());
584+ colSigned1Pt.push_back (trk.signed1Pt ());
585+ colMultNorm.push_back (trk.has_collision () ? mults[trk.collisionId ()] / 11000 .f : 1 .f );
586+ colNclsNorm.push_back (std::sqrt (nNclNormalization / trk.tpcNClsFound ()));
587+ if (input_dimensions >= ExpectedInputDimensionsNNV2) {
588+ colOccupancyNorm.push_back (trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 .f : 1 .f );
589+ }
590+ if (input_dimensions >= ExpectedInputDimensionsNNV3) {
591+ if (trk.has_collision ()) {
592+ colHadronicRateNorm.push_back (hadronicRateForCollision[trk.collisionId ()] / hadronicRateDivisor);
593+ } else {
594+ colHadronicRateNorm.push_back (hadronicRateBegin / hadronicRateDivisor);
577595 }
596+ }
597+ if (input_dimensions >= ExpectedInputDimensionsNNV4) {
598+ colPhiMod.push_back (std::fmod (std::fmod (trk.phi (), 2 .f * static_cast <float >(M_PI)) + 2 .f * static_cast <float >(M_PI), static_cast <float >(M_PI) / 9 .0f ));
599+ }
600+ }
578601
579- if (input_dimensions == ExpectedInputDimensionsNNV4 && networkVersion == NetworkVersionV4) {
580- track_properties[counter_track_props + 6 ] = trk.has_collision () ? collisions.iteratorAt (trk.collisionId ()).ft0cOccupancyInTimeRange () / 60000 . : 1 .;
581- if (trk.has_collision ()) {
582- if (collsys == CollisionSystemType::kCollSyspp ) {
583- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 1500 .;
584- } else {
585- track_properties[counter_track_props + 7 ] = hadronicRateForCollision[trk.collisionId ()] / 50 .;
586- }
587- } else {
588- // asign Hadronic Rate at beginning of run if track does not belong to a collision
589- if (collsys == CollisionSystemType::kCollSyspp ) {
590- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 1500 .;
591- } else {
592- track_properties[counter_track_props + 7 ] = hadronicRateBegin / 50 .;
593- }
594- }
595- track_properties[counter_track_props + 8 ] = std::fmod (std::fmod (trk.phi (), 2 * M_PI) + 2 * M_PI, M_PI / 9.0 );
596- }
597- counter_track_props += input_dimensions;
602+ const int64_t nValidTracks = static_cast <int64_t >(colTpcInnerParam.size ());
603+ colMass.resize (nValidTracks);
604+ auto memInfo = Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
605+
606+ // Evaluate network once per hypothesis, passing columns as separate tensors
607+ for (int j = 0 ; j < NParticleTypes; j++) {
608+ std::fill (colMass.begin (), colMass.end (), o2::track::pid_constants::sMasses [j]);
609+
610+ // Build column tensors (zero-copy wrapping existing vectors)
611+ std::vector<Ort::Value> inputTensors;
612+ inputTensors.reserve (input_dimensions);
613+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colTpcInnerParam.data (), nValidTracks, &nValidTracks, 1 ));
614+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colTgl.data (), nValidTracks, &nValidTracks, 1 ));
615+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colSigned1Pt.data (), nValidTracks, &nValidTracks, 1 ));
616+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colMass.data (), nValidTracks, &nValidTracks, 1 ));
617+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colMultNorm.data (), nValidTracks, &nValidTracks, 1 ));
618+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colNclsNorm.data (), nValidTracks, &nValidTracks, 1 ));
619+ if (input_dimensions >= ExpectedInputDimensionsNNV2) {
620+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colOccupancyNorm.data (), nValidTracks, &nValidTracks, 1 ));
621+ }
622+ if (input_dimensions >= ExpectedInputDimensionsNNV3) {
623+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colHadronicRateNorm.data (), nValidTracks, &nValidTracks, 1 ));
624+ }
625+ if (input_dimensions >= ExpectedInputDimensionsNNV4) {
626+ inputTensors.emplace_back (Ort::Value::CreateTensor<float >(memInfo, colPhiMod.data (), nValidTracks, &nValidTracks, 1 ));
598627 }
599628
600629 auto start_network_eval = std::chrono::high_resolution_clock::now ();
601- float * output_network = network.evalModel (track_properties );
630+ float * output_network = network.evalModel < float >(inputTensors );
602631 auto stop_network_eval = std::chrono::high_resolution_clock::now ();
603632 duration_network += std::chrono::duration<float , std::ratio<1 , 1000000000 >>(stop_network_eval - start_network_eval).count ();
604633 for (uint64_t k = 0 ; k < prediction_size; k += output_dimensions) {
605634 for (int l = 0 ; l < output_dimensions; l++) {
606- network_prediction[k + l + prediction_size * loop_counter ] = output_network[k + l];
635+ network_prediction[k + l + prediction_size * j ] = output_network[k + l];
607636 }
608637 }
609-
610- counter_track_props = 0 ;
611- loop_counter += 1 ;
612638 }
613- track_properties.clear ();
614-
615639 auto stop_network_total = std::chrono::high_resolution_clock::now ();
616640 LOG (debug) << " Neural Network for the TPC PID response correction: Time per track (eval ONNX): " << duration_network / (size * 9 ) << " ns ; Total time (eval ONNX): " << duration_network / 1000000000 << " s" ;
617641 LOG (debug) << " Neural Network for the TPC PID response correction: Time per track (eval + overhead): " << std::chrono::duration<float , std::ratio<1 , 1000000000 >>(stop_network_total - start_network_total).count () / (size * 9 ) << " ns ; Total time (eval + overhead): " << std::chrono::duration<float , std::ratio<1 , 1000000000 >>(stop_network_total - start_network_total).count () / 1000000000 << " s" ;
0 commit comments