diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 1b1a51d1e26f..cd0040bddc34 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -465,7 +465,8 @@ def _accepts_norm_num_groups(model_class): def test_forward_with_norm_groups(self): if not self._accepts_norm_num_groups(self.model_class): pytest.skip(f"Test not supported for {self.model_class.__name__}") - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 32) @@ -480,9 +481,9 @@ def test_forward_with_norm_groups(self): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" class ModelTesterMixin: diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 7036bb16203d..6842b9ee30d2 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -287,8 +287,9 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" ) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] - new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + inputs_dict = self.get_dummy_inputs() + image = model(**inputs_dict, return_dict=False)[0] + new_image = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -308,8 +309,9 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): new_model.to(torch_device) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] - new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + inputs_dict = self.get_dummy_inputs() + image = model(**inputs_dict, return_dict=False)[0] + new_image = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -337,8 +339,9 @@ def test_determinism(self, atol=1e-5, rtol=0): model.to(torch_device) model.eval() - first = model(**self.get_dummy_inputs(), return_dict=False)[0] - second = model(**self.get_dummy_inputs(), return_dict=False)[0] + inputs_dict = self.get_dummy_inputs() + first = model(**inputs_dict, return_dict=False)[0] + second = model(**inputs_dict, return_dict=False)[0] first_flat = first.flatten() second_flat = second.flatten() @@ -395,8 +398,9 @@ def recursive_check(tuple_object, dict_object): model.to(torch_device) model.eval() - outputs_dict = model(**self.get_dummy_inputs()) - outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) + inputs_dict = self.get_dummy_inputs() + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) @@ -523,8 +527,10 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): new_model = new_model.to(torch_device) torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new, return_dict=False)[0] + # Re-create inputs only if they contain a generator (which needs to be reset) + if "generator" in inputs_dict: + inputs_dict = self.get_dummy_inputs() + new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" @@ -563,8 +569,10 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): new_model = new_model.to(torch_device) torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new, return_dict=False)[0] + # Re-create inputs only if they contain a generator (which needs to be reset) + if "generator" in inputs_dict: + inputs_dict = self.get_dummy_inputs() + new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load" @@ -614,8 +622,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt model_parallel = model_parallel.to(torch_device) torch.manual_seed(0) - inputs_dict_parallel = self.get_dummy_inputs() - output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0] + # Re-create inputs only if they contain a generator (which needs to be reset) + if "generator" in inputs_dict: + inputs_dict = self.get_dummy_inputs() + output_parallel = model_parallel(**inputs_dict, return_dict=False)[0] assert_tensors_close( base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading" diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index 998b88fb469e..4787d0742b18 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -92,9 +92,6 @@ def test_torch_compile_repeated_blocks(self, recompile_limit=1): model.eval() model.compile_repeated_blocks(fullgraph=True) - if self.model_class.__name__ == "UNet2DConditionModel": - recompile_limit = 2 - with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(recompile_limit=recompile_limit), diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index 994aaed55ca7..b1fa123ffaa2 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -15,6 +15,7 @@ import gc import json +import logging import os import re @@ -23,10 +24,12 @@ import torch import torch.nn as nn +from diffusers.utils import logging as diffusers_logging from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import check_if_dicts_are_equal from ...testing_utils import ( + CaptureLogger, assert_tensors_close, backend_empty_cache, is_lora, @@ -477,10 +480,7 @@ def test_enable_lora_hotswap_called_after_adapter_added_raises(self): with pytest.raises(RuntimeError, match=msg): model.enable_lora_hotswap(target_rank=32) - def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): - # ensure that enable_lora_hotswap is called before loading the first adapter - import logging - + def test_enable_lora_hotswap_called_after_adapter_added_warning(self): lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) @@ -488,21 +488,26 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): msg = ( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) - with caplog.at_level(logging.WARNING): + + logger = diffusers_logging.get_logger("diffusers.loaders.peft") + logger.setLevel(logging.WARNING) + with CaptureLogger(logger) as cap_logger: model.enable_lora_hotswap(target_rank=32, check_compiled="warn") - assert any(msg in record.message for record in caplog.records) - def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog): - # check possibility to ignore the error/warning - import logging + assert msg in str(cap_logger.out), f"Expected warning not found. Captured: {cap_logger.out}" + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) model.add_adapter(lora_config) - with caplog.at_level(logging.WARNING): + + logger = diffusers_logging.get_logger("diffusers.loaders.peft") + logger.setLevel(logging.WARNING) + with CaptureLogger(logger) as cap_logger: model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") - assert len(caplog.records) == 0 + + assert cap_logger.out == "", f"Expected no warnings but found: {cap_logger.out}" def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): # check that wrong argument value raises an error @@ -515,9 +520,6 @@ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog): - # check the error and log - import logging - # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers target_modules0 = ["to_q"] target_modules1 = ["to_q", "to_k"] diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index bac017e7e7d3..1ee6c4768c27 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import pytest import torch @@ -26,64 +24,39 @@ slow, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, +) -class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet1DModel - main_input_name = "sample" - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 +_LAYERWISE_CASTING_XFAIL_REASON = ( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." +) - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet1DTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet1DModel testing (standard variant).""" @property - def input_shape(self): - return (4, 14, 16) + def model_class(self): + return UNet1DModel @property def output_shape(self): - return (4, 14, 16) - - @unittest.skip("Test not supported.") - def test_ema_training(self): - pass - - @unittest.skip("Test not supported.") - def test_training(self): - pass + return (14, 16) - @unittest.skip("Test not supported.") - def test_layerwise_casting_training(self): - pass - - def test_determinism(self): - super().test_determinism() - - def test_outputs_equivalence(self): - super().test_outputs_equivalence() - - def test_from_save_pretrained(self): - super().test_from_save_pretrained() - - def test_from_save_pretrained_variant(self): - super().test_from_save_pretrained_variant() - - def test_model_from_pretrained(self): - super().test_model_from_pretrained() - - def test_output(self): - super().test_output() + @property + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": (8, 8, 16, 16), "in_channels": 14, "out_channels": 14, @@ -97,18 +70,40 @@ def prepare_init_args_and_inputs_for_common(self): "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), "act_fn": "swish", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + return { + "sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device), + "timestep": torch.tensor([10] * batch_size).to(torch_device), + } + + +class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Not implemented yet for this UNet") + def test_forward_with_norm_groups(self): + pass + + +class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin): + @pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON) + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() + + +class TestUNet1DHubLoading(UNet1DTesterConfig): def test_from_pretrained_hub(self): model, loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - image = model(**self.dummy_input) + image = model(**self.get_dummy_inputs()) assert image is not None, "Make sure output is not None" @@ -131,12 +126,7 @@ def test_output_pretrained(self): # fmt: off expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass + assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3) @slow def test_unet_1d_maestro(self): @@ -157,98 +147,29 @@ def test_unet_1d_maestro(self): assert (output_sum - 224.0896).abs() < 0.5 assert (output_max - 0.0607).abs() < 4e-4 - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_inference(self): - super().test_layerwise_casting_inference() - - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_memory(self): - pass - - -class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet1DModel - main_input_name = "sample" - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 +# ============================================================================= +# UNet1D RL (Value Function) Model Tests +# ============================================================================= - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet1DRLTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet1DModel testing (RL value function variant).""" @property - def input_shape(self): - return (4, 14, 16) + def model_class(self): + return UNet1DModel @property def output_shape(self): - return (4, 14, 1) - - def test_determinism(self): - super().test_determinism() - - def test_outputs_equivalence(self): - super().test_outputs_equivalence() + return (1,) - def test_from_save_pretrained(self): - super().test_from_save_pretrained() - - def test_from_save_pretrained_variant(self): - super().test_from_save_pretrained_variant() - - def test_model_from_pretrained(self): - super().test_model_from_pretrained() - - def test_output(self): - # UNetRL is a value-function is different output shape - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - @unittest.skip("Test not supported.") - def test_ema_training(self): - pass - - @unittest.skip("Test not supported.") - def test_training(self): - pass - - @unittest.skip("Test not supported.") - def test_layerwise_casting_training(self): - pass + @property + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "in_channels": 14, "out_channels": 14, "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], @@ -264,18 +185,54 @@ def prepare_init_args_and_inputs_for_common(self): "time_embedding_type": "positional", "act_fn": "mish", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + return { + "sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device), + "timestep": torch.tensor([10] * batch_size).to(torch_device), + } + + +class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Not implemented yet for this UNet") + def test_forward_with_norm_groups(self): + pass + + @torch.no_grad() + def test_output(self): + # UNetRL is a value-function with different output shape (batch, 1) + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + inputs_dict = self.get_dummy_inputs() + output = model(**inputs_dict, return_dict=False)[0] + + assert output is not None + expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) + assert output.shape == expected_shape, "Input and output shapes do not match" + + +class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin): + @pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON) + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() + + +class TestUNet1DRLHubLoading(UNet1DRLTesterConfig): def test_from_pretrained_hub(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" ) - self.assertIsNotNone(value_function) - self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + assert value_function is not None + assert len(vf_loading_info["missing_keys"]) == 0 value_function.to(torch_device) - image = value_function(**self.dummy_input) + image = value_function(**self.get_dummy_inputs()) assert image is not None, "Make sure output is not None" @@ -299,31 +256,4 @@ def test_output_pretrained(self): # fmt: off expected_output_slice = torch.tensor([165.25] * seq_len) # fmt: on - self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass - - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_inference(self): - pass - - @pytest.mark.xfail( - reason=( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." - ), - ) - def test_layerwise_casting_memory(self): - pass + assert torch.allclose(output, expected_output_slice, rtol=1e-3) diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index e289f44303f2..3d24a1f05e00 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -15,12 +15,11 @@ import gc import math -import unittest +import pytest import torch from diffusers import UNet2DModel -from diffusers.utils import logging from ...testing_utils import ( backend_empty_cache, @@ -31,39 +30,40 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() -class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) +# ============================================================================= +# Standard UNet2D Model Tests +# ============================================================================= - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet2DTesterConfig(BaseModelTesterConfig): + """Base configuration for standard UNet2DModel testing.""" @property - def input_shape(self): - return (3, 32, 32) + def model_class(self): + return UNet2DModel @property def output_shape(self): return (3, 32, 32) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def main_input_name(self): + return "sample" + + def get_init_dict(self): + return { "block_out_channels": (4, 8), "norm_num_groups": 2, "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), @@ -74,11 +74,22 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 2, "sample_size": 32, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + } + + +class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin): def test_mid_block_attn_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["add_attention"] = True init_dict["attn_norm_num_groups"] = 4 @@ -87,13 +98,11 @@ def test_mid_block_attn_groups(self): model.to(torch_device) model.eval() - self.assertIsNotNone( - model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not." + assert model.mid_block.attentions[0].group_norm is not None, ( + "Mid block Attention group norm should exist but does not." ) - self.assertEqual( - model.mid_block.attentions[0].group_norm.num_groups, - init_dict["attn_norm_num_groups"], - "Mid block Attention group norm does not have the expected number of groups.", + assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], ( + "Mid block Attention group norm does not have the expected number of groups." ) with torch.no_grad(): @@ -102,13 +111,15 @@ def test_mid_block_attn_groups(self): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_mid_block_none(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + mid_none_init_dict = self.get_init_dict() + mid_none_inputs_dict = self.get_dummy_inputs() mid_none_init_dict["mid_block_type"] = None model = self.model_class(**init_dict) @@ -119,7 +130,7 @@ def test_mid_block_none(self): mid_none_model.to(torch_device) mid_none_model.eval() - self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.") + assert mid_none_model.mid_block is None, "Mid block should not exist." with torch.no_grad(): output = model(**inputs_dict) @@ -133,8 +144,10 @@ def test_mid_block_none(self): if isinstance(mid_none_output, dict): mid_none_output = mid_none_output.to_tuple()[0] - self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.") + assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different." + +class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = { "AttnUpBlock2D", @@ -143,41 +156,32 @@ def test_gradient_checkpointing_is_applied(self): "UpBlock2D", "DownBlock2D", } - # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` - attention_head_dim = 8 - block_out_channels = (16, 32) + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) +# ============================================================================= +# UNet2D LDM Model Tests +# ============================================================================= -class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - return {"sample": noise, "timestep": time_step} +class UNet2DLDMTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet2DModel LDM variant testing.""" @property - def input_shape(self): - return (4, 32, 32) + def model_class(self): + return UNet2DModel @property def output_shape(self): return (4, 32, 32) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def main_input_name(self): + return "sample" + + def get_init_dict(self): + return { "sample_size": 32, "in_channels": 4, "out_channels": 4, @@ -187,17 +191,34 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ("DownBlock2D", "DownBlock2D"), "up_block_types": ("UpBlock2D", "UpBlock2D"), } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + } + + +class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig): def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - image = model(**self.dummy_input).sample + image = model(**self.get_dummy_inputs()).sample assert image is not None, "Make sure output is not None" @@ -205,7 +226,7 @@ def test_from_pretrained_hub(self): def test_from_pretrained_accelerate(self): model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model.to(torch_device) - image = model(**self.dummy_input).sample + image = model(**self.get_dummy_inputs()).sample assert image is not None, "Make sure output is not None" @@ -265,44 +286,31 @@ def test_output_pretrained(self): expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) # fmt: on - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3) - def test_gradient_checkpointing_is_applied(self): - expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} - # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` - attention_head_dim = 32 - block_out_channels = (32, 64) - - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) +# ============================================================================= +# NCSN++ Model Tests +# ============================================================================= -class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" +class NCSNppTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet2DModel NCSN++ variant testing.""" @property - def dummy_input(self, sizes=(32, 32)): - batch_size = 4 - num_channels = 3 - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) - - return {"sample": noise, "timestep": time_step} + def model_class(self): + return UNet2DModel @property - def input_shape(self): + def output_shape(self): return (3, 32, 32) @property - def output_shape(self): - return (3, 32, 32) + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": [32, 64, 64, 64], "in_channels": 3, "layers_per_block": 1, @@ -324,17 +332,71 @@ def prepare_init_args_and_inputs_for_common(self): "SkipUpBlock2D", ], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device), + } + + +class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Test not supported.") + def test_forward_with_norm_groups(self): + pass + + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_keep_in_fp32_modules(self): + pass + + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_from_save_pretrained_dtype_inference(self): + pass + + +class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin): + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_layerwise_casting_memory(self): + pass + + @pytest.mark.skip( + "To make layerwise casting work with this model, we will have to update the implementation. " + "Due to potentially low usage, we don't support it here." + ) + def test_layerwise_casting_training(self): + pass + + +class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "UNetMidBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestNCSNppHubLoading(NCSNppTesterConfig): @slow def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - inputs = self.dummy_input + inputs = self.get_dummy_inputs() noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) inputs["sample"] = noise image = model(**inputs) @@ -361,7 +423,7 @@ def test_output_pretrained_ve_mid(self): expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056]) # fmt: on - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) def test_output_pretrained_ve_large(self): model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") @@ -382,35 +444,4 @@ def test_output_pretrained_ve_large(self): expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) # fmt: on - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # not required for this model - pass - - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "UNetMidBlock2D", - } - - block_out_channels = (32, 64, 64, 64) - - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, block_out_channels=block_out_channels - ) - - def test_effective_gradient_checkpointing(self): - super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) - - @unittest.skip( - "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." - ) - def test_layerwise_casting_inference(self): - pass - - @unittest.skip( - "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." - ) - def test_layerwise_casting_memory(self): - pass + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 4dbb8ca7c075..a7293208d370 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -20,6 +20,7 @@ import unittest from collections import OrderedDict +import pytest import torch from huggingface_hub import snapshot_download from parameterized import parameterized @@ -52,17 +53,24 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ( +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + IPAdapterTesterMixin, LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, ModelTesterMixin, TorchCompileTesterMixin, - UNetTesterMixin, + TrainingTesterMixin, ) if is_peft_available(): from peft import LoraConfig - from peft.tuners.tuners_utils import BaseTunerLayer + + from ..testing_utils.lora import check_if_lora_correctly_set logger = logging.get_logger(__name__) @@ -82,16 +90,6 @@ def get_unet_lora_config(): return unet_lora_config -def check_if_lora_correctly_set(model) -> bool: - """ - Checks if the LoRA layers are correctly set with peft - """ - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - return True - return False - - def create_ip_adapter_state_dict(model): # "ip_adapter" (cross-attention weights) ip_cross_attn_state_dict = {} @@ -354,34 +352,28 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - main_input_name = "sample" - # We override the items here because the unet under consideration is small. - model_split_percents = [0.5, 0.34, 0.4] +class UNet2DConditionTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet2DConditionModel testing.""" @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + def model_class(self): + return UNet2DConditionModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int, int]: return (4, 16, 16) @property - def output_shape(self): - return (4, 16, 16) + def model_split_percents(self) -> list[float]: + return [0.5, 0.34, 0.4] + + @property + def main_input_name(self) -> str: + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + """Return UNet2D model initialization arguments.""" + return { "block_out_channels": (4, 8), "norm_num_groups": 4, "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), @@ -393,26 +385,24 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + """Return dummy inputs for UNet2D model.""" + batch_size = 4 + num_channels = 4 + sizes = (16, 16) - model.enable_xformers_memory_efficient_attention() + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + "encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device), + } - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" +class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin): def test_model_with_attention_head_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -427,12 +417,13 @@ def test_model_with_attention_head_dim_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_use_linear_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["use_linear_projection"] = True @@ -446,12 +437,13 @@ def test_model_with_use_linear_projection(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_cross_attention_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["cross_attention_dim"] = (8, 8) @@ -465,12 +457,13 @@ def test_model_with_cross_attention_dim_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_simple_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() batch_size, _, _, sample_size = inputs_dict["sample"].shape @@ -489,12 +482,13 @@ def test_model_with_simple_projection(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_class_embeddings_concat(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() batch_size, _, _, sample_size = inputs_dict["sample"].shape @@ -514,12 +508,287 @@ def test_model_with_class_embeddings_concat(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None + expected_shape = inputs_dict["sample"].shape + assert output.shape == expected_shape, "Input and output shapes do not match" + + # see diffusers.models.attention_processor::Attention#prepare_attention_mask + # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. + # since the use-case (somebody passes in a too-short cross-attn mask) is pretty small, + # maybe it's fine that this only works for the unclip use-case. + @mark.skip( + reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." + ) + def test_model_xattn_padding(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + batch, tokens, _ = cond.shape + keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) + keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample + assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" + + trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) + trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + assert trunc_mask_out.allclose(keeplast_out), ( + "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." + ) + + def test_pickle(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + init_dict["block_out_channels"] = (16, 32) + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample = model(**inputs_dict).sample + + sample_copy = copy.copy(sample) + + assert (sample - sample_copy).abs().max() < 1e-4 + + def test_asymmetrical_unet(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + # Add asymmetry to configs + init_dict["transformer_layers_per_block"] = [[3, 2], 1] + init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + output = model(**inputs_dict).sample expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + # Check if input and output shapes are the same + assert output.shape == expected_shape, "Input and output shapes do not match" + + +class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig): + """Hub checkpoint loading tests for UNet2DConditionModel.""" + + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub_local(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_from_hub_local_subfolder(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") + loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): + inputs_dict = self.get_dummy_inputs() + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_device_map_from_hub_local(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + @require_torch_accelerator + def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): + inputs_dict = self.get_dummy_inputs() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") + loaded_model = self.model_class.from_pretrained( + ckpt_path, local_files_only=True, subfolder="unet", device_map="auto" + ) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + + +class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin): + """LoRA adapter tests for UNet2DConditionModel.""" + + @require_peft_backend + def test_load_attn_procs_raise_warning(self): + """Test that deprecated load_attn_procs method raises FutureWarning.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without LoRA + with torch.no_grad(): + non_lora_sample = model(**inputs_dict).sample + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + # forward pass with LoRA + with torch.no_grad(): + lora_sample_1 = model(**inputs_dict).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + model.unload_lora() + + with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"): + model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # import to still check for the rest of the stuff. + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + with torch.no_grad(): + lora_sample_2 = model(**inputs_dict).sample + + assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( + "LoRA injected UNet should produce different results." + ) + assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), ( + "Loading from a saved checkpoint should produce identical results." + ) + + @require_peft_backend + def test_save_attn_procs_raise_warning(self): + """Test that deprecated save_attn_procs method raises FutureWarning.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + with tempfile.TemporaryDirectory() as tmpdirname: + with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"): + model.save_attn_procs(os.path.join(tmpdirname)) + + +class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNet2DConditionModel.""" + + +class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin): + """Training tests for UNet2DConditionModel.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "CrossAttnUpBlock2D", + "CrossAttnDownBlock2D", + "UNetMidBlock2DCrossAttn", + "UpBlock2D", + "Transformer2DModel", + "DownBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin): + """Attention processor tests for UNet2DConditionModel.""" + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -544,7 +813,7 @@ def test_model_attention_slicing(self): assert output is not None def test_model_sliceable_head_dim(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -562,21 +831,6 @@ def check_sliceable_dim_attr(module: torch.nn.Module): for module in model.children(): check_sliceable_dim_attr(module) - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "CrossAttnUpBlock2D", - "CrossAttnDownBlock2D", - "UNetMidBlock2DCrossAttn", - "UpBlock2D", - "Transformer2DModel", - "DownBlock2D", - } - attention_head_dim = (8, 16) - block_out_channels = (16, 32) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) - def test_special_attn_proc(self): class AttnEasyProc(torch.nn.Module): def __init__(self, num): @@ -618,7 +872,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma return hidden_states # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -645,7 +900,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma ] ) def test_model_xattn_mask(self, mask_dtype): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)}) model.to(torch_device) @@ -675,39 +931,13 @@ def test_model_xattn_mask(self, mask_dtype): "masking the last token from our cond should be equivalent to truncating that token out of the condition" ) - # see diffusers.models.attention_processor::Attention#prepare_attention_mask - # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. - # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. - # maybe it's fine that this only works for the unclip use-case. - @mark.skip( - reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." - ) - def test_model_xattn_padding(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) - model.to(torch_device) - model.eval() - - cond = inputs_dict["encoder_hidden_states"] - with torch.no_grad(): - full_cond_out = model(**inputs_dict).sample - assert full_cond_out is not None - batch, tokens, _ = cond.shape - keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) - keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample - assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" - - trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) - trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample - assert trunc_mask_out.allclose(keeplast_out), ( - "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." - ) +class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig): + """Custom Diffusion processor tests for UNet2DConditionModel.""" def test_custom_diffusion_processors(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -733,8 +963,8 @@ def test_custom_diffusion_processors(self): assert (sample1 - sample2).abs().max() < 3e-3 def test_custom_diffusion_save_load(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -754,7 +984,7 @@ def test_custom_diffusion_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")) torch.manual_seed(0) new_model = self.model_class(**init_dict) new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") @@ -773,8 +1003,8 @@ def test_custom_diffusion_save_load(self): reason="XFormers attention is only available with CUDA and `xformers` installed", ) def test_custom_diffusion_xformers_on_off(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -798,41 +1028,28 @@ def test_custom_diffusion_xformers_on_off(self): assert (sample - on_sample).abs().max() < 1e-4 assert (sample - off_sample).abs().max() < 1e-4 - def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - sample = model(**inputs_dict).sample - sample_copy = copy.copy(sample) +class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin): + """IP Adapter tests for UNet2DConditionModel.""" - assert (sample - sample_copy).abs().max() < 1e-4 - - def test_asymmetrical_unet(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - # Add asymmetry to configs - init_dict["transformer_layers_per_block"] = [[3, 2], 1] - init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) + @property + def ip_adapter_processor_cls(self): + return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0) - output = model(**inputs_dict).sample - expected_shape = inputs_dict["sample"].shape + def create_ip_adapter_state_dict(self, model): + return create_ip_adapter_state_dict(model) - # Check if input and output shapes are the same - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + # for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim] + cross_attention_dim = getattr(model.config, "cross_attention_dim", 8) + image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]} + return inputs_dict def test_ip_adapter(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -905,7 +1122,8 @@ def test_ip_adapter(self): assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) def test_ip_adapter_plus(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) @@ -977,185 +1195,16 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), - ] - ) - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), - ] - ) - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub_local(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") - loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_from_hub_local_subfolder(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") - loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), - ] - ) - def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - @parameterized.expand( - [ - ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), - ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), - ] - ) - def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_device_map_from_hub_local(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") - loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_torch_accelerator - def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") - loaded_model = self.model_class.from_pretrained( - ckpt_path, local_files_only=True, subfolder="unet", device_map="auto" - ) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - - @require_peft_backend - def test_load_attn_procs_raise_warning(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - # forward pass without LoRA - with torch.no_grad(): - non_lora_sample = model(**inputs_dict).sample - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - # forward pass with LoRA - with torch.no_grad(): - lora_sample_1 = model(**inputs_dict).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - model.unload_lora() - - with self.assertWarns(FutureWarning) as warning: - model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - warning_message = str(warning.warnings[0].message) - assert "Using the `load_attn_procs()` method has been deprecated" in warning_message - - # import to still check for the rest of the stuff. - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - with torch.no_grad(): - lora_sample_2 = model(**inputs_dict).sample - - assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( - "LoRA injected UNet should produce different results." - ) - assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), ( - "Loading from a saved checkpoint should produce identical results." - ) - - @require_peft_backend - def test_save_attn_procs_raise_warning(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - with tempfile.TemporaryDirectory() as tmpdirname: - with self.assertWarns(FutureWarning) as warning: - model.save_attn_procs(tmpdirname) - - warning_message = str(warning.warnings[0].message) - assert "Using the `save_attn_procs()` method has been deprecated" in warning_message - - -class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - def prepare_init_args_and_inputs_for_common(self): - return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() +class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for UNet2DConditionModel.""" + def test_torch_compile_repeated_blocks(self): + return super().test_torch_compile_repeated_blocks(recompile_limit=2) -class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - def prepare_init_args_and_inputs_for_common(self): - return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() +class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for UNet2DConditionModel.""" @slow diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py index f73e3461c38e..264c67223fb3 100644 --- a/tests/models/unets/test_models_unet_3d_condition.py +++ b/tests/models/unets/test_models_unet_3d_condition.py @@ -18,47 +18,44 @@ import numpy as np import torch -from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.utils import logging +from diffusers import UNet3DConditionModel from diffusers.utils.import_utils import is_xformers_available -from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + skip_mps, + torch_device, +) +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + ModelTesterMixin, +) enable_full_determinism() -logger = logging.get_logger(__name__) - @skip_mps -class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet3DConditionModel - main_input_name = "sample" +class UNet3DConditionTesterConfig(BaseModelTesterConfig): + """Base configuration for UNet3DConditionModel testing.""" @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - num_frames = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + def model_class(self): + return UNet3DConditionModel @property - def input_shape(self): + def output_shape(self): return (4, 4, 16, 16) @property - def output_shape(self): - return (4, 4, 16, 16) + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": (4, 8), "norm_num_groups": 4, "down_block_types": ( @@ -73,27 +70,25 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (16, 16) - model.enable_xformers_memory_efficient_attention() + return { + "sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + "encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device), + } - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" +class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin): # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (32, 64) init_dict["norm_num_groups"] = 32 @@ -107,39 +102,74 @@ def test_forward_with_norm_groups(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" # Overriding since the UNet3D outputs a different structure. + @torch.no_grad() def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with torch.no_grad(): - # Warmup pass when using mps (see #372) - if torch_device == "mps" and isinstance(model, ModelMixin): - model(**self.dummy_input) + inputs_dict = self.get_dummy_inputs() - first = model(**inputs_dict) - if isinstance(first, dict): - first = first.sample + first = model(**inputs_dict) + if isinstance(first, dict): + first = first.sample - second = model(**inputs_dict) - if isinstance(second, dict): - second = second.sample + second = model(**inputs_dict) + if isinstance(second, dict): + second = second.sample out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + assert max_diff <= 1e-5 + + def test_feed_forward_chunking(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + init_dict["block_out_channels"] = (32, 64) + init_dict["norm_num_groups"] = 32 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict)[0] + + model.enable_forward_chunking() + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + assert output.shape == output_2.shape, "Shape doesn't match" + assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + + +class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin): + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = 8 @@ -162,22 +192,3 @@ def test_model_attention_slicing(self): with torch.no_grad(): output = model(**inputs_dict) assert output is not None - - def test_feed_forward_chunking(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - init_dict["block_out_channels"] = (32, 64) - init_dict["norm_num_groups"] = 32 - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict)[0] - - model.enable_forward_chunking() - with torch.no_grad(): - output_2 = model(**inputs_dict)[0] - - self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") - assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 40773536df70..43ac56f7f54c 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -13,59 +13,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import numpy as np +import pytest import torch from torch import nn from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel -from diffusers.utils import logging from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() -class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetControlNetXSModel - main_input_name = "sample" +class UNetControlNetXSTesterConfig(BaseModelTesterConfig): + """Base configuration for UNetControlNetXSModel testing.""" @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device) - conditioning_scale = 1 - - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - "controlnet_cond": controlnet_cond, - "conditioning_scale": conditioning_scale, - } + def model_class(self): + return UNetControlNetXSModel @property - def input_shape(self): + def output_shape(self): return (4, 16, 16) @property - def output_shape(self): - return (4, 16, 16) + def main_input_name(self): + return "sample" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "sample_size": 16, "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), @@ -80,11 +63,23 @@ def prepare_init_args_and_inputs_for_common(self): "ctrl_max_norm_num_groups": 2, "ctrl_conditioning_embedding_out_channels": (2, 2), } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + + def get_dummy_inputs(self): + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + conditioning_image_size = (3, 32, 32) + + return { + "sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device), + "timestep": torch.tensor([10]).to(torch_device), + "encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device), + "controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device), + "conditioning_scale": 1, + } def get_dummy_unet(self): - """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" + """Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter.""" return UNet2DConditionModel( block_out_channels=(4, 8), layers_per_block=2, @@ -99,10 +94,16 @@ def get_dummy_unet(self): ) def get_dummy_controlnet_from_unet(self, unet, **kwargs): - """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" - # size_ratio and conditioning_embedding_out_channels chosen to keep model small + """Build the ControlNetXS-Adapter from a UNet.""" return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) + +class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Test not supported.") + def test_forward_with_norm_groups(self): + # UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32 + pass + def test_from_unet(self): unet = self.get_dummy_unet() controlnet = self.get_dummy_controlnet_from_unet(unet) @@ -115,7 +116,7 @@ def assert_equal_weights(module, weight_dict_prefix): assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value) # # check unet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_unet = [ "time_embedding", "conv_in", @@ -152,7 +153,7 @@ def assert_equal_weights(module, weight_dict_prefix): assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers") # # check controlnet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_controlnet = { "controlnet_cond_embedding": "controlnet_cond_embedding", "conv_in": "ctrl_conv_in", @@ -193,12 +194,12 @@ def assert_unfrozen(module): for p in module.parameters(): assert p.requires_grad - init_dict, _ = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() model = UNetControlNetXSModel(**init_dict) model.freeze_unet_params() # # check unet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_unet = [ model.base_time_embedding, model.base_conv_in, @@ -236,7 +237,7 @@ def assert_unfrozen(module): assert_frozen(u.upsamplers) # # check controlnet - # everything expect down,mid,up blocks + # everything except down,mid,up blocks modules_from_controlnet = [ model.controlnet_cond_embedding, model.ctrl_conv_in, @@ -267,16 +268,6 @@ def assert_unfrozen(module): for u in model.up_blocks: assert_unfrozen(u.ctrl_to_base) - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "Transformer2DModel", - "UNetMidBlock2DCrossAttn", - "ControlNetXSCrossAttnDownBlock2D", - "ControlNetXSCrossAttnMidBlock2D", - "ControlNetXSCrossAttnUpBlock2D", - } - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - @is_flaky def test_forward_no_control(self): unet = self.get_dummy_unet() @@ -287,7 +278,7 @@ def test_forward_no_control(self): unet = unet.to(torch_device) model = model.to(torch_device) - input_ = self.dummy_input + input_ = self.get_dummy_inputs() control_specific_input = ["controlnet_cond", "conditioning_scale"] input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} @@ -312,7 +303,7 @@ def test_time_embedding_mixing(self): model = model.to(torch_device) model_mix_time = model_mix_time.to(torch_device) - input_ = self.dummy_input + input_ = self.get_dummy_inputs() with torch.no_grad(): output = model(**input_).sample @@ -320,7 +311,14 @@ def test_time_embedding_mixing(self): assert output.shape == output_mix_time.shape - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. - pass + +class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "Transformer2DModel", + "UNetMidBlock2DCrossAttn", + "ControlNetXSCrossAttnDownBlock2D", + "ControlNetXSCrossAttnMidBlock2D", + "ControlNetXSCrossAttnUpBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py index 7df868c9e95b..1951d2c0f326 100644 --- a/tests/models/unets/test_models_unet_spatiotemporal.py +++ b/tests/models/unets/test_models_unet_spatiotemporal.py @@ -16,10 +16,10 @@ import copy import unittest +import pytest import torch from diffusers import UNetSpatioTemporalConditionModel -from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from ...testing_utils import ( @@ -28,45 +28,34 @@ skip_mps, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - +from ..test_modeling_common import UNetTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() @skip_mps -class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetSpatioTemporalConditionModel - main_input_name = "sample" +class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig): + """Base configuration for UNetSpatioTemporalConditionModel testing.""" @property - def dummy_input(self): - batch_size = 2 - num_frames = 2 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device) - - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - "added_time_ids": self._get_add_time_ids(), - } - - @property - def input_shape(self): - return (2, 2, 4, 32, 32) + def model_class(self): + return UNetSpatioTemporalConditionModel @property def output_shape(self): return (4, 32, 32) + @property + def main_input_name(self): + return "sample" + @property def fps(self): return 6 @@ -83,8 +72,8 @@ def noise_aug_strength(self): def addition_time_embed_dim(self): return 32 - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "block_out_channels": (32, 64), "down_block_types": ( "CrossAttnDownBlockSpatioTemporal", @@ -103,8 +92,23 @@ def prepare_init_args_and_inputs_for_common(self): "projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3, "addition_time_embed_dim": self.addition_time_embed_dim, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + + def get_dummy_inputs(self): + batch_size = 2 + num_frames = 2 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device) + + return { + "sample": noise, + "timestep": time_step, + "encoder_hidden_states": encoder_hidden_states, + "added_time_ids": self._get_add_time_ids(), + } def _get_add_time_ids(self, do_classifier_free_guidance=True): add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength] @@ -124,43 +128,15 @@ def _get_add_time_ids(self, do_classifier_free_guidance=True): return add_time_ids - @unittest.skip("Number of Norm Groups is not configurable") - def test_forward_with_norm_groups(self): - pass - - @unittest.skip("Deprecated functionality") - def test_model_attention_slicing(self): - pass - - @unittest.skip("Not supported") - def test_model_with_use_linear_projection(self): - pass - - @unittest.skip("Not supported") - def test_model_with_simple_projection(self): - pass - @unittest.skip("Not supported") - def test_model_with_class_embeddings_concat(self): +class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin): + @pytest.mark.skip("Number of Norm Groups is not configurable") + def test_forward_with_norm_groups(self): pass - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - - model.enable_xformers_memory_efficient_attention() - - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" - def test_model_with_num_attention_heads_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["num_attention_heads"] = (8, 16) model = self.model_class(**init_dict) @@ -173,12 +149,13 @@ def test_model_with_num_attention_heads_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == expected_shape, "Input and output shapes do not match" def test_model_with_cross_attention_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["cross_attention_dim"] = (32, 32) @@ -192,27 +169,13 @@ def test_model_with_cross_attention_dim_tuple(self): if isinstance(output, dict): output = output.sample - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "TransformerSpatioTemporalModel", - "CrossAttnDownBlockSpatioTemporal", - "DownBlockSpatioTemporal", - "UpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "UNetMidBlockSpatioTemporal", - } - num_attention_heads = (8, 16) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, num_attention_heads=num_attention_heads - ) + assert output.shape == expected_shape, "Input and output shapes do not match" def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["num_attention_heads"] = (8, 16) @@ -225,3 +188,33 @@ def test_pickle(self): sample_copy = copy.copy(sample) assert (sample - sample_copy).abs().max() < 1e-4 + + +class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin): + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersAttnProcessor" + ), "xformers is not enabled" + + +class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "TransformerSpatioTemporalModel", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "UNetMidBlockSpatioTemporal", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set)