Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/diffusers/models/transformers/transformer_ernie_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,14 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
query, key = query.to(dtype), key.to(dtype)

# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
if attention_mask is not None:
if attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]

if attention_mask.ndim == 4:
# NPU does not support automatic broadcasting for this type; the mask must be expanded.
if attention_mask.device.type == 'npu' and attention_mask.shape[1:3] == (1, 1):
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we verify if we explicitly seet the backend to npu, this would also work?

def _native_npu_attention(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a mask of shape [batch, seq_len] or [batch, 1, 1, seq_len] is passed, the operator fails with an error similar to:
get unsupported atten_mask shape, the shape is [B, 1, 1, S] – while only shapes like [B, N, S, S], [B, 1, S, S], [1, 1, S, S], or [S, S] are accepted.

The _native_npu_attention function operates correctly as it leverages _maybe_modify_attn_mask_npu to reshape the attention mask from [batch_size, seq_len_k] to [batch_size, 1, seq_len_q, seq_len_k]. This reshaped format is compatible with the NPU backend.

Reference:
Ascend NPU fusion attention API:
https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md

attention_mask = attention_mask.expand(-1, attn.heads, query.shape[1], -1)

# Compute joint attention
hidden_states = dispatch_attention_fn(
Expand Down
Loading