From 1525347d973d964e8825dd5ac477d50362d842d9 Mon Sep 17 00:00:00 2001
From: Luugaaa
Date: Fri, 18 Jul 2025 10:46:14 -0400
Subject: [PATCH 1/4] deterministic fix
---
.../transforms/intensity/brightness.py | 3 ++-
.../transforms/intensity/contrast.py | 3 ++-
.../transforms/intensity/gamma.py | 9 +++++---
.../transforms/intensity/gaussian_noise.py | 17 +++++++-------
.../transforms/intensity/inversion.py | 3 ++-
.../nnunet/random_binary_operator.py | 2 +-
.../nnunet/remove_connected_components.py | 2 +-
.../transforms/noise/gaussian_blur.py | 2 +-
batchgeneratorsv2/transforms/noise/rician.py | 8 ++++---
.../transforms/spatial/low_resolution.py | 6 +++--
.../transforms/spatial/mirroring.py | 2 +-
.../transforms/spatial/spatial.py | 23 ++++++-------------
batchgeneratorsv2/transforms/utils/random.py | 2 +-
13 files changed, 41 insertions(+), 41 deletions(-)
diff --git a/batchgeneratorsv2/transforms/intensity/brightness.py b/batchgeneratorsv2/transforms/intensity/brightness.py
index 48521ff..668a79f 100644
--- a/batchgeneratorsv2/transforms/intensity/brightness.py
+++ b/batchgeneratorsv2/transforms/intensity/brightness.py
@@ -14,7 +14,8 @@ def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p
def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel = torch.from_numpy(apply_to_channel_np)
if self.synchronize_channels:
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
else:
diff --git a/batchgeneratorsv2/transforms/intensity/contrast.py b/batchgeneratorsv2/transforms/intensity/contrast.py
index 8376e4d..819299a 100644
--- a/batchgeneratorsv2/transforms/intensity/contrast.py
+++ b/batchgeneratorsv2/transforms/intensity/contrast.py
@@ -36,7 +36,8 @@ def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchroni
def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel = torch.from_numpy(apply_to_channel_np)
if self.synchronize_channels:
multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
else:
diff --git a/batchgeneratorsv2/transforms/intensity/gamma.py b/batchgeneratorsv2/transforms/intensity/gamma.py
index ffc336f..6fe3222 100644
--- a/batchgeneratorsv2/transforms/intensity/gamma.py
+++ b/batchgeneratorsv2/transforms/intensity/gamma.py
@@ -5,6 +5,7 @@
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
+import numpy as np
class GammaTransform(ImageOnlyTransform):
def __init__(self, gamma: RandomScalar, p_invert_image: float, synchronize_channels: bool, p_per_channel: float,
@@ -18,9 +19,11 @@ def __init__(self, gamma: RandomScalar, p_invert_image: float, synchronize_chann
def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
- retain_stats = torch.rand(len(apply_to_channel)) < self.p_retain_stats
- invert_image = torch.rand(len(apply_to_channel)) < self.p_invert_image
+ apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel = torch.from_numpy(apply_to_channel_np)
+
+ retain_stats = torch.from_numpy(np.random.rand(len(apply_to_channel)) < self.p_retain_stats)
+ invert_image = torch.from_numpy(np.random.rand(len(apply_to_channel)) < self.p_invert_image)
if self.synchronize_channels:
gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=None)] * len(apply_to_channel))
diff --git a/batchgeneratorsv2/transforms/intensity/gaussian_noise.py b/batchgeneratorsv2/transforms/intensity/gaussian_noise.py
index 54fa523..7c0f32f 100644
--- a/batchgeneratorsv2/transforms/intensity/gaussian_noise.py
+++ b/batchgeneratorsv2/transforms/intensity/gaussian_noise.py
@@ -5,6 +5,8 @@
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
import torch
+import numpy as np
+
class GaussianNoiseTransform(ImageOnlyTransform):
def __init__(self,
@@ -19,7 +21,7 @@ def __init__(self,
def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
dct = {}
- dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel
+ dct['apply_to_channel'] = np.random.rand(shape[0]) < self.p_per_channel
dct['sigmas'] = \
[sample_scalar(self.noise_variance, data_dict['image'])
for i in range(sum(dct['apply_to_channel']))] if not self.synchronize_channels \
@@ -36,15 +38,12 @@ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
def _sample_gaussian_noise(self, img_shape: Tuple[int, ...], **params):
if not isinstance(params['sigmas'], list):
num_channels = sum(params['apply_to_channel'])
- # gaussian = torch.tile(torch.normal(0, params['sigmas'], size=(1, *img_shape[1:])),
- # (num_channels, *[1]*(len(img_shape) - 1)))
- gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:]))
- gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1)))
+ noise_np = np.random.normal(0, params['sigmas'], size=(1, *img_shape[1:]))
+ gaussian = torch.from_numpy(noise_np.astype(np.float32))
+ gaussian = gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1)))
else:
- gaussian = [
- torch.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas']
- ]
- gaussian = torch.cat(gaussian, dim=0)
+ noise_np_list = [np.random.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas']]
+ gaussian = torch.cat([torch.from_numpy(n.astype(np.float32)) for n in noise_np_list], dim=0)
return gaussian
diff --git a/batchgeneratorsv2/transforms/intensity/inversion.py b/batchgeneratorsv2/transforms/intensity/inversion.py
index 845d7e0..89e8113 100644
--- a/batchgeneratorsv2/transforms/intensity/inversion.py
+++ b/batchgeneratorsv2/transforms/intensity/inversion.py
@@ -18,7 +18,8 @@ def get_parameters(self, **data_dict) -> dict:
if np.random.uniform() < self.p_synchronize_channels:
apply_to_channel = torch.arange(0, shape[0])
else:
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel = torch.from_numpy(apply_to_channel_np)
else:
apply_to_channel = []
return {
diff --git a/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py b/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py
index 17a4368..3b3e377 100644
--- a/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py
+++ b/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py
@@ -83,7 +83,7 @@ def __init__(self,
def get_parameters(self, **data_dict) -> dict:
# this needs to be applied in random order to the channels
np.random.shuffle(self.channel_idx)
- apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
+ apply_to_channels = [self.channel_idx[i] for i, j in enumerate(np.random.rand(len(self.channel_idx)) < self.p_per_label) if j]
operators = [np.random.choice(self.any_of_these) for _ in apply_to_channels]
strel_size = [sample_scalar(self.strel_size, image=data_dict['image'], channel=a) for a in apply_to_channels]
return {
diff --git a/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py b/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py
index 9c5e7e6..fc67bc2 100644
--- a/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py
+++ b/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py
@@ -27,7 +27,7 @@ def __init__(self,
def get_parameters(self, **data_dict) -> dict:
# this needs to be applied in random order to the channels
np.random.shuffle(self.channel_idx)
- apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
+ apply_to_channels = [self.channel_idx[i] for i, j in enumerate(np.random.rand(len(self.channel_idx)) < self.p_per_label) if j]
# self.fill_with_other_class_p cannot be resolved here because we don't know how many components there are
return {
diff --git a/batchgeneratorsv2/transforms/noise/gaussian_blur.py b/batchgeneratorsv2/transforms/noise/gaussian_blur.py
index 1a707f4..e583a69 100644
--- a/batchgeneratorsv2/transforms/noise/gaussian_blur.py
+++ b/batchgeneratorsv2/transforms/noise/gaussian_blur.py
@@ -99,7 +99,7 @@ def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
dims = len(shape) - 1
dct = {}
- dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel
+ dct['apply_to_channel'] = np.random.rand(shape[0]) < self.p_per_channel
if self.synchronize_axes:
dct['sigmas'] = \
[[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
diff --git a/batchgeneratorsv2/transforms/noise/rician.py b/batchgeneratorsv2/transforms/noise/rician.py
index a7c3882..d997887 100644
--- a/batchgeneratorsv2/transforms/noise/rician.py
+++ b/batchgeneratorsv2/transforms/noise/rician.py
@@ -1,7 +1,7 @@
import torch
from typing import Tuple
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
-
+import numpy as np
class RicianNoiseTransform(ImageOnlyTransform):
"""
@@ -21,8 +21,10 @@ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
var = params['variance']
- noise_real = torch.empty_like(img).normal_(mean=0.0, std=var)
- noise_imag = torch.empty_like(img).normal_(mean=0.0, std=var)
+ noise_real_np = np.random.normal(0.0, var, size=img.shape).astype(np.float32)
+ noise_imag_np = np.random.normal(0.0, var, size=img.shape).astype(np.float32)
+ noise_real = torch.from_numpy(noise_real_np).to(img.device)
+ noise_imag = torch.from_numpy(noise_imag_np).to(img.device)
min_val = img.min()
shifted = img - min_val
diff --git a/batchgeneratorsv2/transforms/spatial/low_resolution.py b/batchgeneratorsv2/transforms/spatial/low_resolution.py
index 9ed632e..c5abbaa 100644
--- a/batchgeneratorsv2/transforms/spatial/low_resolution.py
+++ b/batchgeneratorsv2/transforms/spatial/low_resolution.py
@@ -6,6 +6,7 @@
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
from torch.nn.functional import interpolate
+import numpy as np
class SimulateLowResolutionTransform(ImageOnlyTransform):
def __init__(self,
@@ -32,9 +33,10 @@ def __init__(self,
def get_parameters(self, **data_dict) -> dict:
shape = data_dict['image'].shape
if self.allowed_channels is None:
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0]
+ apply_to_channel = torch.from_numpy(apply_to_channel_np)
else:
- apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel]
+ apply_to_channel = [i for i in self.allowed_channels if np.random.rand() < self.p_per_channel]
if self.synchronize_channels:
if self.synchronize_axes:
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel))
diff --git a/batchgeneratorsv2/transforms/spatial/mirroring.py b/batchgeneratorsv2/transforms/spatial/mirroring.py
index 89fa3b9..dcb7bbe 100644
--- a/batchgeneratorsv2/transforms/spatial/mirroring.py
+++ b/batchgeneratorsv2/transforms/spatial/mirroring.py
@@ -11,7 +11,7 @@ def __init__(self, allowed_axes: Tuple[int, ...]):
self.allowed_axes = allowed_axes
def get_parameters(self, **data_dict) -> dict:
- axes = [i for i in self.allowed_axes if torch.rand(1) < 0.5]
+ axes = [i for i in self.allowed_axes if np.random.rand() < 0.5]
return {
'axes': axes
}
diff --git a/batchgeneratorsv2/transforms/spatial/spatial.py b/batchgeneratorsv2/transforms/spatial/spatial.py
index 6da7c17..8ae666e 100644
--- a/batchgeneratorsv2/transforms/spatial/spatial.py
+++ b/batchgeneratorsv2/transforms/spatial/spatial.py
@@ -115,27 +115,18 @@ def get_parameters(self, **data_dict) -> dict:
dim=i, deformation_scale=deformation_scales[i])
for i in range(dim)]
# doing it like this for better memory layout for blurring
- offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
+ offsets_np = np.random.randn(dim, *self.patch_size).astype(np.float32)
- # all the additional time elastic deform takes is spent here
for d in range(dim):
- # fft torch, slower
- # for i in range(offsets.ndim - 1):
- # offsets[d] = blur_dimension(offsets[d][None], sigmas[d], i, force_use_fft=True, truncate=6)[0]
-
- # fft numpy, this is faster o.O
- tmp = np.fft.fftn(offsets[d].numpy())
+ tmp = np.fft.fftn(offsets_np[d])
tmp = fourier_gaussian(tmp, sigmas[d])
- offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
-
- # tmp = offsets[d].numpy().astype(np.float64)
- # gaussian_filter(tmp, sigmas[d], 0, output=tmp)
- # offsets[d] = torch.from_numpy(tmp).to(offsets.dtype)
- # print(offsets.dtype)
+ offsets_np[d] = np.fft.ifftn(tmp).real
- mx = torch.max(torch.abs(offsets[d]))
- offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
+ mx = np.max(np.abs(offsets_np[d]))
+ offsets_np[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
+
spatial_dims = tuple(list(range(1, dim + 1)))
+ offsets = torch.from_numpy(offsets_np)
offsets = torch.permute(offsets, (*spatial_dims, 0))
else:
offsets = None
diff --git a/batchgeneratorsv2/transforms/utils/random.py b/batchgeneratorsv2/transforms/utils/random.py
index 0a55054..cf62b9f 100644
--- a/batchgeneratorsv2/transforms/utils/random.py
+++ b/batchgeneratorsv2/transforms/utils/random.py
@@ -13,7 +13,7 @@ def __init__(self, transform: BasicTransform, apply_probability: float = 1):
self.apply_probability = apply_probability
def get_parameters(self, **data_dict) -> dict:
- return {"apply_transform": torch.rand(1).item() < self.apply_probability}
+ return {"apply_transform": np.random.rand() < self.apply_probability}
def apply(self, data_dict: dict, **params) -> dict:
if params['apply_transform']:
From 61ee047161dea8b9c15905cc8044daa1db2cc52c Mon Sep 17 00:00:00 2001
From: Luugaaa
Date: Fri, 18 Jul 2025 11:09:54 -0400
Subject: [PATCH 2/4] missing np
---
batchgeneratorsv2/transforms/spatial/mirroring.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/batchgeneratorsv2/transforms/spatial/mirroring.py b/batchgeneratorsv2/transforms/spatial/mirroring.py
index dcb7bbe..bc57d62 100644
--- a/batchgeneratorsv2/transforms/spatial/mirroring.py
+++ b/batchgeneratorsv2/transforms/spatial/mirroring.py
@@ -4,6 +4,7 @@
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
+import numpy as np
class MirrorTransform(BasicTransform):
def __init__(self, allowed_axes: Tuple[int, ...]):
From eb14c61b0070d992241d9e203e1566ff172f77e3 Mon Sep 17 00:00:00 2001
From: Luugaaa
Date: Fri, 18 Jul 2025 12:15:47 -0400
Subject: [PATCH 3/4] determinism pipeline testing
---
determinism_test_pipeline.py | 228 +++++++++++++++++++++++++++++++++++
1 file changed, 228 insertions(+)
create mode 100644 determinism_test_pipeline.py
diff --git a/determinism_test_pipeline.py b/determinism_test_pipeline.py
new file mode 100644
index 0000000..52c4506
--- /dev/null
+++ b/determinism_test_pipeline.py
@@ -0,0 +1,228 @@
+import sys
+import os
+import torch
+import numpy as np
+from copy import deepcopy
+from PIL import Image
+import torchvision.transforms.functional as TF
+import numpy as np
+
+sys.path.insert(0, os.path.abspath('.'))
+
+from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform, BrightnessAdditiveTransform
+from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
+from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
+from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
+from batchgeneratorsv2.transforms.intensity.inversion import InvertImageTransform
+from batchgeneratorsv2.transforms.intensity.random_clip import CutOffOutliersTransform
+from batchgeneratorsv2.transforms.local.brightness_gradient import BrightnessGradientAdditiveTransform
+from batchgeneratorsv2.transforms.local.local_contrast import LocalContrastTransform
+from batchgeneratorsv2.transforms.local.local_gamma import LocalGammaTransform
+from batchgeneratorsv2.transforms.local.local_smoothing import LocalSmoothingTransform
+from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
+from batchgeneratorsv2.transforms.nnunet.remove_connected_components import RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from batchgeneratorsv2.transforms.noise.blank_rectangle import BlankRectangleTransform
+from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
+from batchgeneratorsv2.transforms.noise.median_filter import MedianFilterTransform
+from batchgeneratorsv2.transforms.noise.rician import RicianNoiseTransform
+from batchgeneratorsv2.transforms.noise.sharpen import SharpeningTransform
+from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
+from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
+from batchgeneratorsv2.transforms.spatial.rot90 import Rot90Transform
+from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
+from batchgeneratorsv2.transforms.spatial.transpose import TransposeAxesTransform
+from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
+from batchgeneratorsv2.transforms.utils.random import RandomTransform
+
+MASTER_SEED = 7
+
+def seed_everything(seed: int):
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.use_deterministic_algorithms(True)
+
+def compare_outputs(output1: dict, output2: dict) -> bool:
+ if output1.keys() != output2.keys():
+ print(f" - โ FAIL: Output dictionaries have different keys.")
+ return False
+ for key in output1.keys():
+ val1, val2 = output1[key], output2[key]
+ if isinstance(val1, torch.Tensor):
+ if not torch.equal(val1, val2):
+ diff = torch.abs(val1.float() - val2.float()).max()
+ print(f" - โ FAIL: Tensor mismatch for key '{key}'. Max difference: {diff.item()}")
+ return False
+ return True
+
+def run_test_on_data(transform_class, kwargs, input_data, dimension_str):
+ print(f"\n๐งช Testing {transform_class.__name__} ({dimension_str})...")
+ try:
+ # --- RUN 1 ---
+ # Seed both numpy and torch to establish a baseline
+ seed_everything(MASTER_SEED)
+ transform_run1 = transform_class(**kwargs)
+ output1 = transform_run1(**deepcopy(input_data))
+
+ # --- RUN 2 ---
+ # Re-seed NumPy. This simulates the real-world scenario
+ # where torch's RNG state is not controlled in worker processes. If a transform
+ # uses torch.rand, it will now fail this test.
+ np.random.seed(MASTER_SEED)
+
+ transform_run2 = transform_class(**kwargs)
+ output2 = transform_run2(**deepcopy(input_data))
+
+ # --- VERIFICATION ---
+ if compare_outputs(output1, output2):
+ print(f" - โ
[PASS] Outputs are identical.")
+ return True
+ else:
+ return False
+ except Exception as e:
+ print(f" - โ [ERROR] An exception occurred: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+def main():
+ print("--- Starting Determinism Test Pipeline ---")
+
+ # --- A single, comprehensive list of all transforms and their base kwargs ---
+ all_transforms_and_kwargs = [
+ (MultiplicativeBrightnessTransform, {'multiplier_range': (0.8, 1.2), 'synchronize_channels': False}),
+ (BrightnessAdditiveTransform, {'mu': 0.0, 'sigma': 0.1, 'per_channel': False}),
+ (ContrastTransform, {'contrast_range': BGContrast((0.9, 1.1)), 'preserve_range': True, 'synchronize_channels': False}),
+ (GammaTransform, {'gamma': (0.8, 1.2), 'p_invert_image': 0.0, 'synchronize_channels': False, 'p_per_channel': 1, 'p_retain_stats': 0.5}),
+ (GaussianNoiseTransform, {'noise_variance': (0, 0.05)}),
+ (InvertImageTransform, {'p_invert_image': 0.1, 'p_synchronize_channels': 0.5, 'p_per_channel': 0.8}),
+ (CutOffOutliersTransform, {'percentile_lower': (0.1, 1.0), 'percentile_upper': (99.0, 99.9)}),
+
+ (BrightnessGradientAdditiveTransform, {'scale': (40, 80), 'max_strength': (0.1, 0.3)}),
+ (LocalContrastTransform, {'scale': (40, 80), 'new_contrast': (0.8, 1.2)}),
+ (LocalGammaTransform, {'scale': (40, 80), 'gamma': (0.8, 1.2)}),
+ (LocalSmoothingTransform, {'scale': (40, 80), 'kernel_size': (0.5, 1.5)}),
+
+ (ApplyRandomBinaryOperatorTransform, {'channel_idx': [0], 'strel_size': (1, 2)}),
+ (RemoveRandomConnectedComponentFromOneHotEncodingTransform, {'channel_idx': [0], 'fill_with_other_class_p': 0.5}),
+
+ (BlankRectangleTransform, {'rectangle_size': ((2, 5), (2, 5), (2, 5)), 'rectangle_value': (0, 1), 'num_rectangles': (1, 2)}),
+ (GaussianBlurTransform, {'blur_sigma': (0.5, 1.5), 'benchmark': False}),
+ (MedianFilterTransform, {'filter_size': 3}),
+ (RicianNoiseTransform, {'noise_variance': (0, 0.05)}),
+ (SharpeningTransform, {'strength': (0.1, 0.2)}),
+
+ (SimulateLowResolutionTransform, {'scale': (0.8, 1.0), 'synchronize_channels': False, 'synchronize_axes': False, 'ignore_axes': None}),
+ (MirrorTransform, {'allowed_axes': (0, 1, 2)}),
+ (Rot90Transform, {'num_axis_combinations': 1, 'allowed_axes': {0, 1, 2}}),
+ (SpatialTransform, {'patch_size': (12, 12, 12), 'patch_center_dist_from_border': [6, 6, 6], 'random_crop': True, 'p_elastic_deform': 0.2, 'p_rotation': 0.2, 'p_scaling': 0.2}),
+ (TransposeAxesTransform, {'allowed_axes': {0, 1, 2}}),
+ ]
+
+ passed_count = 0
+ failed_count = 0
+
+ # ===============================================================
+ # Part 1: Test all transforms on 3D data
+ # ===============================================================
+ print("\n--- Part 1: Testing all transforms on 3D data ---")
+ seed_everything(MASTER_SEED)
+ input_data_3d = {'image': torch.randn(2, 16, 16, 16)}
+ for transform_class, kwargs in all_transforms_and_kwargs:
+ if run_test_on_data(transform_class, kwargs, input_data_3d, "3D"):
+ passed_count += 1
+ else:
+ failed_count += 1
+
+ # ===============================================================
+ # Part 2: Test all transforms on 2D data
+ # ===============================================================
+ print("\n\n--- Part 2: Testing all transforms on 2D data ---")
+ seed_everything(MASTER_SEED)
+ input_data_2d = {'image': torch.randn(3, 64, 64)}
+ for transform_class, kwargs_base in all_transforms_and_kwargs:
+
+ # Skip transforms that are 3D only
+ if transform_class == SpatialTransform:
+ print(f"\n๐งช Skipping {transform_class.__name__} (3D-only)...")
+ continue
+
+ # Adapt kwargs for 2D compatibility
+ if 'allowed_axes' in kwargs_base: kwargs_base['allowed_axes'] = {ax for ax in kwargs_base['allowed_axes'] if ax < 2}
+ if 'rectangle_size' in kwargs_base: kwargs_base['rectangle_size'] = kwargs_base['rectangle_size'][:2]
+
+ if run_test_on_data(transform_class, kwargs_base, input_data_2d, "2D"):
+ passed_count += 1
+ else:
+ failed_count += 1
+
+ # ===============================================================
+ # Part 3: Run composed pipeline on sample_image.jpg using ONLY changed transforms
+ # ===============================================================
+ print("\n\n--- Part 3: Testing composed pipeline on sample_image.jpg (fixed transforms only) ---")
+ try:
+ with Image.open("sample_image.jpg") as img:
+ # Resize to be square to ensure all spatial transforms are compatible
+ img = img.resize((512, 512))
+ input_tensor_2d = TF.to_tensor(img.convert("RGB"))
+ TF.to_pil_image(input_tensor_2d).save("sample_image_original.png")
+ print("โ
Loaded, resized, and saved 'sample_image.jpg' as 'sample_image_original.png'.")
+
+ # Define the set of transform classes whose internal logic we fixed
+ changed_transforms = {
+ ContrastTransform,
+ GammaTransform,
+ GaussianNoiseTransform,
+ InvertImageTransform,
+ MultiplicativeBrightnessTransform,
+ ApplyRandomBinaryOperatorTransform,
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform,
+ GaussianBlurTransform,
+ RicianNoiseTransform,
+ SimulateLowResolutionTransform,
+ MirrorTransform
+ }
+
+ # Filter the main list to get only the transforms we changed
+ transforms_to_compose = []
+ for cls, kw_base in all_transforms_and_kwargs:
+ if cls in changed_transforms and cls != SpatialTransform:
+ kw = deepcopy(kw_base)
+ # Adapt kwargs for 2D on the fly
+ if 'allowed_axes' in kw: kw['allowed_axes'] = {ax for ax in kw['allowed_axes'] if ax < 2}
+ if 'rectangle_size' in kw: kw['rectangle_size'] = kw['rectangle_size'][:2]
+ transforms_to_compose.append(cls(**kw))
+
+ # Wrap all selected transforms in RandomTransform with 100% probability, better for comparison with original batchgeneratorsv2 code
+ composed_2d_transforms = ComposeTransforms([
+ RandomTransform(t, 1.0) for t in transforms_to_compose
+ ])
+
+ print(f"โ
Created a composed pipeline with {len(transforms_to_compose)} fixed transforms.")
+
+ seed_everything(MASTER_SEED)
+ final_output = composed_2d_transforms(**{'image': input_tensor_2d})
+ TF.to_pil_image(final_output['image']).save("sample_image_augmented.png")
+ print("โ
Saved 'sample_image_augmented.png'.")
+
+
+ except FileNotFoundError:
+ print("โ WARNING: 'sample_image.jpg' not found. Skipping Part 3.")
+ except Exception as e:
+ print(f" - โ [ERROR] An exception occurred during the composed 2D test: {e}")
+ import traceback
+ traceback.print_exc()
+ failed_count += 1
+
+ # --- Final Summary ---
+ print("\n--- Test Pipeline Finished ---")
+ print(f"Total Checks: {passed_count + failed_count} | โ
Passed: {passed_count} | โ Failed: {failed_count}")
+ print("--------------------------------")
+
+ return failed_count == 0
+
+if __name__ == "__main__":
+ if main():
+ print("\n๐ All augmentation tests passed and are deterministic!")
+ else:
+ print("\n๐ฅ Some augmentations failed the determinism check. Please review the logs above.")
+ sys.exit(1)
\ No newline at end of file
From df02edb27d3b413ca384104228938064a4911893 Mon Sep 17 00:00:00 2001
From: Luugaaa
Date: Fri, 18 Jul 2025 13:28:52 -0400
Subject: [PATCH 4/4] note on benchmark
---
batchgeneratorsv2/transforms/noise/gaussian_blur.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/batchgeneratorsv2/transforms/noise/gaussian_blur.py b/batchgeneratorsv2/transforms/noise/gaussian_blur.py
index e583a69..0f2ffcb 100644
--- a/batchgeneratorsv2/transforms/noise/gaussian_blur.py
+++ b/batchgeneratorsv2/transforms/noise/gaussian_blur.py
@@ -77,7 +77,7 @@ def __init__(self,
benchmark: bool = False
):
"""
- uses separable gaussian filters for all the speed
+ Uses separable gaussian filters for all the speed. Note : Benchmark = True will likely make the transform non deterministic.
blur_sigma, if callable, will be called as blur_sigma(image, shape, dim) where shape is (c, x(, y, z) and dim i
s 1, 2 or 3 for x, y and z, respectively)