From fc7c419711789ad3b982f569c484bec1c1398bef Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Tue, 14 Apr 2026 11:28:48 +0800 Subject: [PATCH 1/2] fix npu compatibility --- .../models/transformers/transformer_ernie_image.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 09682a218d91..1a046a5672fc 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -127,8 +127,14 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso query, key = query.to(dtype), key.to(dtype) # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None and attention_mask.ndim == 2: - attention_mask = attention_mask[:, None, None, :] + if attention_mask is not None + if attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + if attention_mask.ndim == 4: + # NPU does not support automatic broadcasting for this type; the mask must be expanded. + if attention_mask.device.type == 'npu' and attention_mask.shape[1:3] == (1, 1): + attention_mask = attention_mask.expand(-1, attn.heads, query.shape[1], -1) # Compute joint attention hidden_states = dispatch_attention_fn( From 51c7ea7c8201903c78584ddc583631096f1bf6a7 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Tue, 14 Apr 2026 11:33:33 +0800 Subject: [PATCH 2/2] fix npu compatibility --- src/diffusers/models/transformers/transformer_ernie_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 1a046a5672fc..90524e46b6ad 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -127,7 +127,7 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso query, key = query.to(dtype), key.to(dtype) # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None + if attention_mask is not None: if attention_mask.ndim == 2: attention_mask = attention_mask[:, None, None, :]