diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 4bf00f749f25..5b2b51202b27 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -127,8 +127,14 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso query, key = query.to(dtype), key.to(dtype) # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None and attention_mask.ndim == 2: - attention_mask = attention_mask[:, None, None, :] + if attention_mask is not None: + if attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + if attention_mask.ndim == 4: + # NPU does not support automatic broadcasting for this type; the mask must be expanded. + if attention_mask.device.type == 'npu' and attention_mask.shape[1:3] == (1, 1): + attention_mask = attention_mask.expand(-1, attn.heads, query.shape[1], -1) # Compute joint attention hidden_states = dispatch_attention_fn(