Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions bindings/pyroot/pythonizations/test/ml_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4570,6 +4570,44 @@ def test16_vector_padding(self):
self.teardown_file(self.file_name5)
raise

def test17_shuffled_split_varies_with_seed(self):
self.create_file1()
self.create_file2()

try:
df1 = ROOT.RDataFrame(self.tree_name, self.file_name1)
df2 = ROOT.RDataFrame(self.tree_name, self.file_name2)

dl1 = ROOT.Experimental.ML.RDataLoader(
[df1, df2],
batch_size=3,
target="b2",
shuffle=True,
drop_remainder=False,
set_seed=42,
)

dl2 = ROOT.Experimental.ML.RDataLoader(
[df1, df2],
batch_size=3,
target="b2",
shuffle=True,
drop_remainder=False,
set_seed=43,
)

_, gen_val1 = dl1.train_test_split(0.4)
_, gen_val2 = dl2.train_test_split(0.4)

val_1_collected = sorted([v for x, y in gen_val1.as_numpy() for v in x.flatten().tolist()])
val_2_collected = sorted([v for x, y in gen_val2.as_numpy() for v in x.flatten().tolist()])

self.assertNotEqual(val_1_collected, val_2_collected)

finally:
self.teardown_file(self.file_name1)
self.teardown_file(self.file_name2)


class DataLoaderRandomUndersampling(unittest.TestCase):
file_name1 = "major.root"
Expand Down
36 changes: 30 additions & 6 deletions tree/ml/inc/ROOT/ML/RClusterLoader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,27 @@ public:
// --- Shuffled path
// Every cluster contributes a prefix to training and a suffix to validation.
// Cost: Each cluster is read twice per epoch, only when validation split is more than 0.
// TODO(staider) Swicth between prefix or suffix for validation randomly per cluster
// We generate a random boolean value to decide whether the training set gets the prefix
// or suffix of each cluster to ensure better shuffling across runs when splitting.
std::mt19937 g(fSetSeed);
std::uniform_int_distribution<int> coin(0, 1);

for (const RClusterRange &c : fAllClusters) {
const std::size_t sz = c.GetNumEntries();
const std::size_t trainSz = static_cast<std::size_t>((1.0f - fValidationSplit) * sz);
const std::size_t valSz = sz - trainSz;

// Randomly assign prefix or suffix to training
bool trainIsPrefix = coin(g);
const uint64_t trainStart = trainIsPrefix ? c.start : c.start + static_cast<std::uint64_t>(valSz);
const uint64_t valStart = trainIsPrefix ? c.start + static_cast<std::uint64_t>(trainSz) : c.start;

if (trainSz > 0) {
fTrainingClusters.push_back({c.rdfIdx, c.start, c.start + static_cast<std::uint64_t>(trainSz)});
fTrainingClusters.push_back({c.rdfIdx, trainStart, trainStart + static_cast<std::uint64_t>(trainSz)});
fNumTrainingEntries += trainSz;
}
if (valSz > 0) {
fValidationClusters.push_back({c.rdfIdx, c.start + static_cast<std::uint64_t>(trainSz), c.end});
fValidationClusters.push_back({c.rdfIdx, valStart, valStart + static_cast<std::uint64_t>(valSz)});
fNumValidationEntries += valSz;
}
}
Expand Down Expand Up @@ -392,14 +401,29 @@ public:
std::min(static_cast<std::size_t>(totalFiltered * (1.0f - fValidationSplit)), trainRemaining);
const std::size_t valCount = totalFiltered - trainCount;

bool trainIsPrefix = true;
if (fShuffle) {
// If shuffling is enabled, we generate a random boolean value to decide whether the training set
// gets the prefix or suffix of each cluster to ensure better shuffling across runs when splitting.
std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster
std::uniform_int_distribution<int> coin(0, 1);
trainIsPrefix = coin(g);
}

// The boundary is the raw entry index of the first entry assigned to validation.
// Stable across epochs since the same filter always produces the same ordered entries.
const std::uint64_t boundary = (valCount > 0) ? rdfEntries[trainCount] : endRow;
const std::uint64_t trainBoundaryEntry = trainIsPrefix ? rdfEntries[trainCount] : rdfEntries[valCount];
const std::uint64_t boundary = (valCount > 0) ? trainBoundaryEntry : endRow;

const std::uint64_t trainStart = trainIsPrefix ? startRow : boundary;
const std::uint64_t trainEnd = trainIsPrefix ? boundary : endRow;
const std::uint64_t valStart = trainIsPrefix ? boundary : startRow;
const std::uint64_t valEnd = trainIsPrefix ? endRow : boundary;

if (trainCount > 0)
fTrainingClusters.push_back({rdfIdx, startRow, boundary, trainCount});
fTrainingClusters.push_back({rdfIdx, trainStart, trainEnd, trainCount});
if (valCount > 0)
fValidationClusters.push_back({rdfIdx, boundary, endRow, valCount});
fValidationClusters.push_back({rdfIdx, valStart, valEnd, valCount});

fAccumulatedFilteredForTrain += trainCount;
return trainCount;
Expand Down
Loading