diff --git a/bindings/pyroot/pythonizations/test/ml_dataloader.py b/bindings/pyroot/pythonizations/test/ml_dataloader.py index 491bcaf8ae5aa..748d3517c5d73 100644 --- a/bindings/pyroot/pythonizations/test/ml_dataloader.py +++ b/bindings/pyroot/pythonizations/test/ml_dataloader.py @@ -4570,6 +4570,44 @@ def test16_vector_padding(self): self.teardown_file(self.file_name5) raise + def test17_shuffled_split_varies_with_seed(self): + self.create_file1() + self.create_file2() + + try: + df1 = ROOT.RDataFrame(self.tree_name, self.file_name1) + df2 = ROOT.RDataFrame(self.tree_name, self.file_name2) + + dl1 = ROOT.Experimental.ML.RDataLoader( + [df1, df2], + batch_size=3, + target="b2", + shuffle=True, + drop_remainder=False, + set_seed=42, + ) + + dl2 = ROOT.Experimental.ML.RDataLoader( + [df1, df2], + batch_size=3, + target="b2", + shuffle=True, + drop_remainder=False, + set_seed=43, + ) + + _, gen_val1 = dl1.train_test_split(0.4) + _, gen_val2 = dl2.train_test_split(0.4) + + val_1_collected = sorted([v for x, y in gen_val1.as_numpy() for v in x.flatten().tolist()]) + val_2_collected = sorted([v for x, y in gen_val2.as_numpy() for v in x.flatten().tolist()]) + + self.assertNotEqual(val_1_collected, val_2_collected) + + finally: + self.teardown_file(self.file_name1) + self.teardown_file(self.file_name2) + class DataLoaderRandomUndersampling(unittest.TestCase): file_name1 = "major.root" diff --git a/tree/ml/inc/ROOT/ML/RClusterLoader.hxx b/tree/ml/inc/ROOT/ML/RClusterLoader.hxx index fabad80b3be33..090b097923c1a 100644 --- a/tree/ml/inc/ROOT/ML/RClusterLoader.hxx +++ b/tree/ml/inc/ROOT/ML/RClusterLoader.hxx @@ -227,18 +227,27 @@ public: // --- Shuffled path // Every cluster contributes a prefix to training and a suffix to validation. // Cost: Each cluster is read twice per epoch, only when validation split is more than 0. - // TODO(staider) Swicth between prefix or suffix for validation randomly per cluster + // We generate a random boolean value to decide whether the training set gets the prefix + // or suffix of each cluster to ensure better shuffling across runs when splitting. + std::mt19937 g(fSetSeed); + std::uniform_int_distribution coin(0, 1); + for (const RClusterRange &c : fAllClusters) { const std::size_t sz = c.GetNumEntries(); const std::size_t trainSz = static_cast((1.0f - fValidationSplit) * sz); const std::size_t valSz = sz - trainSz; + // Randomly assign prefix or suffix to training + bool trainIsPrefix = coin(g); + const uint64_t trainStart = trainIsPrefix ? c.start : c.start + static_cast(valSz); + const uint64_t valStart = trainIsPrefix ? c.start + static_cast(trainSz) : c.start; + if (trainSz > 0) { - fTrainingClusters.push_back({c.rdfIdx, c.start, c.start + static_cast(trainSz)}); + fTrainingClusters.push_back({c.rdfIdx, trainStart, trainStart + static_cast(trainSz)}); fNumTrainingEntries += trainSz; } if (valSz > 0) { - fValidationClusters.push_back({c.rdfIdx, c.start + static_cast(trainSz), c.end}); + fValidationClusters.push_back({c.rdfIdx, valStart, valStart + static_cast(valSz)}); fNumValidationEntries += valSz; } } @@ -392,14 +401,29 @@ public: std::min(static_cast(totalFiltered * (1.0f - fValidationSplit)), trainRemaining); const std::size_t valCount = totalFiltered - trainCount; + bool trainIsPrefix = true; + if (fShuffle) { + // If shuffling is enabled, we generate a random boolean value to decide whether the training set + // gets the prefix or suffix of each cluster to ensure better shuffling across runs when splitting. + std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster + std::uniform_int_distribution coin(0, 1); + trainIsPrefix = coin(g); + } + // The boundary is the raw entry index of the first entry assigned to validation. // Stable across epochs since the same filter always produces the same ordered entries. - const std::uint64_t boundary = (valCount > 0) ? rdfEntries[trainCount] : endRow; + const std::uint64_t trainBoundaryEntry = trainIsPrefix ? rdfEntries[trainCount] : rdfEntries[valCount]; + const std::uint64_t boundary = (valCount > 0) ? trainBoundaryEntry : endRow; + + const std::uint64_t trainStart = trainIsPrefix ? startRow : boundary; + const std::uint64_t trainEnd = trainIsPrefix ? boundary : endRow; + const std::uint64_t valStart = trainIsPrefix ? boundary : startRow; + const std::uint64_t valEnd = trainIsPrefix ? endRow : boundary; if (trainCount > 0) - fTrainingClusters.push_back({rdfIdx, startRow, boundary, trainCount}); + fTrainingClusters.push_back({rdfIdx, trainStart, trainEnd, trainCount}); if (valCount > 0) - fValidationClusters.push_back({rdfIdx, boundary, endRow, valCount}); + fValidationClusters.push_back({rdfIdx, valStart, valEnd, valCount}); fAccumulatedFilteredForTrain += trainCount; return trainCount;