Skip to content

[tests] refactor UNet model tests to align with the new pattern#13153

Open
sayakpaul wants to merge 15 commits intomainfrom
unet-model-tests-refactor
Open

[tests] refactor UNet model tests to align with the new pattern#13153
sayakpaul wants to merge 15 commits intomainfrom
unet-model-tests-refactor

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

Some comments in-line. I have run the tests locally and all of them pass.

@sayakpaul sayakpaul requested a review from DN6 February 16, 2026 10:16
sayakpaul and others added 5 commits February 16, 2026 16:05
Refactor UNet1D model tests to follow the modern testing pattern using
BaseModelTesterConfig and focused mixin classes (ModelTesterMixin,
MemoryTesterMixin, TrainingTesterMixin, LoraTesterMixin).

Both UNet1D standard and RL variants now have separate config classes
and dedicated test classes organized by concern (core, memory, training,
LoRA, hub loading).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Refactor UNet2D model tests (standard, LDM, NCSN++) to follow the
modern testing pattern. Each variant gets its own config class and
dedicated test classes organized by concern (core, memory, training,
LoRA, hub loading).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ing mixins

Refactor UNet3DConditionModel tests to follow the modern testing pattern
with separate classes for core, attention, memory, training, and LoRA.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ing mixins

Refactor UNetControlNetXSModel tests to follow the modern testing
pattern with separate classes for core, memory, training, and LoRA.
Specialized tests (from_unet, freeze_unet, forward_no_control,
time_embedding_mixing) remain in the core test class.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…sting mixins

Refactored the spatiotemporal UNet test file to follow the modern modular testing
pattern with BaseModelTesterConfig and focused test classes:

- UNetSpatioTemporalTesterConfig: Base configuration with model setup
- TestUNetSpatioTemporal: Core model tests (ModelTesterMixin, UNetTesterMixin)
- TestUNetSpatioTemporalAttention: Attention-related tests (AttentionTesterMixin)
- TestUNetSpatioTemporalMemory: Memory/offloading tests (MemoryTesterMixin)
- TestUNetSpatioTemporalTraining: Training tests (TrainingTesterMixin)
- TestUNetSpatioTemporalLoRA: LoRA adapter tests (LoraTesterMixin)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines -290 to +292
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure reproducibility.

Comment on lines -95 to -97
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed as we pass recompile_limit explicitly now.

with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)

def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was needed because caplog doesn't capture these warnings properly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant