Skip to content

Commit 44cf89d

Browse files
committed
Avoid having intermediary vectors for the NN-based TPC PID
1 parent 0bba10f commit 44cf89d

4 files changed

Lines changed: 406 additions & 71 deletions

File tree

Common/Tools/PID/pidTPCModule.h

Lines changed: 94 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ class pidTPCModule
414414
network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0));
415415
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
416416
network.evalModel(dummyInput); /// Init the model evaluations
417+
setupColumnInputNetwork();
417418
LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, and NN-Version {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]);
418419
} else {
419420
LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available";
@@ -427,6 +428,7 @@ class pidTPCModule
427428
network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value);
428429
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
429430
network.evalModel(dummyInput); // This is an initialisation and might reduce the overhead of the model
431+
setupColumnInputNetwork();
430432
}
431433
} else {
432434
return;
@@ -438,6 +440,22 @@ class pidTPCModule
438440
}
439441
} // end init
440442

443+
//__________________________________________________
444+
void setupColumnInputNetwork()
445+
{
446+
int nInputs = network.getNumInputNodes();
447+
std::vector<std::string> colNames;
448+
colNames.reserve(nInputs);
449+
// Column names must match the order used in createNetworkPrediction
450+
const char* baseNames[] = {"tpcInnerParam", "tgl", "signed1Pt", "mass",
451+
"multNorm", "nclsNorm", "occupancyNorm",
452+
"hadronicRateNorm", "phiMod"};
453+
for (int i = 0; i < nInputs; i++) {
454+
colNames.emplace_back(baseNames[i]);
455+
}
456+
network.setupColumnInputs(colNames);
457+
}
458+
441459
//__________________________________________________
442460
template <typename TCCDB, typename M, typename T, typename B>
443461
std::vector<float> createNetworkPrediction(TCCDB& ccdb, soa::Join<aod::Collisions, aod::EvSels> const& collisions, M const& mults, T const& tracks, B const& bcs, const size_t size)
@@ -489,6 +507,7 @@ class pidTPCModule
489507
network.initModel(pidTPCopts.networkPathLocally.value, pidTPCopts.enableNetworkOptimizations.value, pidTPCopts.networkSetNumThreads.value, strtoul(headers["Valid-From"].c_str(), NULL, 0), strtoul(headers["Valid-Until"].c_str(), NULL, 0));
490508
std::vector<float> dummyInput(network.getNumInputNodes(), 1.);
491509
network.evalModel(dummyInput);
510+
setupColumnInputNetwork();
492511
LOGP(info, "Retrieved NN corrections for production tag {}, pass number {}, NN-Version number {}", headers["LPMProductionTag"], headers["RecoPassName"], headers["NN-Version"]);
493512
} else {
494513
LOG(fatal) << "No valid NN object found matching retrieved Bethe-Bloch parametrisation for pass " << metadata["RecoPassName"] << ". Please ensure that the requested pass has dedicated NN corrections available";
@@ -497,19 +516,14 @@ class pidTPCModule
497516
}
498517

499518
// Defining some network parameters
500-
int input_dimensions = network.getNumInputNodes();
519+
int input_dimensions = network.getNumColumns();
501520
int output_dimensions = network.getNumOutputNodes();
502-
const uint64_t track_prop_size = input_dimensions * size;
503521
const uint64_t prediction_size = output_dimensions * size;
504522

505523
network_prediction = std::vector<float>(prediction_size * 9); // For each mass hypotheses
506524
const float nNclNormalization = response->GetNClNormalization();
507525
float duration_network = 0;
508526

509-
std::vector<float> track_properties(track_prop_size);
510-
uint64_t counter_track_props = 0;
511-
int loop_counter = 0;
512-
513527
// To load the Hadronic rate once for each collision
514528
float hadronicRateBegin = 0.;
515529
std::vector<float> hadronicRateForCollision(collisions.size(), 0.0f);
@@ -530,88 +544,98 @@ class pidTPCModule
530544
hadronicRateBegin = 0.0f;
531545
}
532546

