From 14439ab793c98abdae430237021cb3cc4d6cb050 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 10:08:41 +0530 Subject: [PATCH 01/13] refactor unet2d condition model tests. --- .../unets/test_models_unet_2d_condition.py | 962 +++++++++--------- 1 file changed, 504 insertions(+), 458 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 4dbb8ca7c075..3557977ebf5e 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -52,17 +52,24 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ( +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + IPAdapterTesterMixin, LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, ModelTesterMixin, TorchCompileTesterMixin, - UNetTesterMixin, + TrainingTesterMixin, ) if is_peft_available(): from peft import LoraConfig - from peft.tuners.tuners_utils import BaseTunerLayer + + from ..testing_utils.lora import check_if_lora_correctly_set logger = logging.get_logger(__name__) @@ -82,16 +89,6 @@ def get_unet_lora_config(): return unet_lora_config -def check_if_lora_correctly_set(model) -> bool: - """ - Checks if the LoRA layers are correctly set with peft - """ - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - return True - return False - - def create_ip_adapter_state_dict(model): # "ip_adapter" (cross-attention weights) ip_cross_attn_state_dict = {} @@ -354,34 +351,24 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - main_input_name = "sample" - # We override the items here because the unet under consideration is small. - model_split_percents = [0.5, 0.34, 0.4] +class UNet2DConditionTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet2DConditionModel testing.""" @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + def model_class(self): + return UNet2DConditionModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int, int]: return (4, 16, 16) @property - def output_shape(self): - return (4, 16, 16) + def model_split_percents(self) -> list[float]: + return [0.5, 0.34, 0.4] - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + """Return UNet2D model initialization arguments.""" + return { "block_out_channels": (4, 8), "norm_num_groups": 4, "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), @@ -393,24 +380,21 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + """Return dummy inputs for UNet2D model.""" + batch_size = 4 + num_channels = 4 + sizes = (16, 16) - model.enable_xformers_memory_efficient_attention() + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + "encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device), + } - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" +class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin): def test_model_with_attention_head_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -518,163 +502,6 @@ def test_model_with_class_embeddings_concat(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - model.set_attention_slice("auto") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice("max") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice(2) - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - def test_model_sliceable_head_dim(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - - def check_sliceable_dim_attr(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - assert isinstance(module.sliceable_head_dim, int) - - for child in module.children(): - check_sliceable_dim_attr(child) - - # retrieve number of attention layers - for module in model.children(): - check_sliceable_dim_attr(module) - - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "CrossAttnUpBlock2D", - "CrossAttnDownBlock2D", - "UNetMidBlock2DCrossAttn", - "UpBlock2D", - "Transformer2DModel", - "DownBlock2D", - } - attention_head_dim = (8, 16) - block_out_channels = (16, 32) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) - - def test_special_attn_proc(self): - class AttnEasyProc(torch.nn.Module): - def __init__(self, num): - super().__init__() - self.weight = torch.nn.Parameter(torch.tensor(num)) - self.is_run = False - self.number = 0 - self.counter = 0 - - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states += self.weight - - self.is_run = True - self.counter += 1 - self.number = number - - return hidden_states - - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - - processor = AttnEasyProc(5.0) - - model.set_attn_processor(processor) - model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample - - assert processor.counter == 8 - assert processor.is_run - assert processor.number == 123 - - @parameterized.expand( - [ - # fmt: off - [torch.bool], - [torch.long], - [torch.float], - # fmt: on - ] - ) - def test_model_xattn_mask(self, mask_dtype): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)}) - model.to(torch_device) - model.eval() - - cond = inputs_dict["encoder_hidden_states"] - with torch.no_grad(): - full_cond_out = model(**inputs_dict).sample - assert full_cond_out is not None - - keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) - full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample - assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), ( - "a 'keep all' mask should give the same result as no mask" - ) - - trunc_cond = cond[:, :-1, :] - trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample - assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), ( - "discarding the last token from our cond should change the result" - ) - - batch, tokens, _ = cond.shape - mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) - masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample - assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), ( - "masking the last token from our cond should be equivalent to truncating that token out of the condition" - ) - # see diffusers.models.attention_processor::Attention#prepare_attention_mask # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. @@ -683,7 +510,8 @@ def test_model_xattn_mask(self, mask_dtype): reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." ) def test_model_xattn_padding(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) model.to(torch_device) @@ -705,9 +533,9 @@ def test_model_xattn_padding(self): "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." ) - def test_custom_diffusion_processors(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + def test_pickle(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -716,123 +544,512 @@ def test_custom_diffusion_processors(self): model.to(torch_device) with torch.no_grad(): - sample1 = model(**inputs_dict).sample - - custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) - - # make sure we can set a list of attention processors - model.set_attn_processor(custom_diffusion_attn_procs) - model.to(torch_device) - - # test that attn processors can be set to itself - model.set_attn_processor(model.attn_processors) - - with torch.no_grad(): - sample2 = model(**inputs_dict).sample + sample = model(**inputs_dict).sample - assert (sample1 - sample2).abs().max() < 3e-3 + sample_copy = copy.copy(sample) - def test_custom_diffusion_save_load(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + assert (sample - sample_copy).abs().max() < 1e-4 - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) + def test_asymmetrical_unet(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + # Add asymmetry to configs + init_dict["transformer_layers_per_block"] = [[3, 2], 1] + init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) - model.set_attn_processor(custom_diffusion_attn_procs) - - with torch.no_grad(): - sample = model(**inputs_dict).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") - new_model.to(torch_device) + output = model(**inputs_dict).sample + expected_shape = inputs_dict["sample"].shape - with torch.no_grad(): - new_sample = new_model(**inputs_dict).sample + # Check if input and output shapes are the same + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - assert (sample - new_sample).abs().max() < 1e-4 - # custom diffusion and no custom diffusion should be the same - assert (sample - old_sample).abs().max() < 3e-3 +class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig): + """Hub checkpoint loading tests for UNet2DConditionModel.""" - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] ) - def test_custom_diffusion_xformers_on_off(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub_local(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub_local_subfolder(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") + loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_device_map_from_hub_local(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") + loaded_model = self.model_class.from_pretrained( + ckpt_path, local_files_only=True, subfolder="unet", device_map="auto" + ) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + +class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin): + """LoRA adapter tests for UNet2DConditionModel.""" + + @require_peft_backend + def test_load_attn_procs_raise_warning(self): + """Test that deprecated load_attn_procs method raises FutureWarning.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without LoRA + with torch.no_grad(): + non_lora_sample = model(**inputs_dict).sample + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + # forward pass with LoRA + with torch.no_grad(): + lora_sample_1 = model(**inputs_dict).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + model.unload_lora() + + with self.assertWarns(FutureWarning) as warning: + model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + warning_message = str(warning.warnings[0].message) + assert "Using the `load_attn_procs()` method has been deprecated" in warning_message + + # import to still check for the rest of the stuff. + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + with torch.no_grad(): + lora_sample_2 = model(**inputs_dict).sample + + assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( + "LoRA injected UNet should produce different results." + ) + assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), ( + "Loading from a saved checkpoint should produce identical results." + ) + + @require_peft_backend + def test_save_attn_procs_raise_warning(self): + """Test that deprecated save_attn_procs method raises FutureWarning.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + with tempfile.TemporaryDirectory() as tmpdirname: + with self.assertWarns(FutureWarning) as warning: + model.save_attn_procs(tmpdirname) + + warning_message = str(warning.warnings[0].message) + assert "Using the `save_attn_procs()` method has been deprecated" in warning_message + + +class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNet2DConditionModel.""" + + +class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin): + """Training tests for UNet2DConditionModel.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "CrossAttnUpBlock2D", + "CrossAttnDownBlock2D", + "UNetMidBlock2DCrossAttn", + "UpBlock2D", + "Transformer2DModel", + "DownBlock2D", + } + attention_head_dim = (8, 16) + block_out_channels = (16, 32) + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + + +class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin): + """Attention processor tests for UNet2DConditionModel.""" + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + def test_model_attention_slicing(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + init_dict["block_out_channels"] = (16, 32) + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + def test_model_sliceable_head_dim(self): + init_dict = self.get_init_dict() + + init_dict["block_out_channels"] = (16, 32) + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + + def check_sliceable_dim_attr(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + assert isinstance(module.sliceable_head_dim, int) + + for child in module.children(): + check_sliceable_dim_attr(child) + + # retrieve number of attention layers + for module in model.children(): + check_sliceable_dim_attr(module) + + def test_special_attn_proc(self): + class AttnEasyProc(torch.nn.Module): + def __init__(self, num): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(num)) + self.is_run = False + self.number = 0 + self.counter = 0 + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states += self.weight + + self.is_run = True + self.counter += 1 + self.number = number + + return hidden_states + + # enable deterministic behavior for gradient checkpointing + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + init_dict["block_out_channels"] = (16, 32) + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + processor = AttnEasyProc(5.0) + + model.set_attn_processor(processor) + model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + + assert processor.counter == 8 + assert processor.is_run + assert processor.number == 123 + + @parameterized.expand( + [ + # fmt: off + [torch.bool], + [torch.long], + [torch.float], + # fmt: on + ] + ) + def test_model_xattn_mask(self, mask_dtype): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) + full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample + assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), ( + "a 'keep all' mask should give the same result as no mask" + ) + + trunc_cond = cond[:, :-1, :] + trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample + assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), ( + "discarding the last token from our cond should change the result" + ) + + batch, tokens, _ = cond.shape + mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) + masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample + assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), ( + "masking the last token from our cond should be equivalent to truncating that token out of the condition" + ) + + +class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig): + """Custom Diffusion processor tests for UNet2DConditionModel.""" + + def test_custom_diffusion_processors(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + init_dict["block_out_channels"] = (16, 32) + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + + # make sure we can set a list of attention processors + model.set_attn_processor(custom_diffusion_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + assert (sample1 - sample2).abs().max() < 3e-3 + + def test_custom_diffusion_save_load(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + init_dict["block_out_channels"] = (16, 32) + init_dict["attention_head_dim"] = (8, 16) torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) model.set_attn_processor(custom_diffusion_attn_procs) - # default with torch.no_grad(): sample = model(**inputs_dict).sample - model.enable_xformers_memory_efficient_attention() - on_sample = model(**inputs_dict).sample + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=False) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") + new_model.to(torch_device) - model.disable_xformers_memory_efficient_attention() - off_sample = model(**inputs_dict).sample + with torch.no_grad(): + new_sample = new_model(**inputs_dict).sample - assert (sample - on_sample).abs().max() < 1e-4 - assert (sample - off_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 1e-4 - def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + # custom diffusion and no custom diffusion should be the same + assert (sample - old_sample).abs().max() < 3e-3 + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_custom_diffusion_xformers_on_off(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) + torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + model.set_attn_processor(custom_diffusion_attn_procs) + # default with torch.no_grad(): sample = model(**inputs_dict).sample - sample_copy = copy.copy(sample) + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample - assert (sample - sample_copy).abs().max() < 1e-4 + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample - def test_asymmetrical_unet(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - # Add asymmetry to configs - init_dict["transformer_layers_per_block"] = [[3, 2], 1] - init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] + assert (sample - on_sample).abs().max() < 1e-4 + assert (sample - off_sample).abs().max() < 1e-4 - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - output = model(**inputs_dict).sample - expected_shape = inputs_dict["sample"].shape +class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin): + """IP Adapter tests for UNet2DConditionModel.""" - # Check if input and output shapes are the same - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + @property + def ip_adapter_processor_cls(self): + return IPAdapterAttnProcessor + + def create_ip_adapter_state_dict(self, model): + return create_ip_adapter_state_dict(model) + + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + # for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim] + cross_attention_dim = getattr(model.config, "cross_attention_dim", 8) + image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]} + return inputs_dict def test_ip_adapter(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -905,7 +1122,8 @@ def test_ip_adapter(self): assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) def test_ip_adapter_plus(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -977,185 +1195,13 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), - ] - ) - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), - ] - ) - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub_local(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") - loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub_local_subfolder(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") - loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), - ] - ) - def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), - ] - ) - def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_device_map_from_hub_local(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") - loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") - loaded_model = self.model_class.from_pretrained( - ckpt_path, local_files_only=True, subfolder="unet", device_map="auto" - ) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_peft_backend - def test_load_attn_procs_raise_warning(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - # forward pass without LoRA - with torch.no_grad(): - non_lora_sample = model(**inputs_dict).sample - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - # forward pass with LoRA - with torch.no_grad(): - lora_sample_1 = model(**inputs_dict).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - model.unload_lora() - - with self.assertWarns(FutureWarning) as warning: - model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - warning_message = str(warning.warnings[0].message) - assert "Using the `load_attn_procs()` method has been deprecated" in warning_message - - # import to still check for the rest of the stuff. - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - with torch.no_grad(): - lora_sample_2 = model(**inputs_dict).sample - - assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( - "LoRA injected UNet should produce different results." - ) - assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), ( - "Loading from a saved checkpoint should produce identical results." - ) - - @require_peft_backend - def test_save_attn_procs_raise_warning(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - with tempfile.TemporaryDirectory() as tmpdirname: - with self.assertWarns(FutureWarning) as warning: - model.save_attn_procs(tmpdirname) - - warning_message = str(warning.warnings[0].message) - assert "Using the `save_attn_procs()` method has been deprecated" in warning_message - - -class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - - def prepare_init_args_and_inputs_for_common(self): - return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() +class UNet2DConditionModelCompileTests(UNet2DConditionTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for UNet2DConditionModel.""" -class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - def prepare_init_args_and_inputs_for_common(self): - return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() +class UNet2DConditionModelLoRAHotSwapTests(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for UNet2DConditionModel.""" @slow From 0e42a3ff9348dd5b4604c6d99a8fca2d37f4b365 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 11:59:33 +0530 Subject: [PATCH 02/13] fix tests --- tests/models/test_modeling_common.py | 7 ++- tests/models/testing_utils/common.py | 38 ++++++++----- .../unets/test_models_unet_2d_condition.py | 56 ++++++++++--------- 3 files changed, 58 insertions(+), 43 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 1b1a51d1e26f..cd0040bddc34 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -465,7 +465,8 @@ def _accepts_norm_num_groups(model_class): def test_forward_with_norm_groups(self): if not self._accepts_norm_num_groups(self.model_class): pytest.skip(f"Test not supported for {self.model_class.__name__}") - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 32) @@ -480,9 +481,9 @@ def test_forward_with_norm_groups(self): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" class ModelTesterMixin: diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 7036bb16203d..6842b9ee30d2 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -287,8 +287,9 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" ) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] - new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + inputs_dict = self.get_dummy_inputs() + image = model(**inputs_dict, return_dict=False)[0] + new_image = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -308,8 +309,9 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): new_model.to(torch_device) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] - new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + inputs_dict = self.get_dummy_inputs() + image = model(**inputs_dict, return_dict=False)[0] + new_image = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -337,8 +339,9 @@ def test_determinism(self, atol=1e-5, rtol=0): model.to(torch_device) model.eval() - first = model(**self.get_dummy_inputs(), return_dict=False)[0] - second = model(**self.get_dummy_inputs(), return_dict=False)[0] + inputs_dict = self.get_dummy_inputs() + first = model(**inputs_dict, return_dict=False)[0] + second = model(**inputs_dict, return_dict=False)[0] first_flat = first.flatten() second_flat = second.flatten() @@ -395,8 +398,9 @@ def recursive_check(tuple_object, dict_object): model.to(torch_device) model.eval() - outputs_dict = model(**self.get_dummy_inputs()) - outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) + inputs_dict = self.get_dummy_inputs() + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) @@ -523,8 +527,10 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): new_model = new_model.to(torch_device) torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new, return_dict=False)[0] + # Re-create inputs only if they contain a generator (which needs to be reset) + if "generator" in inputs_dict: + inputs_dict = self.get_dummy_inputs() + new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" @@ -563,8 +569,10 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): new_model = new_model.to(torch_device) torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new, return_dict=False)[0] + # Re-create inputs only if they contain a generator (which needs to be reset) + if "generator" in inputs_dict: + inputs_dict = self.get_dummy_inputs() + new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load" @@ -614,8 +622,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt model_parallel = model_parallel.to(torch_device) torch.manual_seed(0) - inputs_dict_parallel = self.get_dummy_inputs() - output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0] + # Re-create inputs only if they contain a generator (which needs to be reset) + if "generator" in inputs_dict: + inputs_dict = self.get_dummy_inputs() + output_parallel = model_parallel(**inputs_dict, return_dict=False)[0] assert_tensors_close( base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading" diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 3557977ebf5e..7ffef80542cc 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -20,6 +20,7 @@ import unittest from collections import OrderedDict +import pytest import torch from huggingface_hub import snapshot_download from parameterized import parameterized @@ -366,6 +367,10 @@ def output_shape(self) -> tuple[int, int, int]: def model_split_percents(self) -> list[float]: return [0.5, 0.34, 0.4] + @property + def main_input_name(self) -> str: + return "sample" + def get_init_dict(self) -> dict: """Return UNet2D model initialization arguments.""" return { @@ -396,7 +401,8 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin): def test_model_with_attention_head_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -411,12 +417,13 @@ def test_model_with_attention_head_dim_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_use_linear_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["use_linear_projection"] = True @@ -430,12 +437,13 @@ def test_model_with_use_linear_projection(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_cross_attention_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["cross_attention_dim"] = (8, 8) @@ -449,12 +457,13 @@ def test_model_with_cross_attention_dim_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_simple_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() batch_size, _, _, sample_size = inputs_dict["sample"].shape @@ -473,12 +482,13 @@ def test_model_with_simple_projection(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_class_embeddings_concat(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() batch_size, _, _, sample_size = inputs_dict["sample"].shape @@ -498,13 +508,13 @@ def test_model_with_class_embeddings_concat(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" # see diffusers.models.attention_processor::Attention#prepare_attention_mask # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. - # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. + # since the use-case (somebody passes in a too-short cross-attn mask) is pretty small, # maybe it's fine that this only works for the unclip use-case. @mark.skip( reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." @@ -565,7 +575,7 @@ def test_asymmetrical_unet(self): expected_shape = inputs_dict["sample"].shape # Check if input and output shapes are the same - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig): @@ -706,12 +716,9 @@ def test_load_attn_procs_raise_warning(self): model.save_attn_procs(tmpdirname) model.unload_lora() - with self.assertWarns(FutureWarning) as warning: + with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"): model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - warning_message = str(warning.warnings[0].message) - assert "Using the `load_attn_procs()` method has been deprecated" in warning_message - # import to still check for the rest of the stuff. assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." @@ -738,11 +745,8 @@ def test_save_attn_procs_raise_warning(self): assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." with tempfile.TemporaryDirectory() as tmpdirname: - with self.assertWarns(FutureWarning) as warning: - model.save_attn_procs(tmpdirname) - - warning_message = str(warning.warnings[0].message) - assert "Using the `save_attn_procs()` method has been deprecated" in warning_message + with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"): + model.save_attn_procs(os.path.join(tmpdirname)) class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin): From 2b67fb65ef5579cfed91c5679f3ffbf451a5d370 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 13:10:04 +0530 Subject: [PATCH 03/13] up --- tests/models/unets/test_models_unet_2d_condition.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 7ffef80542cc..0d1d9e65ce32 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -765,11 +765,7 @@ def test_gradient_checkpointing_is_applied(self): "Transformer2DModel", "DownBlock2D", } - attention_head_dim = (8, 16) - block_out_channels = (16, 32) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin): @@ -988,7 +984,7 @@ def test_custom_diffusion_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")) torch.manual_seed(0) new_model = self.model_class(**init_dict) new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") @@ -1038,7 +1034,7 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM @property def ip_adapter_processor_cls(self): - return IPAdapterAttnProcessor + return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0) def create_ip_adapter_state_dict(self, model): return create_ip_adapter_state_dict(model) From 46d44b73d8d703070912896ee47ff1b60f385305 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 13:30:54 +0530 Subject: [PATCH 04/13] fix --- .../kandinsky2_2/test_kandinsky_combined.py | 12 ++++++++++++ .../pipelines/kandinsky2_2/test_kandinsky_inpaint.py | 3 +++ .../pipelines/kandinsky3/test_kandinsky3_img2img.py | 3 +++ 3 files changed, 18 insertions(+) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index 62f5853da9a5..fbb82ccd6d61 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -155,12 +155,18 @@ def test_save_load_local(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=5e-3) + @unittest.skip("TODO: revisit") def test_callback_inputs(self): pass + @unittest.skip("TODO: revisit") def test_callback_cfg(self): pass + @unittest.skip("TODO: revisit") + def test_pipeline_with_accelerator_device_map(self): + pass + class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyV22Img2ImgCombinedPipeline @@ -406,8 +412,14 @@ def test_save_load_optional_components(self): def test_sequential_cpu_offload_forward_pass(self): super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4) + @unittest.skip("TODO: revisit") def test_callback_inputs(self): pass + @unittest.skip("TODO: revisit") def test_callback_cfg(self): pass + + @unittest.skip("TODO: revisit") + def test_pipeline_with_accelerator_device_map(self): + pass diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index df1dd2d9872c..1df1688e082b 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -296,6 +296,9 @@ def callback_inputs_test(pipe, i, t, callback_kwargs): output = pipe(**inputs)[0] assert output.abs().sum() == 0 + def test_pipeline_with_accelerator_device_map(self): + return super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3) + @slow @require_torch_accelerator diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index d3bfa4b3082c..c789165e0741 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -194,6 +194,9 @@ def test_inference_batch_single_identical(self): def test_save_load_dduf(self): super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + def test_pipeline_with_accelerator_device_map(self): + return super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3) + @slow @require_torch_accelerator From 3371560f1d6c84eb2fd04ddeb664203a836aa31f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 13:34:24 +0530 Subject: [PATCH 05/13] Revert "fix" This reverts commit 46d44b73d8d703070912896ee47ff1b60f385305. --- .../kandinsky2_2/test_kandinsky_combined.py | 12 ------------ .../pipelines/kandinsky2_2/test_kandinsky_inpaint.py | 3 --- .../pipelines/kandinsky3/test_kandinsky3_img2img.py | 3 --- 3 files changed, 18 deletions(-) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index fbb82ccd6d61..62f5853da9a5 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -155,18 +155,12 @@ def test_save_load_local(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=5e-3) - @unittest.skip("TODO: revisit") def test_callback_inputs(self): pass - @unittest.skip("TODO: revisit") def test_callback_cfg(self): pass - @unittest.skip("TODO: revisit") - def test_pipeline_with_accelerator_device_map(self): - pass - class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyV22Img2ImgCombinedPipeline @@ -412,14 +406,8 @@ def test_save_load_optional_components(self): def test_sequential_cpu_offload_forward_pass(self): super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4) - @unittest.skip("TODO: revisit") def test_callback_inputs(self): pass - @unittest.skip("TODO: revisit") def test_callback_cfg(self): pass - - @unittest.skip("TODO: revisit") - def test_pipeline_with_accelerator_device_map(self): - pass diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index 1df1688e082b..df1dd2d9872c 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -296,9 +296,6 @@ def callback_inputs_test(pipe, i, t, callback_kwargs): output = pipe(**inputs)[0] assert output.abs().sum() == 0 - def test_pipeline_with_accelerator_device_map(self): - return super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3) - @slow @require_torch_accelerator diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index c789165e0741..d3bfa4b3082c 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -194,9 +194,6 @@ def test_inference_batch_single_identical(self): def test_save_load_dduf(self): super().test_save_load_dduf(atol=1e-3, rtol=1e-3) - def test_pipeline_with_accelerator_device_map(self): - return super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3) - @slow @require_torch_accelerator From ca4a7b064993e4b085ea8aa6ccf076c45bb2e52b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 15:40:24 +0530 Subject: [PATCH 06/13] up --- tests/models/testing_utils/lora.py | 30 ++++++++++--------- .../unets/test_models_unet_2d_condition.py | 4 +-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index 994aaed55ca7..b1fa123ffaa2 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -15,6 +15,7 @@ import gc import json +import logging import os import re @@ -23,10 +24,12 @@ import torch import torch.nn as nn +from diffusers.utils import logging as diffusers_logging from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import check_if_dicts_are_equal from ...testing_utils import ( + CaptureLogger, assert_tensors_close, backend_empty_cache, is_lora, @@ -477,10 +480,7 @@ def test_enable_lora_hotswap_called_after_adapter_added_raises(self): with pytest.raises(RuntimeError, match=msg): model.enable_lora_hotswap(target_rank=32) - def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): - # ensure that enable_lora_hotswap is called before loading the first adapter - import logging - + def test_enable_lora_hotswap_called_after_adapter_added_warning(self): lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) @@ -488,21 +488,26 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): msg = ( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) - with caplog.at_level(logging.WARNING): + + logger = diffusers_logging.get_logger("diffusers.loaders.peft") + logger.setLevel(logging.WARNING) + with CaptureLogger(logger) as cap_logger: model.enable_lora_hotswap(target_rank=32, check_compiled="warn") - assert any(msg in record.message for record in caplog.records) - def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog): - # check possibility to ignore the error/warning - import logging + assert msg in str(cap_logger.out), f"Expected warning not found. Captured: {cap_logger.out}" + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) model.add_adapter(lora_config) - with caplog.at_level(logging.WARNING): + + logger = diffusers_logging.get_logger("diffusers.loaders.peft") + logger.setLevel(logging.WARNING) + with CaptureLogger(logger) as cap_logger: model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") - assert len(caplog.records) == 0 + + assert cap_logger.out == "", f"Expected no warnings but found: {cap_logger.out}" def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): # check that wrong argument value raises an error @@ -515,9 +520,6 @@ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog): - # check the error and log - import logging - # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers target_modules0 = ["to_q"] target_modules1 = ["to_q", "to_k"] diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 0d1d9e65ce32..5fdada7793fc 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1196,11 +1196,11 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) -class UNet2DConditionModelCompileTests(UNet2DConditionTesterConfig, TorchCompileTesterMixin): +class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin): """Torch compile tests for UNet2DConditionModel.""" -class UNet2DConditionModelLoRAHotSwapTests(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin): +class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin): """LoRA hot-swapping tests for UNet2DConditionModel.""" From ea08148bbd195878a6b44ce1142de8d134ec4e2a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 15:41:57 +0530 Subject: [PATCH 07/13] recompile limit --- tests/models/testing_utils/compile.py | 3 --- tests/models/unets/test_models_unet_2d_condition.py | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index 998b88fb469e..4787d0742b18 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -92,9 +92,6 @@ def test_torch_compile_repeated_blocks(self, recompile_limit=1): model.eval() model.compile_repeated_blocks(fullgraph=True) - if self.model_class.__name__ == "UNet2DConditionModel": - recompile_limit = 2 - with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(recompile_limit=recompile_limit), diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 5fdada7793fc..a7293208d370 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1199,6 +1199,9 @@ def test_ip_adapter_plus(self): class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin): """Torch compile tests for UNet2DConditionModel.""" + def test_torch_compile_repeated_blocks(self): + return super().test_torch_compile_repeated_blocks(recompile_limit=2) + class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin): """LoRA hot-swapping tests for UNet2DConditionModel.""" From ffb254a273c3afd12caa5aa8512b14b27d72f23b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 16:05:32 +0530 Subject: [PATCH 08/13] [tests] refactor test_models_unet_1d.py to use modular testing mixins Refactor UNet1D model tests to follow the modern testing pattern using BaseModelTesterConfig and focused mixin classes (ModelTesterMixin, MemoryTesterMixin, TrainingTesterMixin, LoraTesterMixin). Both UNet1D standard and RL variants now have separate config classes and dedicated test classes organized by concern (core, memory, training, LoRA, hub loading). Co-Authored-By: Claude Opus 4.6 --- tests/models/unets/test_models_unet_1d.py | 298 +++++++++------------- 1 file changed, 123 insertions(+), 175 deletions(-) diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index bac017e7e7d3..a52de51cf097 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import pytest import torch @@ -26,64 +24,41 @@ slow, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) -class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet1DModel - main_input_name = "sample" - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 +_LAYERWISE_CASTING_XFAIL_REASON = ( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." +) - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet1DTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet1DModel testing (standard variant).""" @property - def input_shape(self): - return (4, 14, 16) + def model_class(self): + return UNet1DModel @property def output_shape(self): - return (4, 14, 16) - - @unittest.skip("Test not supported.") - def test_ema_training(self): - pass - - @unittest.skip("Test not supported.") - def test_training(self): - pass - - @unittest.skip("Test not supported.") - def test_layerwise_casting_training(self): - pass - - def test_determinism(self): - super().test_determinism() - - def test_outputs_equivalence(self): - super().test_outputs_equivalence() - - def test_from_save_pretrained(self): - super().test_from_save_pretrained() + return (14, 16) - def test_from_save_pretrained_variant(self): - super().test_from_save_pretrained_variant() - - def test_model_from_pretrained(self): - super().test_model_from_pretrained() - - def test_output(self): - super().test_output() + @property + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": (8, 8, 16, 16), "in_channels": 14, "out_channels": 14, @@ -97,18 +72,48 @@ def prepare_init_args_and_inputs_for_common(self): "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), "act_fn": "swish", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + return { + "sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device), + "timestep": torch.tensor([10] * batch_size).to(torch_device), + } + + +class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Not implemented yet for this UNet") + def test_forward_with_norm_groups(self): + pass + + +class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin): + @pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON) + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() + + +class TestUNet1DTraining(UNet1DTesterConfig, TrainingTesterMixin): + pass + + +class TestUNet1DLoRA(UNet1DTesterConfig, LoraTesterMixin): + pass + + +class TestUNet1DHubLoading(UNet1DTesterConfig): def test_from_pretrained_hub(self): model, loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - image = model(**self.dummy_input) + image = model(**self.get_dummy_inputs()) assert image is not None, "Make sure output is not None" @@ -131,12 +136,7 @@ def test_output_pretrained(self): # fmt: off expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass + assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3) @slow def test_unet_1d_maestro(self): @@ -157,98 +157,29 @@ def test_unet_1d_maestro(self): assert (output_sum - 224.0896).abs() < 0.5 assert (output_max - 0.0607).abs() < 4e-4 - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_inference(self): - super().test_layerwise_casting_inference() - - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_memory(self): - pass +# ============================================================================= +# UNet1D RL (Value Function) Model Tests +# ============================================================================= -class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet1DModel - main_input_name = "sample" - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 - - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"sample": noise, "timestep": time_step} +class UNet1DRLTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet1DModel testing (RL value function variant).""" @property - def input_shape(self): - return (4, 14, 16) + def model_class(self): + return UNet1DModel @property def output_shape(self): - return (4, 14, 1) - - def test_determinism(self): - super().test_determinism() - - def test_outputs_equivalence(self): - super().test_outputs_equivalence() + return (1,) - def test_from_save_pretrained(self): - super().test_from_save_pretrained() - - def test_from_save_pretrained_variant(self): - super().test_from_save_pretrained_variant() - - def test_model_from_pretrained(self): - super().test_model_from_pretrained() - - def test_output(self): - # UNetRL is a value-function is different output shape - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - @unittest.skip("Test not supported.") - def test_ema_training(self): - pass - - @unittest.skip("Test not supported.") - def test_training(self): - pass - - @unittest.skip("Test not supported.") - def test_layerwise_casting_training(self): - pass + @property + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "in_channels": 14, "out_channels": 14, "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], @@ -264,18 +195,62 @@ def prepare_init_args_and_inputs_for_common(self): "time_embedding_type": "positional", "act_fn": "mish", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + return { + "sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device), + "timestep": torch.tensor([10] * batch_size).to(torch_device), + } + + +class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Not implemented yet for this UNet") + def test_forward_with_norm_groups(self): + pass + + @torch.no_grad() + def test_output(self): + # UNetRL is a value-function with different output shape (batch, 1) + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + inputs_dict = self.get_dummy_inputs() + output = model(**inputs_dict, return_dict=False)[0] + + assert output is not None + expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) + assert output.shape == expected_shape, "Input and output shapes do not match" + + +class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin): + @pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON) + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() + + +class TestUNet1DRLTraining(UNet1DRLTesterConfig, TrainingTesterMixin): + pass + + +class TestUNet1DRLLoRA(UNet1DRLTesterConfig, LoraTesterMixin): + pass + + +class TestUNet1DRLHubLoading(UNet1DRLTesterConfig): def test_from_pretrained_hub(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" ) - self.assertIsNotNone(value_function) - self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + assert value_function is not None + assert len(vf_loading_info["missing_keys"]) == 0 value_function.to(torch_device) - image = value_function(**self.dummy_input) + image = value_function(**self.get_dummy_inputs()) assert image is not None, "Make sure output is not None" @@ -299,31 +274,4 @@ def test_output_pretrained(self): # fmt: off expected_output_slice = torch.tensor([165.25] * seq_len) # fmt: on - self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass - - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_inference(self): - pass - - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_memory(self): - pass + assert torch.allclose(output, expected_output_slice, rtol=1e-3) From 0411da77398efad7f520a664439de45e2cff0189 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 16:08:35 +0530 Subject: [PATCH 09/13] [tests] refactor test_models_unet_2d.py to use modular testing mixins Refactor UNet2D model tests (standard, LDM, NCSN++) to follow the modern testing pattern. Each variant gets its own config class and dedicated test classes organized by concern (core, memory, training, LoRA, hub loading). Co-Authored-By: Claude Opus 4.6 --- tests/models/unets/test_models_unet_2d.py | 310 +++++++++++++--------- 1 file changed, 183 insertions(+), 127 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index e289f44303f2..ddf7025c1059 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -15,12 +15,11 @@ import gc import math -import unittest +import pytest import torch from diffusers import UNet2DModel -from diffusers.utils import logging from ...testing_utils import ( backend_empty_cache, @@ -31,39 +30,41 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() -class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) +# ============================================================================= +# Standard UNet2D Model Tests +# ============================================================================= - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet2DTesterConfig(BaseModelTesterConfig): + """Base configuration for standard UNet2DModel testing.""" @property - def input_shape(self): - return (3, 32, 32) + def model_class(self): + return UNet2DModel @property def output_shape(self): return (3, 32, 32) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def main_input_name(self): + return "sample" + + def get_init_dict(self): + return { "block_out_channels": (4, 8), "norm_num_groups": 2, "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), @@ -74,11 +75,22 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 2, "sample_size": 32, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + } + + +class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin): def test_mid_block_attn_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["add_attention"] = True init_dict["attn_norm_num_groups"] = 4 @@ -87,13 +99,11 @@ def test_mid_block_attn_groups(self): model.to(torch_device) model.eval() - self.assertIsNotNone( - model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not." + assert model.mid_block.attentions[0].group_norm is not None, ( + "Mid block Attention group norm should exist but does not." ) - self.assertEqual( - model.mid_block.attentions[0].group_norm.num_groups, - init_dict["attn_norm_num_groups"], - "Mid block Attention group norm does not have the expected number of groups.", + assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], ( + "Mid block Attention group norm does not have the expected number of groups." ) with torch.no_grad(): @@ -102,13 +112,15 @@ def test_mid_block_attn_groups(self): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_mid_block_none(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + mid_none_init_dict = self.get_init_dict() + mid_none_inputs_dict = self.get_dummy_inputs() mid_none_init_dict["mid_block_type"] = None model = self.model_class(**init_dict) @@ -119,7 +131,7 @@ def test_mid_block_none(self): mid_none_model.to(torch_device) mid_none_model.eval() - self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.") + assert mid_none_model.mid_block is None, "Mid block should not exist." with torch.no_grad(): output = model(**inputs_dict) @@ -133,8 +145,14 @@ def test_mid_block_none(self): if isinstance(mid_none_output, dict): mid_none_output = mid_none_output.to_tuple()[0] - self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.") + assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different." + +class TestUNet2DMemory(UNet2DTesterConfig, MemoryTesterMixin): + pass + + +class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = { "AttnUpBlock2D", @@ -143,41 +161,36 @@ def test_gradient_checkpointing_is_applied(self): "UpBlock2D", "DownBlock2D", } - # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` - attention_head_dim = 8 - block_out_channels = (16, 32) + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) +class TestUNet2DLoRA(UNet2DTesterConfig, LoraTesterMixin): + pass -class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) +# ============================================================================= +# UNet2D LDM Model Tests +# ============================================================================= - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet2DLDMTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet2DModel LDM variant testing.""" @property - def input_shape(self): - return (4, 32, 32) + def model_class(self): + return UNet2DModel @property def output_shape(self): return (4, 32, 32) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def main_input_name(self): + return "sample" + + def get_init_dict(self): + return { "sample_size": 32, "in_channels": 4, "out_channels": 4, @@ -187,17 +200,46 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ("DownBlock2D", "DownBlock2D"), "up_block_types": ("UpBlock2D", "UpBlock2D"), } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + } + + +class TestUNet2DLDM(UNet2DLDMTesterConfig, ModelTesterMixin, UNetTesterMixin): + pass + + +class TestUNet2DLDMMemory(UNet2DLDMTesterConfig, MemoryTesterMixin): + pass + + +class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNet2DLDMLoRA(UNet2DLDMTesterConfig, LoraTesterMixin): + pass + + +class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig): def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - image = model(**self.dummy_input).sample + image = model(**self.get_dummy_inputs()).sample assert image is not None, "Make sure output is not None" @@ -205,7 +247,7 @@ def test_from_pretrained_hub(self): def test_from_pretrained_accelerate(self): model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model.to(torch_device) - image = model(**self.dummy_input).sample + image = model(**self.get_dummy_inputs()).sample assert image is not None, "Make sure output is not None" @@ -265,44 +307,31 @@ def test_output_pretrained(self): expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) # fmt: on - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3) - # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` - attention_head_dim = 32 - block_out_channels = (32, 64) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) +# ============================================================================= +# NCSN++ Model Tests +# ============================================================================= -class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" +class NCSNppTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet2DModel NCSN++ variant testing.""" @property - def dummy_input(self, sizes=(32, 32)): - batch_size = 4 - num_channels = 3 - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) - - return {"sample": noise, "timestep": time_step} + def model_class(self): + return UNet2DModel @property - def input_shape(self): + def output_shape(self): return (3, 32, 32) @property - def output_shape(self): - return (3, 32, 32) + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": [32, 64, 64, 64], "in_channels": 3, "layers_per_block": 1, @@ -324,17 +353,75 @@ def prepare_init_args_and_inputs_for_common(self): "SkipUpBlock2D", ], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device), + } + + +class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Test not supported.") + def test_forward_with_norm_groups(self): + pass + + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_keep_in_fp32_modules(self): + pass + + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_from_save_pretrained_dtype_inference(self): + pass + + +class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin): + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_layerwise_casting_memory(self): + pass + + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_layerwise_casting_training(self): + pass + + +class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "UNetMidBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestNCSNppLoRA(NCSNppTesterConfig, LoraTesterMixin): + pass + + +class TestNCSNppHubLoading(NCSNppTesterConfig): @slow def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - inputs = self.dummy_input + inputs = self.get_dummy_inputs() noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) inputs["sample"] = noise image = model(**inputs) @@ -361,7 +448,7 @@ def test_output_pretrained_ve_mid(self): expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056]) # fmt: on - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) def test_output_pretrained_ve_large(self): model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") @@ -382,35 +469,4 @@ def test_output_pretrained_ve_large(self): expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) # fmt: on - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # not required for this model - pass - - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "UNetMidBlock2D", - } - - block_out_channels = (32, 64, 64, 64) - - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, block_out_channels=block_out_channels - ) - - def test_effective_gradient_checkpointing(self): - super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) - - @unittest.skip( - "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." - ) - def test_layerwise_casting_inference(self): - pass - - @unittest.skip( - "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." - ) - def test_layerwise_casting_memory(self): - pass + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) From ecbaed793da764e99fb1cf7e839876d229d7320f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 16:09:42 +0530 Subject: [PATCH 10/13] [tests] refactor test_models_unet_3d_condition.py to use modular testing mixins Refactor UNet3DConditionModel tests to follow the modern testing pattern with separate classes for core, attention, memory, training, and LoRA. Co-Authored-By: Claude Opus 4.6 --- .../unets/test_models_unet_3d_condition.py | 167 ++++++++++-------- 1 file changed, 97 insertions(+), 70 deletions(-) diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py index f73e3461c38e..39dfdd920969 100644 --- a/tests/models/unets/test_models_unet_3d_condition.py +++ b/tests/models/unets/test_models_unet_3d_condition.py @@ -16,49 +16,50 @@ import unittest import numpy as np +import pytest import torch -from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.utils import logging +from diffusers import UNet3DConditionModel from diffusers.utils.import_utils import is_xformers_available -from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + skip_mps, + torch_device, +) +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -logger = logging.get_logger(__name__) - @skip_mps -class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet3DConditionModel - main_input_name = "sample" +class UNet3DConditionTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet3DConditionModel testing.""" @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - num_frames = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + def model_class(self): + return UNet3DConditionModel @property - def input_shape(self): + def output_shape(self): return (4, 4, 16, 16) @property - def output_shape(self): - return (4, 4, 16, 16) + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": (4, 8), "norm_num_groups": 4, "down_block_types": ( @@ -73,27 +74,25 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (16, 16) - model.enable_xformers_memory_efficient_attention() + return { + "sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + "encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device), + } - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" +class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin): # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (32, 64) init_dict["norm_num_groups"] = 32 @@ -107,39 +106,74 @@ def test_forward_with_norm_groups(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" # Overriding since the UNet3D outputs a different structure. + @torch.no_grad() def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with torch.no_grad(): - # Warmup pass when using mps (see #372) - if torch_device == "mps" and isinstance(model, ModelMixin): - model(**self.dummy_input) + inputs_dict = self.get_dummy_inputs() - first = model(**inputs_dict) - if isinstance(first, dict): - first = first.sample + first = model(**inputs_dict) + if isinstance(first, dict): + first = first.sample - second = model(**inputs_dict) - if isinstance(second, dict): - second = second.sample + second = model(**inputs_dict) + if isinstance(second, dict): + second = second.sample out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + assert max_diff <= 1e-5 + + def test_feed_forward_chunking(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + init_dict["block_out_channels"] = (32, 64) + init_dict["norm_num_groups"] = 32 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict)[0] + + model.enable_forward_chunking() + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + assert output.shape == output_2.shape, "Shape doesn't match" + assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + + +class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin): + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = 8 @@ -163,21 +197,14 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None - def test_feed_forward_chunking(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - init_dict["block_out_channels"] = (32, 64) - init_dict["norm_num_groups"] = 32 - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() +class TestUNet3DConditionMemory(UNet3DConditionTesterConfig, MemoryTesterMixin): + pass - with torch.no_grad(): - output = model(**inputs_dict)[0] - model.enable_forward_chunking() - with torch.no_grad(): - output_2 = model(**inputs_dict)[0] +class TestUNet3DConditionTraining(UNet3DConditionTesterConfig, TrainingTesterMixin): + pass - self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") - assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + +class TestUNet3DConditionLoRA(UNet3DConditionTesterConfig, LoraTesterMixin): + pass From c6e6992cdd94c958e390ba9377a95fe049084224 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 16:11:00 +0530 Subject: [PATCH 11/13] [tests] refactor test_models_unet_controlnetxs.py to use modular testing mixins Refactor UNetControlNetXSModel tests to follow the modern testing pattern with separate classes for core, memory, training, and LoRA. Specialized tests (from_unet, freeze_unet, forward_no_control, time_embedding_mixing) remain in the core test class. Co-Authored-By: Claude Opus 4.6 --- .../unets/test_models_unet_controlnetxs.py | 126 ++++++++++-------- 1 file changed, 67 insertions(+), 59 deletions(-) diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 40773536df70..8c665927a455 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -13,59 +13,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import numpy as np +import pytest import torch from torch import nn from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel -from diffusers.utils import logging from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() -class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetControlNetXSModel - main_input_name = "sample" +class UNetControlNetXSTesterConfig(BaseModelTesterConfig): + """Base configuration for UNetControlNetXSModel testing.""" @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device) - conditioning_scale = 1 - - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - "controlnet_cond": controlnet_cond, - "conditioning_scale": conditioning_scale, - } + def model_class(self): + return UNetControlNetXSModel @property - def input_shape(self): + def output_shape(self): return (4, 16, 16) @property - def output_shape(self): - return (4, 16, 16) + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "sample_size": 16, "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), @@ -80,11 +65,23 @@ def prepare_init_args_and_inputs_for_common(self): "ctrl_max_norm_num_groups": 2, "ctrl_conditioning_embedding_out_channels": (2, 2), } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + conditioning_image_size = (3, 32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + "encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device), + "controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device), + "conditioning_scale": 1, + } def get_dummy_unet(self): - """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" + """Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter.""" return UNet2DConditionModel( block_out_channels=(4, 8), layers_per_block=2, @@ -99,10 +96,16 @@ def get_dummy_unet(self): ) def get_dummy_controlnet_from_unet(self, unet, **kwargs): - """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" - # size_ratio and conditioning_embedding_out_channels chosen to keep model small + """Build the ControlNetXS-Adapter from a UNet.""" return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) + +class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Test not supported.") + def test_forward_with_norm_groups(self): + # UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32 + pass + def test_from_unet(self): unet = self.get_dummy_unet() controlnet = self.get_dummy_controlnet_from_unet(unet) @@ -115,7 +118,7 @@ def assert_equal_weights(module, weight_dict_prefix): assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value) # # check unet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_unet = [ "time_embedding", "conv_in", @@ -152,7 +155,7 @@ def assert_equal_weights(module, weight_dict_prefix): assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers") # # check controlnet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_controlnet = { "controlnet_cond_embedding": "controlnet_cond_embedding", "conv_in": "ctrl_conv_in", @@ -193,12 +196,12 @@ def assert_unfrozen(module): for p in module.parameters(): assert p.requires_grad - init_dict, _ = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() model = UNetControlNetXSModel(**init_dict) model.freeze_unet_params() # # check unet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_unet = [ model.base_time_embedding, model.base_conv_in, @@ -236,7 +239,7 @@ def assert_unfrozen(module): assert_frozen(u.upsamplers) # # check controlnet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_controlnet = [ model.controlnet_cond_embedding, model.ctrl_conv_in, @@ -267,16 +270,6 @@ def assert_unfrozen(module): for u in model.up_blocks: assert_unfrozen(u.ctrl_to_base) - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "Transformer2DModel", - "UNetMidBlock2DCrossAttn", - "ControlNetXSCrossAttnDownBlock2D", - "ControlNetXSCrossAttnMidBlock2D", - "ControlNetXSCrossAttnUpBlock2D", - } - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - @is_flaky def test_forward_no_control(self): unet = self.get_dummy_unet() @@ -287,7 +280,7 @@ def test_forward_no_control(self): unet = unet.to(torch_device) model = model.to(torch_device) - input_ = self.dummy_input + input_ = self.get_dummy_inputs() control_specific_input = ["controlnet_cond", "conditioning_scale"] input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} @@ -312,7 +305,7 @@ def test_time_embedding_mixing(self): model = model.to(torch_device) model_mix_time = model_mix_time.to(torch_device) - input_ = self.dummy_input + input_ = self.get_dummy_inputs() with torch.no_grad(): output = model(**input_).sample @@ -320,7 +313,22 @@ def test_time_embedding_mixing(self): assert output.shape == output_mix_time.shape - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. - pass + +class TestUNetControlNetXSMemory(UNetControlNetXSTesterConfig, MemoryTesterMixin): + pass + + +class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "Transformer2DModel", + "UNetMidBlock2DCrossAttn", + "ControlNetXSCrossAttnDownBlock2D", + "ControlNetXSCrossAttnMidBlock2D", + "ControlNetXSCrossAttnUpBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNetControlNetXSLoRA(UNetControlNetXSTesterConfig, LoraTesterMixin): + pass From 99de4ceab8e338459e8bce8df5c11ee1619c31c8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 16:13:42 +0530 Subject: [PATCH 12/13] [tests] refactor test_models_unet_spatiotemporal.py to use modular testing mixins Refactored the spatiotemporal UNet test file to follow the modern modular testing pattern with BaseModelTesterConfig and focused test classes: - UNetSpatioTemporalTesterConfig: Base configuration with model setup - TestUNetSpatioTemporal: Core model tests (ModelTesterMixin, UNetTesterMixin) - TestUNetSpatioTemporalAttention: Attention-related tests (AttentionTesterMixin) - TestUNetSpatioTemporalMemory: Memory/offloading tests (MemoryTesterMixin) - TestUNetSpatioTemporalTraining: Training tests (TrainingTesterMixin) - TestUNetSpatioTemporalLoRA: LoRA adapter tests (LoraTesterMixin) Co-Authored-By: Claude Opus 4.6 --- .../unets/test_models_unet_spatiotemporal.py | 173 +++++++++--------- 1 file changed, 88 insertions(+), 85 deletions(-) diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py index 7df868c9e95b..26dc4d8dda78 100644 --- a/tests/models/unets/test_models_unet_spatiotemporal.py +++ b/tests/models/unets/test_models_unet_spatiotemporal.py @@ -16,10 +16,10 @@ import copy import unittest +import pytest import torch from diffusers import UNetSpatioTemporalConditionModel -from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from ...testing_utils import ( @@ -28,45 +28,36 @@ skip_mps, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() @skip_mps -class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetSpatioTemporalConditionModel - main_input_name = "sample" - - @property - def dummy_input(self): - batch_size = 2 - num_frames = 2 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device) - - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - "added_time_ids": self._get_add_time_ids(), - } +class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig): + """Base configuration for UNetSpatioTemporalConditionModel testing.""" @property - def input_shape(self): - return (2, 2, 4, 32, 32) + def model_class(self): + return UNetSpatioTemporalConditionModel @property def output_shape(self): return (4, 32, 32) + @property + def main_input_name(self): + return "sample" + @property def fps(self): return 6 @@ -83,8 +74,8 @@ def noise_aug_strength(self): def addition_time_embed_dim(self): return 32 - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": (32, 64), "down_block_types": ( "CrossAttnDownBlockSpatioTemporal", @@ -103,8 +94,23 @@ def prepare_init_args_and_inputs_for_common(self): "projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3, "addition_time_embed_dim": self.addition_time_embed_dim, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + + def get_dummy_inputs(self): + batch_size = 2 + num_frames = 2 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device) + + return { + "sample": noise, + "timestep": time_step, + "encoder_hidden_states": encoder_hidden_states, + "added_time_ids": self._get_add_time_ids(), + } def _get_add_time_ids(self, do_classifier_free_guidance=True): add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength] @@ -124,43 +130,15 @@ def _get_add_time_ids(self, do_classifier_free_guidance=True): return add_time_ids - @unittest.skip("Number of Norm Groups is not configurable") - def test_forward_with_norm_groups(self): - pass - - @unittest.skip("Deprecated functionality") - def test_model_attention_slicing(self): - pass - - @unittest.skip("Not supported") - def test_model_with_use_linear_projection(self): - pass - - @unittest.skip("Not supported") - def test_model_with_simple_projection(self): - pass - @unittest.skip("Not supported") - def test_model_with_class_embeddings_concat(self): +class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Number of Norm Groups is not configurable") + def test_forward_with_norm_groups(self): pass - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - - model.enable_xformers_memory_efficient_attention() - - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" - def test_model_with_num_attention_heads_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["num_attention_heads"] = (8, 16) model = self.model_class(**init_dict) @@ -173,12 +151,13 @@ def test_model_with_num_attention_heads_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_cross_attention_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["cross_attention_dim"] = (32, 32) @@ -192,27 +171,13 @@ def test_model_with_cross_attention_dim_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "TransformerSpatioTemporalModel", - "CrossAttnDownBlockSpatioTemporal", - "DownBlockSpatioTemporal", - "UpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "UNetMidBlockSpatioTemporal", - } - num_attention_heads = (8, 16) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, num_attention_heads=num_attention_heads - ) + assert output.shape == expected_shape, "Input and output shapes do not match" def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["num_attention_heads"] = (8, 16) @@ -225,3 +190,41 @@ def test_pickle(self): sample_copy = copy.copy(sample) assert (sample - sample_copy).abs().max() < 1e-4 + + +class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin): + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + +class TestUNetSpatioTemporalMemory(UNetSpatioTemporalTesterConfig, MemoryTesterMixin): + pass + + +class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "TransformerSpatioTemporalModel", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "UNetMidBlockSpatioTemporal", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNetSpatioTemporalLoRA(UNetSpatioTemporalTesterConfig, LoraTesterMixin): + pass From 5f8303fe3ca26360c8f85ec04e0708dc25e8d16d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 16:36:06 +0530 Subject: [PATCH 13/13] remove test suites that are passed. --- tests/models/unets/test_models_unet_1d.py | 18 ------------- tests/models/unets/test_models_unet_2d.py | 25 ------------------- .../unets/test_models_unet_3d_condition.py | 16 ------------ .../unets/test_models_unet_controlnetxs.py | 10 -------- .../unets/test_models_unet_spatiotemporal.py | 10 -------- 5 files changed, 79 deletions(-) diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index a52de51cf097..1ee6c4768c27 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -27,10 +27,8 @@ from ..test_modeling_common import UNetTesterMixin from ..testing_utils import ( BaseModelTesterConfig, - LoraTesterMixin, MemoryTesterMixin, ModelTesterMixin, - TrainingTesterMixin, ) @@ -96,14 +94,6 @@ def test_layerwise_casting_memory(self): super().test_layerwise_casting_memory() -class TestUNet1DTraining(UNet1DTesterConfig, TrainingTesterMixin): - pass - - -class TestUNet1DLoRA(UNet1DTesterConfig, LoraTesterMixin): - pass - - class TestUNet1DHubLoading(UNet1DTesterConfig): def test_from_pretrained_hub(self): model, loading_info = UNet1DModel.from_pretrained( @@ -233,14 +223,6 @@ def test_layerwise_casting_memory(self): super().test_layerwise_casting_memory() -class TestUNet1DRLTraining(UNet1DRLTesterConfig, TrainingTesterMixin): - pass - - -class TestUNet1DRLLoRA(UNet1DRLTesterConfig, LoraTesterMixin): - pass - - class TestUNet1DRLHubLoading(UNet1DRLTesterConfig): def test_from_pretrained_hub(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index ddf7025c1059..3d24a1f05e00 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -33,7 +33,6 @@ from ..test_modeling_common import UNetTesterMixin from ..testing_utils import ( BaseModelTesterConfig, - LoraTesterMixin, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin, @@ -148,10 +147,6 @@ def test_mid_block_none(self): assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different." -class TestUNet2DMemory(UNet2DTesterConfig, MemoryTesterMixin): - pass - - class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = { @@ -165,10 +160,6 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class TestUNet2DLoRA(UNet2DTesterConfig, LoraTesterMixin): - pass - - # ============================================================================= # UNet2D LDM Model Tests # ============================================================================= @@ -212,14 +203,6 @@ def get_dummy_inputs(self): } -class TestUNet2DLDM(UNet2DLDMTesterConfig, ModelTesterMixin, UNetTesterMixin): - pass - - -class TestUNet2DLDMMemory(UNet2DLDMTesterConfig, MemoryTesterMixin): - pass - - class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} @@ -227,10 +210,6 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class TestUNet2DLDMLoRA(UNet2DLDMTesterConfig, LoraTesterMixin): - pass - - class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig): def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) @@ -409,10 +388,6 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class TestNCSNppLoRA(NCSNppTesterConfig, LoraTesterMixin): - pass - - class TestNCSNppHubLoading(NCSNppTesterConfig): @slow def test_from_pretrained_hub(self): diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py index 39dfdd920969..264c67223fb3 100644 --- a/tests/models/unets/test_models_unet_3d_condition.py +++ b/tests/models/unets/test_models_unet_3d_condition.py @@ -16,7 +16,6 @@ import unittest import numpy as np -import pytest import torch from diffusers import UNet3DConditionModel @@ -32,10 +31,7 @@ from ..testing_utils import ( AttentionTesterMixin, BaseModelTesterConfig, - LoraTesterMixin, - MemoryTesterMixin, ModelTesterMixin, - TrainingTesterMixin, ) @@ -196,15 +192,3 @@ def test_model_attention_slicing(self): with torch.no_grad(): output = model(**inputs_dict) assert output is not None - - -class TestUNet3DConditionMemory(UNet3DConditionTesterConfig, MemoryTesterMixin): - pass - - -class TestUNet3DConditionTraining(UNet3DConditionTesterConfig, TrainingTesterMixin): - pass - - -class TestUNet3DConditionLoRA(UNet3DConditionTesterConfig, LoraTesterMixin): - pass diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 8c665927a455..43ac56f7f54c 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -24,8 +24,6 @@ from ..test_modeling_common import UNetTesterMixin from ..testing_utils import ( BaseModelTesterConfig, - LoraTesterMixin, - MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin, ) @@ -314,10 +312,6 @@ def test_time_embedding_mixing(self): assert output.shape == output_mix_time.shape -class TestUNetControlNetXSMemory(UNetControlNetXSTesterConfig, MemoryTesterMixin): - pass - - class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = { @@ -328,7 +322,3 @@ def test_gradient_checkpointing_is_applied(self): "ControlNetXSCrossAttnUpBlock2D", } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - -class TestUNetControlNetXSLoRA(UNetControlNetXSTesterConfig, LoraTesterMixin): - pass diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py index 26dc4d8dda78..1951d2c0f326 100644 --- a/tests/models/unets/test_models_unet_spatiotemporal.py +++ b/tests/models/unets/test_models_unet_spatiotemporal.py @@ -32,8 +32,6 @@ from ..testing_utils import ( AttentionTesterMixin, BaseModelTesterConfig, - LoraTesterMixin, - MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin, ) @@ -209,10 +207,6 @@ def test_xformers_enable_works(self): ), "xformers is not enabled" -class TestUNetSpatioTemporalMemory(UNetSpatioTemporalTesterConfig, MemoryTesterMixin): - pass - - class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = { @@ -224,7 +218,3 @@ def test_gradient_checkpointing_is_applied(self): "UNetMidBlockSpatioTemporal", } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - -class TestUNetSpatioTemporalLoRA(UNetSpatioTemporalTesterConfig, LoraTesterMixin): - pass