Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
98f6c8c
draft:add neuron as a legit backend
JingyaHuang Mar 18, 2026
c58b8b8
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Mar 18, 2026
3367409
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Mar 19, 2026
0c51734
Merge branch 'main' into add-neuron-backend
JingyaHuang Mar 25, 2026
a76953c
feat: neuron-specific changes in the pipeline
JingyaHuang Mar 26, 2026
2480388
tests: eager tests
JingyaHuang Mar 27, 2026
929ab72
fix: style
JingyaHuang Apr 9, 2026
52cac76
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 9, 2026
28a5086
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang Apr 9, 2026
68689e5
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 10, 2026
da79308
Merge branch 'main' into add-neuron-backend
JingyaHuang Apr 10, 2026
3bb9c7c
fix:apr_02 beta
JingyaHuang Apr 10, 2026
c4facab
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang Apr 10, 2026
1eb5ff9
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Apr 13, 2026
0d927b3
Merge branch 'main' into add-neuron-backend
JingyaHuang Apr 22, 2026
d1a8911
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang May 4, 2026
c9e55ae
Merge branch 'main' into add-neuron-backend
JingyaHuang May 11, 2026
291171b
cleanup: remove tp part, for another pr
JingyaHuang May 11, 2026
f1caec0
fix: restore ring_anything to ContextParallelConfig after over-aggres…
JingyaHuang May 11, 2026
510a914
removal: style fix
JingyaHuang May 11, 2026
c312167
tests:sdxl + flux2
JingyaHuang May 12, 2026
1165f9f
tests: simplify
JingyaHuang May 13, 2026
5d13779
tests: simplify
JingyaHuang May 13, 2026
c007709
fix style
JingyaHuang May 13, 2026
c608bbb
Merge branch 'main' into add-neuron-backend
JingyaHuang May 13, 2026
fb5ea94
fix style
JingyaHuang May 13, 2026
9d93c68
fix style for doc-builder
JingyaHuang May 13, 2026
332b5a6
Merge branch 'main' into add-neuron-backend
JingyaHuang May 18, 2026
e8bf642
review: address comments
JingyaHuang May 18, 2026
1e25f3e
Merge branch 'main' into add-neuron-backend
JingyaHuang May 18, 2026
a3b6ccb
review:apply suggestion for the fix of index_for_timtestep
JingyaHuang May 20, 2026
86e550a
review: stronger guard on image slice
JingyaHuang May 20, 2026
c8e9716
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang May 20, 2026
45a8812
Apply suggestions from code review
JingyaHuang May 20, 2026
6c025c2
Merge branch 'main' into add-neuron-backend
JingyaHuang May 20, 2026
a18bfc5
Merge branch 'main' into add-neuron-backend
JingyaHuang May 21, 2026
1bf6b56
Merge branch 'main' into add-neuron-backend
JingyaHuang May 22, 2026
3144cfe
Merge branch 'main' into add-neuron-backend
JingyaHuang May 26, 2026
d5ee083
fix: when set_begin_index not implemented for scheduler
JingyaHuang May 26, 2026
688df52
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers…
JingyaHuang May 26, 2026
4b3ff51
review:add maybe_adjust_dtype_for_device and apply to all models with…
JingyaHuang Jun 1, 2026
888a936
fix: dependency
JingyaHuang Jun 1, 2026
1c25326
Merge branch 'main' into add-neuron-backend
JingyaHuang Jun 1, 2026
2443920
Merge branch 'main' into add-neuron-backend
yiyixuxu Jun 2, 2026
e98f17e
Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_di…
JingyaHuang Jun 2, 2026
19702b2
Merge branch 'main' into add-neuron-backend
JingyaHuang Jun 2, 2026
ab82699
review: apply maybe_adjust_dtype_for_device in pixart pipe
JingyaHuang Jun 2, 2026
20af97b
review: update .ai/models.md
JingyaHuang Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .ai/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ Boolean gate. If `False` (default), calling that method raises `ValueError`. All
3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward.
4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.

