diff --git a/.ai/models.md b/.ai/models.md index 954cbd343781..6e37e742ae57 100644 --- a/.ai/models.md +++ b/.ai/models.md @@ -163,14 +163,14 @@ Boolean gate. If `False` (default), calling that method raises `ValueError`. All 3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward. 4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`. -5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows: +5. **`torch.float64` anywhere in the model.** MPS, NPU, and Neuron backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows: - **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on. - - **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo: + - **Only if float32 visibly degrades output, use the `maybe_adjust_dtype_for_device` helper** from `diffusers.utils.torch_utils`. It centralizes the device-specific dtype downcast (float64→float32, int64→int32) for all restricted backends (mps, npu, neuron): ```python - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + from diffusers.utils.torch_utils import maybe_adjust_dtype_for_device + + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) ``` - See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model. + See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py`, and `pipeline_pixart_alpha.py` for reference usages. Never leave an unconditional `torch.float64` in the model. 6. **Using `torch.empty`.** - Do not use `torch.empty` to initialize parameters. Use `torch.zeros` or `torch.ones`, instead. \ No newline at end of file diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 8c2ff2fdd123..d2030f4e7044 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -22,6 +22,7 @@ from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -675,12 +676,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 715d9dad2c34..dda653ea7a50 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -604,12 +605,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 8e434ba1f250..8dfcb1795618 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -19,6 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -620,12 +621,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 6221d4878de9..efc242f332b9 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging -from ...utils.torch_utils import apply_freeu +from ...utils.torch_utils import apply_freeu, maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -1014,12 +1014,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bcd192c1f166..c5eaa746252e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -19,6 +19,7 @@ from torch import nn from ..utils import deprecate +from ..utils.torch_utils import maybe_adjust_dtype_for_device from .activations import FP32SiLU, get_activation from .attention_processor import Attention @@ -346,7 +347,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin # Auto-detect appropriate dtype if not specified if dtype is None: - dtype = torch.float32 if pos.device.type == "mps" else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype) omega /= embed_dim / 2.0 diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py index 873e0b095b33..6b0872ffdb01 100644 --- a/src/diffusers/models/transformers/transformer_anyflow.py +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -28,6 +28,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -41,9 +42,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): # MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices. - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states) @@ -341,9 +340,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor: if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key: return self._freqs_cache[1] - is_mps = device.type == "mps" - is_npu = device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device) h_dim = w_dim = 2 * (self.attention_head_dim // 6) t_dim = self.attention_head_dim - h_dim - w_dim diff --git a/src/diffusers/models/transformers/transformer_anyflow_far.py b/src/diffusers/models/transformers/transformer_anyflow_far.py index 55bf750dc656..4a6fc553279f 100644 --- a/src/diffusers/models/transformers/transformer_anyflow_far.py +++ b/src/diffusers/models/transformers/transformer_anyflow_far.py @@ -30,6 +30,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import BaseOutput, apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -44,9 +45,7 @@ # Copied from diffusers.models.transformers.transformer_anyflow.apply_rotary_emb def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): # MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices. - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states) @@ -650,9 +649,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor: if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key: return self._freqs_cache[1] - is_mps = device.type == "mps" - is_npu = device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device) h_dim = w_dim = 2 * (self.attention_head_dim // 6) t_dim = self.attention_head_dim - h_dim - w_dim diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 8e79046508e9..ff4261343ab2 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -9,7 +9,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -276,8 +276,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], @@ -344,8 +343,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 31c826bbf6b2..7b4ac1a3bedf 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -25,7 +25,7 @@ apply_lora_scale, logging, ) -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -222,8 +222,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 13177bc67878..94857dffacb2 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -503,9 +503,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index e56f18f788e9..c3fa6ac141f3 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import BaseOutput, apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -959,9 +960,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] for i in range(len(self.axes_dim)): cos, sin = get_1d_rotary_pos_embed( diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index b6c0e3533657..bd69d5de68ca 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -9,7 +9,7 @@ from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin from ...utils import apply_lora_scale, deprecate, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps @@ -95,10 +95,7 @@ def forward(self, latent) -> torch.Tensor: def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - is_mps = pos.device.type == "mps" - is_npu = pos.device.type == "npu" - - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index fe4713ea02db..7b842c42132d 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -361,9 +361,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_motif_video.py b/src/diffusers/models/transformers/transformer_motif_video.py index c0908f198f90..fb3ff0666f95 100644 --- a/src/diffusers/models/transformers/transformer_motif_video.py +++ b/src/diffusers/models/transformers/transformer_motif_video.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -483,9 +484,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: grid = torch.stack(grid, dim=0) freqs = [] - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) for i in range(3): freq = get_1d_rotary_pos_embed( dim=self.rope_dim[i], diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 8280a4871f10..7a9df427e0b9 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -364,9 +364,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 3c8e8ae4e2c9..37553dd44c87 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -19,6 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..embeddings import get_timestep_embedding @@ -275,9 +276,7 @@ def __init__(self, dim: int, theta: int, axes_dim: list[int]): def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 - is_mps = pos.device.type == "mps" - is_npu = pos.device.type == "npu" - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index deae25899475..38a41a3dc93f 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -26,6 +26,7 @@ deprecate, logging, ) +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -853,12 +854,9 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 5006e48feb46..0d15e93da68f 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -547,12 +548,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 30fb46095326..9e7841f95e58 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -20,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import Attention, AttentionMixin, FeedForward from ..attention_processor import ( @@ -499,12 +500,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timesteps, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timesteps, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7c4201facacf..6904cc05f10c 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput, apply_lora_scale, deprecate, logging -from ...utils.torch_utils import apply_freeu +from ...utils.torch_utils import apply_freeu, maybe_adjust_dtype_for_device from ..attention import AttentionMixin, BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -1952,12 +1952,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index eddeb9826b0c..d38be0b0675f 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -6,6 +6,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor from ..embeddings import TimestepEmbedding, Timesteps @@ -335,12 +336,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 604e51d88583..11eaeaca7fc0 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -33,7 +33,7 @@ logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import maybe_adjust_dtype_for_device, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -862,7 +862,8 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - if XLA_AVAILABLE: + is_neuron_device = device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device @@ -912,12 +913,10 @@ def __call__( if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, latent_model_input.device) else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device(torch.int64, latent_model_input.device) current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 8148fac123e0..d08b6c5a5973 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1095,7 +1095,8 @@ def __call__( ) # 4. Prepare timesteps - if XLA_AVAILABLE: + is_neuron_device = device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device @@ -1190,6 +1191,8 @@ def __call__( ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -1197,16 +1200,18 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds + # [Neuron] pre-cast timestep to float32 on device. Neuron XLA does not support + # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. + t_unet = t.to(torch.float32).to(device) if is_neuron_device else t noise_pred = self.unet( latent_model_input, - t, + t_unet, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 4b41622b2a4a..3738821a96a1 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -113,6 +113,7 @@ is_timm_available, is_torch_available, is_torch_mlu_available, + is_torch_neuronx_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 5323dfe5ec82..937e20d6b155 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -250,6 +251,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_neuronx_available(): + return _torch_neuronx_available + + def is_flax_available(): return _flax_available @@ -600,6 +605,11 @@ def is_av_available(): """ +TORCH_NEURONX_IMPORT_ERROR = """ +{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it +following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -630,6 +640,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 619a37034949..eefe52c477a6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -46,6 +46,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -113,6 +114,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index c314a8609bec..263334dce8cd 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -22,7 +22,13 @@ from typing import Callable, ParamSpec, TypeVar from . import logging -from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +from .import_utils import ( + is_torch_available, + is_torch_mlu_available, + is_torch_neuronx_available, + is_torch_npu_available, + is_torch_version, +) T = TypeVar("T") @@ -33,12 +39,20 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = { + "cuda": True, + "xpu": True, + "cpu": True, + "mps": False, + "neuron": False, + "default": True, + } BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "neuron": None, "default": None, } BACKEND_DEVICE_COUNT = { @@ -46,6 +60,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -53,6 +68,7 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + "neuron": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -60,6 +76,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -67,6 +84,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -74,6 +92,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "neuron": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -81,8 +100,15 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "neuron": getattr(getattr(torch, "neuron", None), "synchronize", None), "default": None, } + + _FP64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) + _INT64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) + _DTYPE_DOWNCAST = {torch.float64: torch.float32, torch.int64: torch.int32} + _DTYPE_UNSUPPORTED_DEVICES = {torch.float64: _FP64_UNSUPPORTED_DEVICES, torch.int64: _INT64_UNSUPPORTED_DEVICES} + logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: @@ -149,6 +175,11 @@ def backend_supports_training(device: str): return BACKEND_SUPPORTS_TRAINING[device] +def maybe_adjust_dtype_for_device(dtype: "torch.dtype", device: "torch.device") -> "torch.dtype": + unsupported = _DTYPE_UNSUPPORTED_DEVICES.get(dtype) + return _DTYPE_DOWNCAST[dtype] if unsupported and device.type in unsupported else dtype + + def randn_tensor( shape: tuple | list, generator: list["torch.Generator"] | "torch.Generator" | None = None, @@ -169,11 +200,15 @@ def randn_tensor( layout = layout or torch.strided device = device or torch.device("cpu") + # Neuron (XLA) does not support creating random tensors directly on device; always use CPU + if device.type == "neuron": + rand_device = torch.device("cpu") + if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" - if device.type != "mps": + if device.type not in ("mps", "neuron"): logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" @@ -296,6 +331,8 @@ def get_device(): return "mps" elif is_torch_mlu_available(): return "mlu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + return "neuron" else: return "cpu" diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index 8ed9bf3d1e91..377f02dc9aa1 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -1,3 +1,5 @@ +import gc +import os import unittest import numpy as np @@ -11,8 +13,13 @@ Flux2KleinPipeline, Flux2Transformer2DModel, ) +from diffusers.utils.import_utils import is_torch_neuronx_available -from ...testing_utils import torch_device +from ...testing_utils import ( + backend_empty_cache, + require_torch_neuron, + torch_device, +) from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist @@ -181,3 +188,57 @@ def test_image_input(self): @unittest.skip("Needs to be revisited") def test_encode_prompt_works_in_isolation(self): pass + + +@require_torch_neuron +class Flux2KleinPipelineIntegrationTests(unittest.TestCase): + ckpt_id = "black-forest-labs/FLUX.2-klein-4B" + prompt = "A small cactus with a happy face in the Sahara desert." + + def setUp(self): + super().setUp() + self._saved_env = {} + if is_torch_neuronx_available(): + neff_cache_dir = "/tmp/neff_cache" + os.makedirs(neff_cache_dir, exist_ok=True) + for key in ("TORCH_NEURONX_NEFF_CACHE_DIR", "TORCH_NEURONX_ENABLE_NKI_SDPA"): + self._saved_env[key] = os.environ.get(key) + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = neff_cache_dir + os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + for key, original in self._saved_env.items(): + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original + gc.collect() + backend_empty_cache(torch_device) + + def test_flux2_klein_inference_512(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = Flux2KleinPipeline.from_pretrained(self.ckpt_id, torch_dtype=torch.bfloat16) + pipe.to(torch_device) + if is_torch_neuronx_available(): + torch.neuron.synchronize() + pipe.set_progress_bar_config(disable=None) + + image = pipe( + prompt=self.prompt, + height=512, + width=512, + num_inference_steps=4, + guidance_scale=1.0, + generator=generator, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1] + self.assertEqual(image.shape, (1, 512, 512, 3)) + self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") + expected_slice = np.array([0.3652, 0.3574, 0.3633, 0.4102, 0.4062, 0.4043, 0.4453, 0.4355, 0.4570]) + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 5e-2) diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 037a9f44f31e..86fe673a8c7d 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -27,6 +27,7 @@ PixArtAlphaPipeline, PixArtTransformer2DModel, ) +from diffusers.utils.import_utils import is_torch_neuronx_available from ...testing_utils import ( backend_empty_cache, @@ -291,7 +292,9 @@ def test_pixart_1024(self): expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_512(self): generator = torch.Generator("cpu").manual_seed(0) @@ -307,7 +310,9 @@ def test_pixart_512(self): expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_1024_without_resolution_binning(self): generator = torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index b318a505e9db..c9afdc3209cd 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -15,6 +15,7 @@ import copy import gc +import os import tempfile import unittest @@ -24,6 +25,7 @@ from diffusers import ( AutoencoderKL, + AutoPipelineForText2Image, DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, @@ -34,6 +36,7 @@ UNet2DConditionModel, UniPCMultistepScheduler, ) +from diffusers.utils.import_utils import is_torch_neuronx_available from ...testing_utils import ( backend_empty_cache, @@ -41,6 +44,7 @@ load_image, numpy_cosine_similarity_distance, require_torch_accelerator, + require_torch_neuron, slow, torch_device, ) @@ -974,3 +978,51 @@ def test_stable_diffusion_lcm(self): max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten()) assert max_diff < 1e-2 + + +@require_torch_neuron +class StableDiffusionXLTurboPipelineIntegrationTests(unittest.TestCase): + ckpt_id = "stabilityai/sdxl-turbo" + prompt = "A small cactus with a happy face in the Sahara desert." + + def setUp(self): + super().setUp() + self._saved_env = {} + if is_torch_neuronx_available(): + self._saved_env["TORCH_NEURONX_ENABLE_NKI_SDPA"] = os.environ.get("TORCH_NEURONX_ENABLE_NKI_SDPA") + os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + for key, original in self._saved_env.items(): + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original + gc.collect() + backend_empty_cache(torch_device) + + def test_sdxl_turbo_512(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = AutoPipelineForText2Image.from_pretrained(self.ckpt_id, torch_dtype=torch.float16, variant="fp16") + pipe.to(torch_device) + if is_torch_neuronx_available(): + torch.neuron.synchronize() + pipe.set_progress_bar_config(disable=None) + + image = pipe( + self.prompt, + num_inference_steps=1, + guidance_scale=0.0, + generator=generator, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1] + self.assertEqual(image.shape, (1, 512, 512, 3)) + self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") + expected_slice = np.array([0.3524, 0.3160, 0.3652, 0.3316, 0.3376, 0.3315, 0.3042, 0.3102, 0.3449]) + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 5e-2) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 060f9ee0f882..6d6df8b24d1e 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -46,6 +46,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -110,6 +111,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( @@ -142,8 +145,8 @@ def assert_tensors_close( """ Assert that two tensors are close within tolerance. - Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| - Provides concise, actionable error messages without dumping full tensors. + Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| Provides concise, + actionable error messages without dumping full tensors. Args: actual: The actual tensor from the computation. @@ -337,8 +340,7 @@ def nightly(test_case): def is_torch_compile(test_case): """ Decorator marking a test as a torch.compile test. These tests can be filtered using: - pytest -m "not compile" to skip - pytest -m compile to run only these tests + pytest -m "not compile" to skip pytest -m compile to run only these tests """ return pytest.mark.compile(test_case) @@ -346,8 +348,7 @@ def is_torch_compile(test_case): def is_single_file(test_case): """ Decorator marking a test as a single file loading test. These tests can be filtered using: - pytest -m "not single_file" to skip - pytest -m single_file to run only these tests + pytest -m "not single_file" to skip pytest -m single_file to run only these tests """ return pytest.mark.single_file(test_case) @@ -355,8 +356,7 @@ def is_single_file(test_case): def is_lora(test_case): """ Decorator marking a test as a LoRA test. These tests can be filtered using: - pytest -m "not lora" to skip - pytest -m lora to run only these tests + pytest -m "not lora" to skip pytest -m lora to run only these tests """ return pytest.mark.lora(test_case) @@ -364,8 +364,7 @@ def is_lora(test_case): def is_ip_adapter(test_case): """ Decorator marking a test as an IP Adapter test. These tests can be filtered using: - pytest -m "not ip_adapter" to skip - pytest -m ip_adapter to run only these tests + pytest -m "not ip_adapter" to skip pytest -m ip_adapter to run only these tests """ return pytest.mark.ip_adapter(test_case) @@ -373,8 +372,7 @@ def is_ip_adapter(test_case): def is_training(test_case): """ Decorator marking a test as a training test. These tests can be filtered using: - pytest -m "not training" to skip - pytest -m training to run only these tests + pytest -m "not training" to skip pytest -m training to run only these tests """ return pytest.mark.training(test_case) @@ -382,8 +380,7 @@ def is_training(test_case): def is_attention(test_case): """ Decorator marking a test as an attention test. These tests can be filtered using: - pytest -m "not attention" to skip - pytest -m attention to run only these tests + pytest -m "not attention" to skip pytest -m attention to run only these tests """ return pytest.mark.attention(test_case) @@ -391,8 +388,7 @@ def is_attention(test_case): def is_memory(test_case): """ Decorator marking a test as a memory optimization test. These tests can be filtered using: - pytest -m "not memory" to skip - pytest -m memory to run only these tests + pytest -m "not memory" to skip pytest -m memory to run only these tests """ return pytest.mark.memory(test_case) @@ -400,8 +396,7 @@ def is_memory(test_case): def is_cpu_offload(test_case): """ Decorator marking a test as a CPU offload test. These tests can be filtered using: - pytest -m "not cpu_offload" to skip - pytest -m cpu_offload to run only these tests + pytest -m "not cpu_offload" to skip pytest -m cpu_offload to run only these tests """ return pytest.mark.cpu_offload(test_case) @@ -409,8 +404,7 @@ def is_cpu_offload(test_case): def is_group_offload(test_case): """ Decorator marking a test as a group offload test. These tests can be filtered using: - pytest -m "not group_offload" to skip - pytest -m group_offload to run only these tests + pytest -m "not group_offload" to skip pytest -m group_offload to run only these tests """ return pytest.mark.group_offload(test_case) @@ -418,8 +412,7 @@ def is_group_offload(test_case): def is_quantization(test_case): """ Decorator marking a test as a quantization test. These tests can be filtered using: - pytest -m "not quantization" to skip - pytest -m quantization to run only these tests + pytest -m "not quantization" to skip pytest -m quantization to run only these tests """ return pytest.mark.quantization(test_case) @@ -427,8 +420,7 @@ def is_quantization(test_case): def is_bitsandbytes(test_case): """ Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using: - pytest -m "not bitsandbytes" to skip - pytest -m bitsandbytes to run only these tests + pytest -m "not bitsandbytes" to skip pytest -m bitsandbytes to run only these tests """ return pytest.mark.bitsandbytes(test_case) @@ -436,8 +428,7 @@ def is_bitsandbytes(test_case): def is_quanto(test_case): """ Decorator marking a test as a Quanto quantization test. These tests can be filtered using: - pytest -m "not quanto" to skip - pytest -m quanto to run only these tests + pytest -m "not quanto" to skip pytest -m quanto to run only these tests """ return pytest.mark.quanto(test_case) @@ -445,8 +436,7 @@ def is_quanto(test_case): def is_torchao(test_case): """ Decorator marking a test as a TorchAO quantization test. These tests can be filtered using: - pytest -m "not torchao" to skip - pytest -m torchao to run only these tests + pytest -m "not torchao" to skip pytest -m torchao to run only these tests """ return pytest.mark.torchao(test_case) @@ -454,8 +444,7 @@ def is_torchao(test_case): def is_gguf(test_case): """ Decorator marking a test as a GGUF quantization test. These tests can be filtered using: - pytest -m "not gguf" to skip - pytest -m gguf to run only these tests + pytest -m "not gguf" to skip pytest -m gguf to run only these tests """ return pytest.mark.gguf(test_case) @@ -463,8 +452,7 @@ def is_gguf(test_case): def is_modelopt(test_case): """ Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using: - pytest -m "not modelopt" to skip - pytest -m modelopt to run only these tests + pytest -m "not modelopt" to skip pytest -m modelopt to run only these tests """ return pytest.mark.modelopt(test_case) @@ -472,8 +460,7 @@ def is_modelopt(test_case): def is_context_parallel(test_case): """ Decorator marking a test as a context parallel inference test. These tests can be filtered using: - pytest -m "not context_parallel" to skip - pytest -m context_parallel to run only these tests + pytest -m "not context_parallel" to skip pytest -m context_parallel to run only these tests """ return pytest.mark.context_parallel(test_case) @@ -481,8 +468,7 @@ def is_context_parallel(test_case): def is_cache(test_case): """ Decorator marking a test as a cache test. These tests can be filtered using: - pytest -m "not cache" to skip - pytest -m cache to run only these tests + pytest -m "not cache" to skip pytest -m cache to run only these tests """ return pytest.mark.cache(test_case) @@ -552,6 +538,14 @@ def require_torch_accelerator(test_case): return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case) +def require_torch_neuron(test_case): + """Decorator marking a test that requires a Neuron device (Trainium/Inferentia).""" + return pytest.mark.skipif( + not (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available()), + reason="test requires Neuron device", + )(test_case) + + def require_torch_multi_gpu(test_case): """ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without @@ -1327,7 +1321,7 @@ class CaptureLogger: Example: ```python >>> from diffusers import logging - >>> from diffusers..testing_utils import CaptureLogger + >>> from diffusers.utils.testing_utils import CaptureLogger >>> msg = "Testing 1, 2, 3" >>> logging.set_verbosity_info() @@ -1435,6 +1429,15 @@ def _is_torch_fp64_available(device): # Behaviour flags BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + # Neuron device key: torch.neuron.current_device() returns an int (e.g. 0). + # We capture it once at import time if torch_neuronx is available so we can add it + # to all dispatch tables using the same key that torch_device is set to. + _neuron_device = ( + torch.neuron.current_device() + if (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available()) + else None + ) + # Function definitions BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, @@ -1486,13 +1489,20 @@ def _is_torch_fp64_available(device): "default": None, } + if _neuron_device is not None: + BACKEND_EMPTY_CACHE[_neuron_device] = None + BACKEND_DEVICE_COUNT[_neuron_device] = torch.neuron.device_count + BACKEND_MANUAL_SEED[_neuron_device] = torch.manual_seed + BACKEND_RESET_PEAK_MEMORY_STATS[_neuron_device] = None + BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None + BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0 + BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize + BACKEND_SUPPORTS_TRAINING[_neuron_device] = False + # This dispatches a defined function according to the accelerator from the function definitions. def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs): - if device not in dispatch_table: - return dispatch_table["default"](*args, **kwargs) - - fn = dispatch_table[device] + fn = dispatch_table[device] if device in dispatch_table else dispatch_table["default"] # Some device agnostic functions return values. Need to guard against 'None' instead at # user level