533-
// Filling a std::vector<float> to be evaluated by the network
534-
// Evaluation on single tracks brings huge overhead: Thus evaluation is done on one large vector
547+
// Extract per-column data in a single pass over tracks (instead of 9x)
535548
static constexpr int NParticleTypes = 9;
536549
constexpr int ExpectedInputDimensionsNNV2 = 7;
537550
constexpr int ExpectedInputDimensionsNNV3 = 8;
538551
constexpr int ExpectedInputDimensionsNNV4 = 9;
539-
constexpr auto NetworkVersionV2 = "2";
540-
constexpr auto NetworkVersionV3 = "3";
541-
constexpr auto NetworkVersionV4 = "4";
542-
for (int j = 0; j < NParticleTypes; j++) { // Loop over particle number for which network correction is used
543-
for (auto const& trk : tracks) {
544-
if (!trk.hasTPC()) {
552+
553+
const float hadronicRateDivisor = (collsys == CollisionSystemType::kCollSyspp) ? 1500.f : 50.f;
554+
555+
std::vector<float> colTpcInnerParam, colTgl, colSigned1Pt, colMass;
556+
std::vector<float> colMultNorm, colNclsNorm;
557+
std::vector<float> colOccupancyNorm, colHadronicRateNorm, colPhiMod;
558+
colTpcInnerParam.reserve(size);
559+
colTgl.reserve(size);
560+
colSigned1Pt.reserve(size);
561+
colMultNorm.reserve(size);
562+
colNclsNorm.reserve(size);
563+
if (input_dimensions >= ExpectedInputDimensionsNNV2) {
564+
colOccupancyNorm.reserve(size);
565+
}
566+
if (input_dimensions >= ExpectedInputDimensionsNNV3) {
567+
colHadronicRateNorm.reserve(size);
568+
}
569+
if (input_dimensions >= ExpectedInputDimensionsNNV4) {
570+
colPhiMod.reserve(size);
571+
}
572+
573+
for (auto const& trk : tracks) {
574+
if (!trk.hasTPC()) {
575+
continue;
576+
}
577+
if (pidTPCopts.skipTPCOnly) {
578+
if (!trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) {
545579
continue;
546580
}
547-
if (pidTPCopts.skipTPCOnly) {
548-
if (!trk.hasITS() && !trk.hasTRD() && !trk.hasTOF()) {
549-
continue;
550-
}
551-
}
552-
track_properties[counter_track_props] = trk.tpcInnerParam();
553-
track_properties[counter_track_props + 1] = trk.tgl();
554-
track_properties[counter_track_props + 2] = trk.signed1Pt();
555-
track_properties[counter_track_props + 3] = o2::track::pid_constants::sMasses[j];
556-
track_properties[counter_track_props + 4] = trk.has_collision() ? mults[trk.collisionId()] / 11000. : 1.;
557-
track_properties[counter_track_props + 5] = std::sqrt(nNclNormalization / trk.tpcNClsFound());
558-
if (input_dimensions == ExpectedInputDimensionsNNV2 && networkVersion == NetworkVersionV2) {
559-
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
560-
}
561-
if (input_dimensions == ExpectedInputDimensionsNNV3 && networkVersion == NetworkVersionV3) {
562-
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
563-
if (trk.has_collision()) {
564-
if (collsys == CollisionSystemType::kCollSyspp) {
565-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.;
566-
} else {
567-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.;
568-
}
569-
} else {
570-
// asign Hadronic Rate at beginning of run if track does not belong to a collision
571-
if (collsys == CollisionSystemType::kCollSyspp) {
572-
track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.;
573-
} else {
574-
track_properties[counter_track_props + 7] = hadronicRateBegin / 50.;
575-
}
576-
}
581+
}
582+
colTpcInnerParam.push_back(trk.tpcInnerParam());
583+
colTgl.push_back(trk.tgl());
584+
colSigned1Pt.push_back(trk.signed1Pt());
585+
colMultNorm.push_back(trk.has_collision() ? mults[trk.collisionId()] / 11000.f : 1.f);
586+
colNclsNorm.push_back(std::sqrt(nNclNormalization / trk.tpcNClsFound()));
587+
if (input_dimensions >= ExpectedInputDimensionsNNV2) {
588+
colOccupancyNorm.push_back(trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000.f : 1.f);
589+
}
590+
if (input_dimensions >= ExpectedInputDimensionsNNV3) {
591+
if (trk.has_collision()) {
592+
colHadronicRateNorm.push_back(hadronicRateForCollision[trk.collisionId()] / hadronicRateDivisor);
593+
} else {
594+
colHadronicRateNorm.push_back(hadronicRateBegin / hadronicRateDivisor);
577595
}
596+
}
597+
if (input_dimensions >= ExpectedInputDimensionsNNV4) {
598+
colPhiMod.push_back(std::fmod(std::fmod(trk.phi(), 2.f * static_cast<float>(M_PI)) + 2.f * static_cast<float>(M_PI), static_cast<float>(M_PI) / 9.0f));
599+
}
600+
}
578601

579-
if (input_dimensions == ExpectedInputDimensionsNNV4 && networkVersion == NetworkVersionV4) {
580-
track_properties[counter_track_props + 6] = trk.has_collision() ? collisions.iteratorAt(trk.collisionId()).ft0cOccupancyInTimeRange() / 60000. : 1.;
581-
if (trk.has_collision()) {
582-
if (collsys == CollisionSystemType::kCollSyspp) {
583-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 1500.;
584-
} else {
585-
track_properties[counter_track_props + 7] = hadronicRateForCollision[trk.collisionId()] / 50.;
586-
}
587-
} else {
588-
// asign Hadronic Rate at beginning of run if track does not belong to a collision
589-
if (collsys == CollisionSystemType::kCollSyspp) {
590-
track_properties[counter_track_props + 7] = hadronicRateBegin / 1500.;
591-
} else {
592-
track_properties[counter_track_props + 7] = hadronicRateBegin / 50.;
593-
}
594-
}
595-
track_properties[counter_track_props + 8] = std::fmod(std::fmod(trk.phi(), 2 * M_PI) + 2 * M_PI, M_PI / 9.0);
596-
}
597-
counter_track_props += input_dimensions;
602+
const int64_t nValidTracks = static_cast<int64_t>(colTpcInnerParam.size());
603+
colMass.resize(nValidTracks);
604+
auto memInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
605+
606+
// Evaluate network once per hypothesis, passing columns as separate tensors
607+
for (int j = 0; j < NParticleTypes; j++) {
608+
std::fill(colMass.begin(), colMass.end(), o2::track::pid_constants::sMasses[j]);
609+
610+
// Build column tensors (zero-copy wrapping existing vectors)
611+
std::vector<Ort::Value> inputTensors;
612+
inputTensors.reserve(input_dimensions);
613+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colTpcInnerParam.data(), nValidTracks, &nValidTracks, 1));
614+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colTgl.data(), nValidTracks, &nValidTracks, 1));
615+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colSigned1Pt.data(), nValidTracks, &nValidTracks, 1));
616+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colMass.data(), nValidTracks, &nValidTracks, 1));
617+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colMultNorm.data(), nValidTracks, &nValidTracks, 1));
618+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colNclsNorm.data(), nValidTracks, &nValidTracks, 1));
619+
if (input_dimensions >= ExpectedInputDimensionsNNV2) {
620+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colOccupancyNorm.data(), nValidTracks, &nValidTracks, 1));
621+
}
622+
if (input_dimensions >= ExpectedInputDimensionsNNV3) {
623+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colHadronicRateNorm.data(), nValidTracks, &nValidTracks, 1));
624+
}
625+
if (input_dimensions >= ExpectedInputDimensionsNNV4) {
626+
inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memInfo, colPhiMod.data(), nValidTracks, &nValidTracks, 1));
598627
}
599628

600629
auto start_network_eval = std::chrono::high_resolution_clock::now();
601-
float* output_network = network.evalModel(track_properties);
630+
float* output_network = network.evalModel<float>(inputTensors);
602631
auto stop_network_eval = std::chrono::high_resolution_clock::now();
603632
duration_network += std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_eval - start_network_eval).count();
604633
for (uint64_t k = 0; k < prediction_size; k += output_dimensions) {
605634
for (int l = 0; l < output_dimensions; l++) {
606-
network_prediction[k + l + prediction_size * loop_counter] = output_network[k + l];
635+
network_prediction[k + l + prediction_size * j] = output_network[k + l];
607636
}
608637
}
609-
610-
counter_track_props = 0;
611-
loop_counter += 1;
612638
}
613-
track_properties.clear();
614-
615639
auto stop_network_total = std::chrono::high_resolution_clock::now();
616640
LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval ONNX): " << duration_network / (size * 9) << "ns ; Total time (eval ONNX): " << duration_network / 1000000000 << " s";
617641
LOG(debug) << "Neural Network for the TPC PID response correction: Time per track (eval + overhead): " << std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_total - start_network_total).count() / (size * 9) << "ns ; Total time (eval + overhead): " << std::chrono::duration<float, std::ratio<1, 1000000000>>(stop_network_total - start_network_total).count() / 1000000000 << " s";

Tools/ML/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111

1212
o2physics_add_library(MLCore
1313
SOURCES model.cxx
14-
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime
14+
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore ONNXRuntime::ONNXRuntime ONNX::onnx_proto
1515
)

0 commit comments

Comments
 (0)