From 1525347d973d964e8825dd5ac477d50362d842d9 Mon Sep 17 00:00:00 2001 From: Luugaaa Date: Fri, 18 Jul 2025 10:46:14 -0400 Subject: [PATCH 1/4] deterministic fix --- .../transforms/intensity/brightness.py | 3 ++- .../transforms/intensity/contrast.py | 3 ++- .../transforms/intensity/gamma.py | 9 +++++--- .../transforms/intensity/gaussian_noise.py | 17 +++++++------- .../transforms/intensity/inversion.py | 3 ++- .../nnunet/random_binary_operator.py | 2 +- .../nnunet/remove_connected_components.py | 2 +- .../transforms/noise/gaussian_blur.py | 2 +- batchgeneratorsv2/transforms/noise/rician.py | 8 ++++--- .../transforms/spatial/low_resolution.py | 6 +++-- .../transforms/spatial/mirroring.py | 2 +- .../transforms/spatial/spatial.py | 23 ++++++------------- batchgeneratorsv2/transforms/utils/random.py | 2 +- 13 files changed, 41 insertions(+), 41 deletions(-) diff --git a/batchgeneratorsv2/transforms/intensity/brightness.py b/batchgeneratorsv2/transforms/intensity/brightness.py index 48521ff..668a79f 100644 --- a/batchgeneratorsv2/transforms/intensity/brightness.py +++ b/batchgeneratorsv2/transforms/intensity/brightness.py @@ -14,7 +14,8 @@ def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p def get_parameters(self, **data_dict) -> dict: shape = data_dict['image'].shape - apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel = torch.from_numpy(apply_to_channel_np) if self.synchronize_channels: multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel)) else: diff --git a/batchgeneratorsv2/transforms/intensity/contrast.py b/batchgeneratorsv2/transforms/intensity/contrast.py index 8376e4d..819299a 100644 --- a/batchgeneratorsv2/transforms/intensity/contrast.py +++ b/batchgeneratorsv2/transforms/intensity/contrast.py @@ -36,7 +36,8 @@ def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchroni def get_parameters(self, **data_dict) -> dict: shape = data_dict['image'].shape - apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel = torch.from_numpy(apply_to_channel_np) if self.synchronize_channels: multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=None)] * len(apply_to_channel)) else: diff --git a/batchgeneratorsv2/transforms/intensity/gamma.py b/batchgeneratorsv2/transforms/intensity/gamma.py index ffc336f..6fe3222 100644 --- a/batchgeneratorsv2/transforms/intensity/gamma.py +++ b/batchgeneratorsv2/transforms/intensity/gamma.py @@ -5,6 +5,7 @@ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform +import numpy as np class GammaTransform(ImageOnlyTransform): def __init__(self, gamma: RandomScalar, p_invert_image: float, synchronize_channels: bool, p_per_channel: float, @@ -18,9 +19,11 @@ def __init__(self, gamma: RandomScalar, p_invert_image: float, synchronize_chann def get_parameters(self, **data_dict) -> dict: shape = data_dict['image'].shape - apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0] - retain_stats = torch.rand(len(apply_to_channel)) < self.p_retain_stats - invert_image = torch.rand(len(apply_to_channel)) < self.p_invert_image + apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel = torch.from_numpy(apply_to_channel_np) + + retain_stats = torch.from_numpy(np.random.rand(len(apply_to_channel)) < self.p_retain_stats) + invert_image = torch.from_numpy(np.random.rand(len(apply_to_channel)) < self.p_invert_image) if self.synchronize_channels: gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=None)] * len(apply_to_channel)) diff --git a/batchgeneratorsv2/transforms/intensity/gaussian_noise.py b/batchgeneratorsv2/transforms/intensity/gaussian_noise.py index 54fa523..7c0f32f 100644 --- a/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +++ b/batchgeneratorsv2/transforms/intensity/gaussian_noise.py @@ -5,6 +5,8 @@ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform import torch +import numpy as np + class GaussianNoiseTransform(ImageOnlyTransform): def __init__(self, @@ -19,7 +21,7 @@ def __init__(self, def get_parameters(self, **data_dict) -> dict: shape = data_dict['image'].shape dct = {} - dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel + dct['apply_to_channel'] = np.random.rand(shape[0]) < self.p_per_channel dct['sigmas'] = \ [sample_scalar(self.noise_variance, data_dict['image']) for i in range(sum(dct['apply_to_channel']))] if not self.synchronize_channels \ @@ -36,15 +38,12 @@ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor: def _sample_gaussian_noise(self, img_shape: Tuple[int, ...], **params): if not isinstance(params['sigmas'], list): num_channels = sum(params['apply_to_channel']) - # gaussian = torch.tile(torch.normal(0, params['sigmas'], size=(1, *img_shape[1:])), - # (num_channels, *[1]*(len(img_shape) - 1))) - gaussian = torch.normal(0, params['sigmas'], size=(1, *img_shape[1:])) - gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1))) + noise_np = np.random.normal(0, params['sigmas'], size=(1, *img_shape[1:])) + gaussian = torch.from_numpy(noise_np.astype(np.float32)) + gaussian = gaussian.expand((num_channels, *[-1]*(len(img_shape) - 1))) else: - gaussian = [ - torch.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas'] - ] - gaussian = torch.cat(gaussian, dim=0) + noise_np_list = [np.random.normal(0, i, size=(1, *img_shape[1:])) for i in params['sigmas']] + gaussian = torch.cat([torch.from_numpy(n.astype(np.float32)) for n in noise_np_list], dim=0) return gaussian diff --git a/batchgeneratorsv2/transforms/intensity/inversion.py b/batchgeneratorsv2/transforms/intensity/inversion.py index 845d7e0..89e8113 100644 --- a/batchgeneratorsv2/transforms/intensity/inversion.py +++ b/batchgeneratorsv2/transforms/intensity/inversion.py @@ -18,7 +18,8 @@ def get_parameters(self, **data_dict) -> dict: if np.random.uniform() < self.p_synchronize_channels: apply_to_channel = torch.arange(0, shape[0]) else: - apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel = torch.from_numpy(apply_to_channel_np) else: apply_to_channel = [] return { diff --git a/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py b/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py index 17a4368..3b3e377 100644 --- a/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +++ b/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py @@ -83,7 +83,7 @@ def __init__(self, def get_parameters(self, **data_dict) -> dict: # this needs to be applied in random order to the channels np.random.shuffle(self.channel_idx) - apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j] + apply_to_channels = [self.channel_idx[i] for i, j in enumerate(np.random.rand(len(self.channel_idx)) < self.p_per_label) if j] operators = [np.random.choice(self.any_of_these) for _ in apply_to_channels] strel_size = [sample_scalar(self.strel_size, image=data_dict['image'], channel=a) for a in apply_to_channels] return { diff --git a/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py b/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py index 9c5e7e6..fc67bc2 100644 --- a/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +++ b/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py @@ -27,7 +27,7 @@ def __init__(self, def get_parameters(self, **data_dict) -> dict: # this needs to be applied in random order to the channels np.random.shuffle(self.channel_idx) - apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j] + apply_to_channels = [self.channel_idx[i] for i, j in enumerate(np.random.rand(len(self.channel_idx)) < self.p_per_label) if j] # self.fill_with_other_class_p cannot be resolved here because we don't know how many components there are return { diff --git a/batchgeneratorsv2/transforms/noise/gaussian_blur.py b/batchgeneratorsv2/transforms/noise/gaussian_blur.py index 1a707f4..e583a69 100644 --- a/batchgeneratorsv2/transforms/noise/gaussian_blur.py +++ b/batchgeneratorsv2/transforms/noise/gaussian_blur.py @@ -99,7 +99,7 @@ def get_parameters(self, **data_dict) -> dict: shape = data_dict['image'].shape dims = len(shape) - 1 dct = {} - dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel + dct['apply_to_channel'] = np.random.rand(shape[0]) < self.p_per_channel if self.synchronize_axes: dct['sigmas'] = \ [[sample_scalar(self.blur_sigma, shape, dim=None)] * dims diff --git a/batchgeneratorsv2/transforms/noise/rician.py b/batchgeneratorsv2/transforms/noise/rician.py index a7c3882..d997887 100644 --- a/batchgeneratorsv2/transforms/noise/rician.py +++ b/batchgeneratorsv2/transforms/noise/rician.py @@ -1,7 +1,7 @@ import torch from typing import Tuple from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform - +import numpy as np class RicianNoiseTransform(ImageOnlyTransform): """ @@ -21,8 +21,10 @@ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict: def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor: var = params['variance'] - noise_real = torch.empty_like(img).normal_(mean=0.0, std=var) - noise_imag = torch.empty_like(img).normal_(mean=0.0, std=var) + noise_real_np = np.random.normal(0.0, var, size=img.shape).astype(np.float32) + noise_imag_np = np.random.normal(0.0, var, size=img.shape).astype(np.float32) + noise_real = torch.from_numpy(noise_real_np).to(img.device) + noise_imag = torch.from_numpy(noise_imag_np).to(img.device) min_val = img.min() shifted = img - min_val diff --git a/batchgeneratorsv2/transforms/spatial/low_resolution.py b/batchgeneratorsv2/transforms/spatial/low_resolution.py index 9ed632e..c5abbaa 100644 --- a/batchgeneratorsv2/transforms/spatial/low_resolution.py +++ b/batchgeneratorsv2/transforms/spatial/low_resolution.py @@ -6,6 +6,7 @@ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform from torch.nn.functional import interpolate +import numpy as np class SimulateLowResolutionTransform(ImageOnlyTransform): def __init__(self, @@ -32,9 +33,10 @@ def __init__(self, def get_parameters(self, **data_dict) -> dict: shape = data_dict['image'].shape if self.allowed_channels is None: - apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel_np = np.where(np.random.rand(shape[0]) < self.p_per_channel)[0] + apply_to_channel = torch.from_numpy(apply_to_channel_np) else: - apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel] + apply_to_channel = [i for i in self.allowed_channels if np.random.rand() < self.p_per_channel] if self.synchronize_channels: if self.synchronize_axes: scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel)) diff --git a/batchgeneratorsv2/transforms/spatial/mirroring.py b/batchgeneratorsv2/transforms/spatial/mirroring.py index 89fa3b9..dcb7bbe 100644 --- a/batchgeneratorsv2/transforms/spatial/mirroring.py +++ b/batchgeneratorsv2/transforms/spatial/mirroring.py @@ -11,7 +11,7 @@ def __init__(self, allowed_axes: Tuple[int, ...]): self.allowed_axes = allowed_axes def get_parameters(self, **data_dict) -> dict: - axes = [i for i in self.allowed_axes if torch.rand(1) < 0.5] + axes = [i for i in self.allowed_axes if np.random.rand() < 0.5] return { 'axes': axes } diff --git a/batchgeneratorsv2/transforms/spatial/spatial.py b/batchgeneratorsv2/transforms/spatial/spatial.py index 6da7c17..8ae666e 100644 --- a/batchgeneratorsv2/transforms/spatial/spatial.py +++ b/batchgeneratorsv2/transforms/spatial/spatial.py @@ -115,27 +115,18 @@ def get_parameters(self, **data_dict) -> dict: dim=i, deformation_scale=deformation_scales[i]) for i in range(dim)] # doing it like this for better memory layout for blurring - offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size)) + offsets_np = np.random.randn(dim, *self.patch_size).astype(np.float32) - # all the additional time elastic deform takes is spent here for d in range(dim): - # fft torch, slower - # for i in range(offsets.ndim - 1): - # offsets[d] = blur_dimension(offsets[d][None], sigmas[d], i, force_use_fft=True, truncate=6)[0] - - # fft numpy, this is faster o.O - tmp = np.fft.fftn(offsets[d].numpy()) + tmp = np.fft.fftn(offsets_np[d]) tmp = fourier_gaussian(tmp, sigmas[d]) - offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real) - - # tmp = offsets[d].numpy().astype(np.float64) - # gaussian_filter(tmp, sigmas[d], 0, output=tmp) - # offsets[d] = torch.from_numpy(tmp).to(offsets.dtype) - # print(offsets.dtype) + offsets_np[d] = np.fft.ifftn(tmp).real - mx = torch.max(torch.abs(offsets[d])) - offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf)) + mx = np.max(np.abs(offsets_np[d])) + offsets_np[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf)) + spatial_dims = tuple(list(range(1, dim + 1))) + offsets = torch.from_numpy(offsets_np) offsets = torch.permute(offsets, (*spatial_dims, 0)) else: offsets = None diff --git a/batchgeneratorsv2/transforms/utils/random.py b/batchgeneratorsv2/transforms/utils/random.py index 0a55054..cf62b9f 100644 --- a/batchgeneratorsv2/transforms/utils/random.py +++ b/batchgeneratorsv2/transforms/utils/random.py @@ -13,7 +13,7 @@ def __init__(self, transform: BasicTransform, apply_probability: float = 1): self.apply_probability = apply_probability def get_parameters(self, **data_dict) -> dict: - return {"apply_transform": torch.rand(1).item() < self.apply_probability} + return {"apply_transform": np.random.rand() < self.apply_probability} def apply(self, data_dict: dict, **params) -> dict: if params['apply_transform']: From 61ee047161dea8b9c15905cc8044daa1db2cc52c Mon Sep 17 00:00:00 2001 From: Luugaaa Date: Fri, 18 Jul 2025 11:09:54 -0400 Subject: [PATCH 2/4] missing np --- batchgeneratorsv2/transforms/spatial/mirroring.py | 1 + 1 file changed, 1 insertion(+) diff --git a/batchgeneratorsv2/transforms/spatial/mirroring.py b/batchgeneratorsv2/transforms/spatial/mirroring.py index dcb7bbe..bc57d62 100644 --- a/batchgeneratorsv2/transforms/spatial/mirroring.py +++ b/batchgeneratorsv2/transforms/spatial/mirroring.py @@ -4,6 +4,7 @@ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform +import numpy as np class MirrorTransform(BasicTransform): def __init__(self, allowed_axes: Tuple[int, ...]): From eb14c61b0070d992241d9e203e1566ff172f77e3 Mon Sep 17 00:00:00 2001 From: Luugaaa Date: Fri, 18 Jul 2025 12:15:47 -0400 Subject: [PATCH 3/4] determinism pipeline testing --- determinism_test_pipeline.py | 228 +++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 determinism_test_pipeline.py diff --git a/determinism_test_pipeline.py b/determinism_test_pipeline.py new file mode 100644 index 0000000..52c4506 --- /dev/null +++ b/determinism_test_pipeline.py @@ -0,0 +1,228 @@ +import sys +import os +import torch +import numpy as np +from copy import deepcopy +from PIL import Image +import torchvision.transforms.functional as TF +import numpy as np + +sys.path.insert(0, os.path.abspath('.')) + +from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform, BrightnessAdditiveTransform +from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast +from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform +from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform +from batchgeneratorsv2.transforms.intensity.inversion import InvertImageTransform +from batchgeneratorsv2.transforms.intensity.random_clip import CutOffOutliersTransform +from batchgeneratorsv2.transforms.local.brightness_gradient import BrightnessGradientAdditiveTransform +from batchgeneratorsv2.transforms.local.local_contrast import LocalContrastTransform +from batchgeneratorsv2.transforms.local.local_gamma import LocalGammaTransform +from batchgeneratorsv2.transforms.local.local_smoothing import LocalSmoothingTransform +from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform +from batchgeneratorsv2.transforms.nnunet.remove_connected_components import RemoveRandomConnectedComponentFromOneHotEncodingTransform +from batchgeneratorsv2.transforms.noise.blank_rectangle import BlankRectangleTransform +from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform +from batchgeneratorsv2.transforms.noise.median_filter import MedianFilterTransform +from batchgeneratorsv2.transforms.noise.rician import RicianNoiseTransform +from batchgeneratorsv2.transforms.noise.sharpen import SharpeningTransform +from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform +from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform +from batchgeneratorsv2.transforms.spatial.rot90 import Rot90Transform +from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform +from batchgeneratorsv2.transforms.spatial.transpose import TransposeAxesTransform +from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms +from batchgeneratorsv2.transforms.utils.random import RandomTransform + +MASTER_SEED = 7 + +def seed_everything(seed: int): + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(True) + +def compare_outputs(output1: dict, output2: dict) -> bool: + if output1.keys() != output2.keys(): + print(f" - โŒ FAIL: Output dictionaries have different keys.") + return False + for key in output1.keys(): + val1, val2 = output1[key], output2[key] + if isinstance(val1, torch.Tensor): + if not torch.equal(val1, val2): + diff = torch.abs(val1.float() - val2.float()).max() + print(f" - โŒ FAIL: Tensor mismatch for key '{key}'. Max difference: {diff.item()}") + return False + return True + +def run_test_on_data(transform_class, kwargs, input_data, dimension_str): + print(f"\n๐Ÿงช Testing {transform_class.__name__} ({dimension_str})...") + try: + # --- RUN 1 --- + # Seed both numpy and torch to establish a baseline + seed_everything(MASTER_SEED) + transform_run1 = transform_class(**kwargs) + output1 = transform_run1(**deepcopy(input_data)) + + # --- RUN 2 --- + # Re-seed NumPy. This simulates the real-world scenario + # where torch's RNG state is not controlled in worker processes. If a transform + # uses torch.rand, it will now fail this test. + np.random.seed(MASTER_SEED) + + transform_run2 = transform_class(**kwargs) + output2 = transform_run2(**deepcopy(input_data)) + + # --- VERIFICATION --- + if compare_outputs(output1, output2): + print(f" - โœ… [PASS] Outputs are identical.") + return True + else: + return False + except Exception as e: + print(f" - โŒ [ERROR] An exception occurred: {e}") + import traceback + traceback.print_exc() + return False + +def main(): + print("--- Starting Determinism Test Pipeline ---") + + # --- A single, comprehensive list of all transforms and their base kwargs --- + all_transforms_and_kwargs = [ + (MultiplicativeBrightnessTransform, {'multiplier_range': (0.8, 1.2), 'synchronize_channels': False}), + (BrightnessAdditiveTransform, {'mu': 0.0, 'sigma': 0.1, 'per_channel': False}), + (ContrastTransform, {'contrast_range': BGContrast((0.9, 1.1)), 'preserve_range': True, 'synchronize_channels': False}), + (GammaTransform, {'gamma': (0.8, 1.2), 'p_invert_image': 0.0, 'synchronize_channels': False, 'p_per_channel': 1, 'p_retain_stats': 0.5}), + (GaussianNoiseTransform, {'noise_variance': (0, 0.05)}), + (InvertImageTransform, {'p_invert_image': 0.1, 'p_synchronize_channels': 0.5, 'p_per_channel': 0.8}), + (CutOffOutliersTransform, {'percentile_lower': (0.1, 1.0), 'percentile_upper': (99.0, 99.9)}), + + (BrightnessGradientAdditiveTransform, {'scale': (40, 80), 'max_strength': (0.1, 0.3)}), + (LocalContrastTransform, {'scale': (40, 80), 'new_contrast': (0.8, 1.2)}), + (LocalGammaTransform, {'scale': (40, 80), 'gamma': (0.8, 1.2)}), + (LocalSmoothingTransform, {'scale': (40, 80), 'kernel_size': (0.5, 1.5)}), + + (ApplyRandomBinaryOperatorTransform, {'channel_idx': [0], 'strel_size': (1, 2)}), + (RemoveRandomConnectedComponentFromOneHotEncodingTransform, {'channel_idx': [0], 'fill_with_other_class_p': 0.5}), + + (BlankRectangleTransform, {'rectangle_size': ((2, 5), (2, 5), (2, 5)), 'rectangle_value': (0, 1), 'num_rectangles': (1, 2)}), + (GaussianBlurTransform, {'blur_sigma': (0.5, 1.5), 'benchmark': False}), + (MedianFilterTransform, {'filter_size': 3}), + (RicianNoiseTransform, {'noise_variance': (0, 0.05)}), + (SharpeningTransform, {'strength': (0.1, 0.2)}), + + (SimulateLowResolutionTransform, {'scale': (0.8, 1.0), 'synchronize_channels': False, 'synchronize_axes': False, 'ignore_axes': None}), + (MirrorTransform, {'allowed_axes': (0, 1, 2)}), + (Rot90Transform, {'num_axis_combinations': 1, 'allowed_axes': {0, 1, 2}}), + (SpatialTransform, {'patch_size': (12, 12, 12), 'patch_center_dist_from_border': [6, 6, 6], 'random_crop': True, 'p_elastic_deform': 0.2, 'p_rotation': 0.2, 'p_scaling': 0.2}), + (TransposeAxesTransform, {'allowed_axes': {0, 1, 2}}), + ] + + passed_count = 0 + failed_count = 0 + + # =============================================================== + # Part 1: Test all transforms on 3D data + # =============================================================== + print("\n--- Part 1: Testing all transforms on 3D data ---") + seed_everything(MASTER_SEED) + input_data_3d = {'image': torch.randn(2, 16, 16, 16)} + for transform_class, kwargs in all_transforms_and_kwargs: + if run_test_on_data(transform_class, kwargs, input_data_3d, "3D"): + passed_count += 1 + else: + failed_count += 1 + + # =============================================================== + # Part 2: Test all transforms on 2D data + # =============================================================== + print("\n\n--- Part 2: Testing all transforms on 2D data ---") + seed_everything(MASTER_SEED) + input_data_2d = {'image': torch.randn(3, 64, 64)} + for transform_class, kwargs_base in all_transforms_and_kwargs: + + # Skip transforms that are 3D only + if transform_class == SpatialTransform: + print(f"\n๐Ÿงช Skipping {transform_class.__name__} (3D-only)...") + continue + + # Adapt kwargs for 2D compatibility + if 'allowed_axes' in kwargs_base: kwargs_base['allowed_axes'] = {ax for ax in kwargs_base['allowed_axes'] if ax < 2} + if 'rectangle_size' in kwargs_base: kwargs_base['rectangle_size'] = kwargs_base['rectangle_size'][:2] + + if run_test_on_data(transform_class, kwargs_base, input_data_2d, "2D"): + passed_count += 1 + else: + failed_count += 1 + + # =============================================================== + # Part 3: Run composed pipeline on sample_image.jpg using ONLY changed transforms + # =============================================================== + print("\n\n--- Part 3: Testing composed pipeline on sample_image.jpg (fixed transforms only) ---") + try: + with Image.open("sample_image.jpg") as img: + # Resize to be square to ensure all spatial transforms are compatible + img = img.resize((512, 512)) + input_tensor_2d = TF.to_tensor(img.convert("RGB")) + TF.to_pil_image(input_tensor_2d).save("sample_image_original.png") + print("โœ… Loaded, resized, and saved 'sample_image.jpg' as 'sample_image_original.png'.") + + # Define the set of transform classes whose internal logic we fixed + changed_transforms = { + ContrastTransform, + GammaTransform, + GaussianNoiseTransform, + InvertImageTransform, + MultiplicativeBrightnessTransform, + ApplyRandomBinaryOperatorTransform, + RemoveRandomConnectedComponentFromOneHotEncodingTransform, + GaussianBlurTransform, + RicianNoiseTransform, + SimulateLowResolutionTransform, + MirrorTransform + } + + # Filter the main list to get only the transforms we changed + transforms_to_compose = [] + for cls, kw_base in all_transforms_and_kwargs: + if cls in changed_transforms and cls != SpatialTransform: + kw = deepcopy(kw_base) + # Adapt kwargs for 2D on the fly + if 'allowed_axes' in kw: kw['allowed_axes'] = {ax for ax in kw['allowed_axes'] if ax < 2} + if 'rectangle_size' in kw: kw['rectangle_size'] = kw['rectangle_size'][:2] + transforms_to_compose.append(cls(**kw)) + + # Wrap all selected transforms in RandomTransform with 100% probability, better for comparison with original batchgeneratorsv2 code + composed_2d_transforms = ComposeTransforms([ + RandomTransform(t, 1.0) for t in transforms_to_compose + ]) + + print(f"โœ… Created a composed pipeline with {len(transforms_to_compose)} fixed transforms.") + + seed_everything(MASTER_SEED) + final_output = composed_2d_transforms(**{'image': input_tensor_2d}) + TF.to_pil_image(final_output['image']).save("sample_image_augmented.png") + print("โœ… Saved 'sample_image_augmented.png'.") + + + except FileNotFoundError: + print("โŒ WARNING: 'sample_image.jpg' not found. Skipping Part 3.") + except Exception as e: + print(f" - โŒ [ERROR] An exception occurred during the composed 2D test: {e}") + import traceback + traceback.print_exc() + failed_count += 1 + + # --- Final Summary --- + print("\n--- Test Pipeline Finished ---") + print(f"Total Checks: {passed_count + failed_count} | โœ… Passed: {passed_count} | โŒ Failed: {failed_count}") + print("--------------------------------") + + return failed_count == 0 + +if __name__ == "__main__": + if main(): + print("\n๐ŸŽ‰ All augmentation tests passed and are deterministic!") + else: + print("\n๐Ÿ”ฅ Some augmentations failed the determinism check. Please review the logs above.") + sys.exit(1) \ No newline at end of file From df02edb27d3b413ca384104228938064a4911893 Mon Sep 17 00:00:00 2001 From: Luugaaa Date: Fri, 18 Jul 2025 13:28:52 -0400 Subject: [PATCH 4/4] note on benchmark --- batchgeneratorsv2/transforms/noise/gaussian_blur.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batchgeneratorsv2/transforms/noise/gaussian_blur.py b/batchgeneratorsv2/transforms/noise/gaussian_blur.py index e583a69..0f2ffcb 100644 --- a/batchgeneratorsv2/transforms/noise/gaussian_blur.py +++ b/batchgeneratorsv2/transforms/noise/gaussian_blur.py @@ -77,7 +77,7 @@ def __init__(self, benchmark: bool = False ): """ - uses separable gaussian filters for all the speed + Uses separable gaussian filters for all the speed. Note : Benchmark = True will likely make the transform non deterministic. blur_sigma, if callable, will be called as blur_sigma(image, shape, dim) where shape is (c, x(, y, z) and dim i s 1, 2 or 3 for x, y and z, respectively)