From 1645458a83dfd2aefd6d891adcb5633c3a5e030d Mon Sep 17 00:00:00 2001 From: chhayankjain Date: Tue, 12 May 2026 06:39:50 +0530 Subject: [PATCH] fix: only instantiate CrossAttentionBlock when with_cross_attention=True [200~Fixes #8845 TransformerBlock previously instantiated norm_cross_attn and cross_attn unconditionally, even when with_cross_attention=False. These unused modules registered dead parameters in model.parameters(), wasting memory. Wrapped both instantiations in `if with_cross_attention:` to match the existing guard in forward(). Added tests to verify the modules and their parameters are absent when disabled, present when enabled, and that the forward pass with a context tensor works correctly.~ Signed-off-by: chhayankjain --- monai/networks/blocks/transformerblock.py | 28 +++++++++++++------ .../networks/blocks/test_transformerblock.py | 28 +++++++++++++++++++ 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index b93d81bdef..0649512b56 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -46,11 +46,20 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False. + sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None. + with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an + external context tensor. When False, norm_cross_attn and cross_attn are not instantiated. + Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + Raises: + ValueError: if dropout_rate is not in [0, 1]. + ValueError: if hidden_size is not divisible by num_heads. + """ super().__init__() @@ -78,15 +87,16 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) def forward( self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None diff --git a/tests/networks/blocks/test_transformerblock.py b/tests/networks/blocks/test_transformerblock.py index b977a38e73..5cf008d6aa 100644 --- a/tests/networks/blocks/test_transformerblock.py +++ b/tests/networks/blocks/test_transformerblock.py @@ -53,6 +53,34 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_params_not_registered_when_disabled(self): + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False) + param_names = [name for name, _ in block.named_parameters()] + self.assertFalse(any("cross_attn" in n for n in param_names)) + self.assertFalse(any("norm_cross_attn" in n for n in param_names)) + self.assertFalse(hasattr(block, "cross_attn")) + self.assertFalse(hasattr(block, "norm_cross_attn")) + + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_params_registered_when_enabled(self): + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True) + self.assertTrue(hasattr(block, "cross_attn")) + self.assertTrue(hasattr(block, "norm_cross_attn")) + param_names = [name for name, _ in block.named_parameters()] + self.assertTrue(any("cross_attn" in n for n in param_names)) + self.assertTrue(any("norm_cross_attn" in n for n in param_names)) + + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_forward_with_context(self): + hidden_size = 128 + block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True) + x = torch.randn(2, 16, hidden_size) + context = torch.randn(2, 8, hidden_size) + with eval_mode(block): + out = block(x, context=context) + self.assertEqual(out.shape, x.shape) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format