5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
5. **`torch.float64` anywhere in the model.** MPS, NPU, and Neuron backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
- **Only if float32 visibly degrades output, use the `maybe_adjust_dtype_for_device` helper** from `diffusers.utils.torch_utils`. It centralizes the device-specific dtype downcast (float64→float32, int64→int32) for all restricted backends (mps, npu, neuron):
```python
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
from diffusers.utils.torch_utils import maybe_adjust_dtype_for_device

freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
```
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py`, and `pipeline_pixart_alpha.py` for reference usages. Never leave an unconditional `torch.float64` in the model.

6. **Using `torch.empty`.** - Do not use `torch.empty` to initialize parameters. Use `torch.zeros` or `torch.ones`, instead.
10 changes: 4 additions & 6 deletions src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, apply_lora_scale, logging
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -675,12 +676,9 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = maybe_adjust_dtype_for_device(
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/controlnets/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -604,12 +605,9 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = maybe_adjust_dtype_for_device(
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/models/controlnets/controlnet_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -620,12 +621,9 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = maybe_adjust_dtype_for_device(
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
11 changes: 4 additions & 7 deletions src/diffusers/models/controlnets/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.torch_utils import apply_freeu
from ...utils.torch_utils import apply_freeu, maybe_adjust_dtype_for_device
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -1014,12 +1014,9 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = maybe_adjust_dtype_for_device(
torch.float64 if isinstance(timestep, float) else torch.int64, sample.device
)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

from ..utils import deprecate
from ..utils.torch_utils import maybe_adjust_dtype_for_device
from .activations import FP32SiLU, get_activation
from .attention_processor import Attention

Expand Down Expand Up @@ -346,7 +347,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin

# Auto-detect appropriate dtype if not specified
if dtype is None:
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device)

omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
Expand Down
9 changes: 3 additions & 6 deletions src/diffusers/models/transformers/transformer_anyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
Expand All @@ -41,9 +42,7 @@

def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
# MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices.
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
Expand Down Expand Up @@ -341,9 +340,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor:
if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
return self._freqs_cache[1]

is_mps = device.type == "mps"
is_npu = device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device)

h_dim = w_dim = 2 * (self.attention_head_dim // 6)
t_dim = self.attention_head_dim - h_dim - w_dim
Expand Down
9 changes: 3 additions & 6 deletions src/diffusers/models/transformers/transformer_anyflow_far.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, logging
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
Expand All @@ -44,9 +45,7 @@
# Copied from diffusers.models.transformers.transformer_anyflow.apply_rotary_emb
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
# MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices.
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
Expand Down Expand Up @@ -650,9 +649,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor:
if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
return self._freqs_cache[1]

is_mps = device.type == "mps"
is_npu = device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device)

h_dim = w_dim = 2 * (self.attention_head_dim // 6)
t_dim = self.attention_head_dim - h_dim - w_dim
Expand Down
8 changes: 3 additions & 5 deletions src/diffusers/models/transformers/transformer_bria.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
Expand Down Expand Up @@ -276,8 +276,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
Expand Down Expand Up @@ -344,8 +343,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/models/transformers/transformer_bria_fibo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
apply_lora_scale,
logging,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
Expand Down Expand Up @@ -222,8 +222,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
Expand Down
6 changes: 2 additions & 4 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
Expand Down Expand Up @@ -503,9 +503,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/models/transformers/transformer_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, logging
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
Expand Down Expand Up @@ -959,9 +960,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
Comment thread
JingyaHuang marked this conversation as resolved.
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
for i in range(len(self.axes_dim)):
cos, sin = get_1d_rotary_pos_embed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps

Expand Down Expand Up @@ -95,10 +95,7 @@ def forward(self, latent) -> torch.Tensor:
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."

is_mps = pos.device.type == "mps"
is_npu = pos.device.type == "npu"

dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device)

scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
Expand Down Expand Up @@ -361,9 +361,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device)
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/models/transformers/transformer_motif_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_adjust_dtype_for_device
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
Expand Down Expand Up @@ -483,9 +484,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
grid = torch.stack(grid, dim=0)

freqs = []
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)
for i in range(3):
freq = get_1d_rotary_pos_embed(
dim=self.rope_dim[i],
Expand Down
Loading
Loading