From 98f6c8c257c4f2e00b209c188fc993feb93e3fc7 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 18 Mar 2026 11:15:03 +0000 Subject: [PATCH 01/24] draft:add neuron as a legit backend --- src/diffusers/pipelines/pipeline_utils.py | 58 ++++++++++++++++++++++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 ++ src/diffusers/utils/torch_utils.py | 25 ++++++++-- 4 files changed, 85 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..44fe8367636d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,6 +68,7 @@ is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2248,6 +2249,61 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + def enable_neuron_compile( + self, + model_names: Optional[List[str]] = None, + cache_dir: Optional[str] = None, + fullgraph: bool = True, + ) -> None: + """ + Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``, + enabling whole-graph NEFF compilation for AWS Trainium/Inferentia. + + The first forward call per component triggers neuronx-cc compilation (slow). + Use ``neuron_warmup()`` to trigger this explicitly before timed inference. + + Args: + model_names (`List[str]`, *optional*): + Component names to compile. Defaults to all nn.Module components. + cache_dir (`str`, *optional*): + Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``. + Skips recompilation on subsequent runs. + fullgraph (`bool`, defaults to `True`): + Disallow graph breaks (required for full-graph fusion). + """ + requires_backends(self, "torch_neuronx") + import torch_neuronx # noqa: F401 — registers neuron backend + + if cache_dir is not None: + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir + + if model_names is None: + model_names = [ + name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module) + ] + + for name in model_names: + component = getattr(self, name, None) + if isinstance(component, torch.nn.Module) and not is_compiled_module(component): + logger.info(f"Compiling {name} with backend='neuron'") + setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) + + def neuron_warmup(self, *args, **kwargs) -> None: + """ + Runs a single dummy forward pass through the pipeline to trigger neuronx-cc + compilation for all components (static-shape NEFF compilation). + + This is equivalent to calling ``__call__`` with the same shapes but discards + the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast. + + Pass the same arguments you would use for real inference (height, width, + num_inference_steps, batch_size, etc.) so that the compiled shapes match. + """ + logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...") + with torch.no_grad(): + self(*args, **kwargs) + logger.info("Neuron warmup complete.") + class StableDiffusionMixin: r""" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..8a86cf4f4151 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -110,6 +110,7 @@ is_timm_available, is_torch_available, is_torch_mlu_available, + is_torch_neuronx_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 551fa358a28d..e23fccc1a374 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -249,6 +250,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_neuronx_available(): + return _torch_neuronx_available + + def is_flax_available(): return _flax_available diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7f4cb3e12766..88b53e2b5b16 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -21,19 +21,26 @@ import os from . import logging -from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +from .import_utils import ( + is_torch_available, + is_torch_mlu_available, + is_torch_neuronx_available, + is_torch_npu_available, + is_torch_version, +) if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "neuron": None, "default": None, } BACKEND_DEVICE_COUNT = { @@ -41,6 +48,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -48,6 +56,7 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + "neuron": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -55,6 +64,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -62,6 +72,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -69,6 +80,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "neuron": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -76,6 +88,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -164,11 +177,15 @@ def randn_tensor( layout = layout or torch.strided device = device or torch.device("cpu") + # Neuron (XLA) does not support creating random tensors directly on device; always use CPU + if device.type == "neuron": + rand_device = torch.device("cpu") + if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" - if device != "mps": + if device.type not in ("mps", "neuron"): logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" @@ -289,6 +306,8 @@ def get_device(): return "mps" elif is_torch_mlu_available(): return "mlu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + return "neuron" else: return "cpu" From a76953cf34aecd0efda8e364798102b3c71a0db2 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 26 Mar 2026 11:57:09 +0000 Subject: [PATCH 02/24] feat: neuron-specific changes in the pipeline --- .../models/unets/unet_2d_condition.py | 5 ++-- src/diffusers/pipelines/pipeline_utils.py | 2 ++ .../pipeline_stable_diffusion_xl.py | 26 ++++++++++++++++--- src/diffusers/utils/import_utils.py | 5 ++++ src/diffusers/utils/torch_utils.py | 2 +- 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index deae25899475..b533bef35414 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" is_npu = sample.device.type == "npu" + is_neuron = sample.device.type == "neuron" if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 44fe8367636d..5b329f46e2aa 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2273,6 +2273,7 @@ def enable_neuron_compile( """ requires_backends(self, "torch_neuronx") import torch_neuronx # noqa: F401 — registers neuron backend + from torch_neuronx.neuron_dynamo_backend import set_model_name if cache_dir is not None: os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir @@ -2286,6 +2287,7 @@ def enable_neuron_compile( component = getattr(self, name, None) if isinstance(component, torch.nn.Module) and not is_compiled_module(component): logger.info(f"Compiling {name} with backend='neuron'") + set_model_name(name) setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) def neuron_warmup(self, *args, **kwargs) -> None: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2f6b105702e8..fdda2547f09e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1092,7 +1092,11 @@ def __call__( ) # 4. Prepare timesteps - if XLA_AVAILABLE: + # Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where + # dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep() + # are incompatible with static-graph compilation. + is_neuron_device = hasattr(device, "type") and device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device @@ -1195,15 +1199,23 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region. + # index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs. + if is_neuron_device: + latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device) + else: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds + # For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support + # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. + t_unet = t.to(torch.float32).to(device) if is_neuron_device else t noise_pred = self.unet( latent_model_input, - t, + t_unet, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, @@ -1222,7 +1234,13 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device. + if is_neuron_device: + latents = self.scheduler.step( + noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False + )[0].to(device) + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e23fccc1a374..2ce989626b3d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -584,6 +584,10 @@ def is_av_available(): """ +TORCH_NEURONX_IMPORT_ERROR = """ +{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -614,6 +618,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7d16c8556689..55fee1d3249e 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -93,7 +93,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, - "neuron": None, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(), "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 2480388fb12a423527d491ab5211c058b07b3262 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 27 Mar 2026 17:55:26 +0000 Subject: [PATCH 03/24] tests: eager tests --- src/diffusers/pipelines/pipeline_utils.py | 58 --------------------- src/diffusers/utils/testing_utils.py | 3 ++ tests/pipelines/pixart_alpha/test_pixart.py | 10 +++- tests/testing_utils.py | 26 +++++++-- 4 files changed, 33 insertions(+), 64 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 5b329f46e2aa..bbee2189c22f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2249,64 +2249,6 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 - def enable_neuron_compile( - self, - model_names: Optional[List[str]] = None, - cache_dir: Optional[str] = None, - fullgraph: bool = True, - ) -> None: - """ - Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``, - enabling whole-graph NEFF compilation for AWS Trainium/Inferentia. - - The first forward call per component triggers neuronx-cc compilation (slow). - Use ``neuron_warmup()`` to trigger this explicitly before timed inference. - - Args: - model_names (`List[str]`, *optional*): - Component names to compile. Defaults to all nn.Module components. - cache_dir (`str`, *optional*): - Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``. - Skips recompilation on subsequent runs. - fullgraph (`bool`, defaults to `True`): - Disallow graph breaks (required for full-graph fusion). - """ - requires_backends(self, "torch_neuronx") - import torch_neuronx # noqa: F401 — registers neuron backend - from torch_neuronx.neuron_dynamo_backend import set_model_name - - if cache_dir is not None: - os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir - - if model_names is None: - model_names = [ - name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module) - ] - - for name in model_names: - component = getattr(self, name, None) - if isinstance(component, torch.nn.Module) and not is_compiled_module(component): - logger.info(f"Compiling {name} with backend='neuron'") - set_model_name(name) - setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) - - def neuron_warmup(self, *args, **kwargs) -> None: - """ - Runs a single dummy forward pass through the pipeline to trigger neuronx-cc - compilation for all components (static-shape NEFF compilation). - - This is equivalent to calling ``__call__`` with the same shapes but discards - the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast. - - Pass the same arguments you would use for real inference (height, width, - num_inference_steps, batch_size, etc.) so that the compiled shapes match. - """ - logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...") - with torch.no_grad(): - self(*args, **kwargs) - logger.info("Neuron warmup complete.") - - class StableDiffusionMixin: r""" Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 619a37034949..eefe52c477a6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -46,6 +46,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -113,6 +114,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 037a9f44f31e..0aa6812c6b25 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -28,6 +28,8 @@ PixArtTransformer2DModel, ) +from diffusers.utils.import_utils import is_torch_neuronx_available + from ...testing_utils import ( backend_empty_cache, enable_full_determinism, @@ -291,7 +293,9 @@ def test_pixart_1024(self): expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_512(self): generator = torch.Generator("cpu").manual_seed(0) @@ -307,7 +311,9 @@ def test_pixart_512(self): expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_1024_without_resolution_binning(self): generator = torch.manual_seed(0) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 53c1b8aa26ce..778381cf31e0 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -45,6 +45,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -109,6 +110,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( @@ -1427,6 +1430,15 @@ def _is_torch_fp64_available(device): # Behaviour flags BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + # Neuron device key: torch.neuron.current_device() returns an int (e.g. 0). + # We capture it once at import time if torch_neuronx is available so we can add it + # to all dispatch tables using the same key that torch_device is set to. + _neuron_device = ( + torch.neuron.current_device() + if (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available()) + else None + ) + # Function definitions BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, @@ -1478,13 +1490,19 @@ def _is_torch_fp64_available(device): "default": None, } + if _neuron_device is not None: + BACKEND_EMPTY_CACHE[_neuron_device] = None + BACKEND_DEVICE_COUNT[_neuron_device] = torch.neuron.device_count + BACKEND_MANUAL_SEED[_neuron_device] = torch.manual_seed + BACKEND_RESET_PEAK_MEMORY_STATS[_neuron_device] = None + BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None + BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0 + BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize + # This dispatches a defined function according to the accelerator from the function definitions. def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs): - if device not in dispatch_table: - return dispatch_table["default"](*args, **kwargs) - - fn = dispatch_table[device] + fn = dispatch_table[device] if device in dispatch_table else dispatch_table["default"] # Some device agnostic functions return values. Need to guard against 'None' instead at # user level From 929ab7288fff62db2e5c32ec7e857eb3e9280f70 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 9 Apr 2026 16:37:56 +0000 Subject: [PATCH 04/24] fix: style --- .../train_instruct_pix2pix_sdxl.py | 8 ++- src/diffusers/loaders/peft.py | 2 +- src/diffusers/models/_modeling_parallel.py | 63 ++++++++++++++++++- src/diffusers/pipelines/pipeline_utils.py | 4 +- src/diffusers/utils/torch_utils.py | 9 ++- tests/pipelines/pixart_alpha/test_pixart.py | 1 - 6 files changed, 78 insertions(+), 9 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 5df0e22fe1cc..ce146c895686 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -85,9 +85,11 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final os.makedirs(val_save_dir) original_image = ( - lambda image_url_or_path: load_image(image_url_or_path) - if urlparse(image_url_or_path).scheme - else Image.open(image_url_or_path).convert("RGB") + lambda image_url_or_path: ( + load_image(image_url_or_path) + if urlparse(image_url_or_path).scheme + else Image.open(image_url_or_path).convert("RGB") + ) )(args.val_image_url_or_path) if torch.backends.mps.is_available(): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index daa078bc25d5..68d9104e028d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( - lambda: (lambda model_cls, weights: weights), + lambda: lambda model_cls, weights: weights, { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 8573c01ca4c7..e673980dbf44 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -35,7 +35,6 @@ # - Unified Attention # - More dispatcher attention backends # - CFG/Data Parallel -# - Tensor Parallel @dataclass @@ -142,6 +141,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() +@dataclass +class TensorParallelConfig: + """ + Configuration for tensor parallelism. + + Tensor parallelism shards weight matrices (column-wise and row-wise) across devices. + Each device computes a partial result; an AllReduce/AllGather at layer boundaries + reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module`` + with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles. + + On Neuron, use the ``_pre_shard_and_tp`` workaround from + ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug + on large tensors (>= 5120x5120). + + Args: + tp_degree (`int`, defaults to `1`): + Number of devices to shard across. Must be a divisor of the number of + attention heads (and FFN hidden dimensions) of the model being parallelised. + mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): + A custom device mesh to use. If provided, ``tp_degree`` is inferred from + ``mesh.size()`` and the argument is ignored. Useful when combining TP with + other parallelism strategies (e.g. CP) that share the same mesh. + """ + + tp_degree: int = 1 + mesh: torch.distributed.device_mesh.DeviceMesh | None = None + + _rank: int = None + _world_size: int = None + _device: torch.device = None + _mesh: torch.distributed.device_mesh.DeviceMesh = None + + def __post_init__(self): + if self.tp_degree < 1: + raise ValueError("`tp_degree` must be >= 1.") + + def setup( + self, + rank: int, + world_size: int, + device: torch.device, + mesh: torch.distributed.device_mesh.DeviceMesh | None = None, + ): + self._rank = rank + self._world_size = world_size + self._device = device + if mesh is not None: + self._mesh = mesh + elif self.mesh is not None: + self._mesh = self.mesh + else: + from torch.distributed.device_mesh import init_device_mesh + + device_type = str(device).split(":")[0] + self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",)) + + @dataclass class ParallelConfig: """ @@ -150,9 +206,12 @@ class ParallelConfig: Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism. + tensor_parallel_config (`TensorParallelConfig`, *optional*): + Configuration for tensor parallelism. """ context_parallel_config: ContextParallelConfig | None = None + tensor_parallel_config: TensorParallelConfig | None = None _rank: int = None _world_size: int = None @@ -173,6 +232,8 @@ def setup( self._mesh = mesh if self.context_parallel_config is not None: self.context_parallel_config.setup(rank, world_size, device, mesh) + if self.tensor_parallel_config is not None: + self.tensor_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index bbee2189c22f..d675f1de04a7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,7 +68,6 @@ is_transformers_version, logging, numpy_to_pil, - requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2249,6 +2248,7 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + class StableDiffusionMixin: r""" Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 55fee1d3249e..e99719625df6 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -39,7 +39,14 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} + BACKEND_SUPPORTS_TRAINING = { + "cuda": True, + "xpu": True, + "cpu": True, + "mps": False, + "neuron": False, + "default": True, + } BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 0aa6812c6b25..86fe673a8c7d 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -27,7 +27,6 @@ PixArtAlphaPipeline, PixArtTransformer2DModel, ) - from diffusers.utils.import_utils import is_torch_neuronx_available from ...testing_utils import ( From 3bb9c7c3fc8483228d93c6bf9e16b2905712d17b Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 10 Apr 2026 15:35:35 +0000 Subject: [PATCH 05/24] fix:apr_02 beta --- src/diffusers/models/transformers/transformer_flux2.py | 3 ++- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 5c90f3a46a98..43d36d6476af 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -961,7 +961,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: pos = ids.float() is_mps = ids.device.type == "mps" is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + is_neuron = ids.device.type == "neuron" + freqs_dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] for i in range(len(self.axes_dim)): cos, sin = get_1d_rotary_pos_embed( diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 604e51d88583..bda4e40f3768 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -29,6 +29,7 @@ deprecate, is_bs4_available, is_ftfy_available, + is_torch_neuronx_available, is_torch_xla_available, logging, replace_example_docstring, @@ -862,7 +863,7 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - if XLA_AVAILABLE: + if XLA_AVAILABLE or is_torch_neuronx_available(): timestep_device = "cpu" else: timestep_device = device @@ -914,10 +915,11 @@ def __call__( # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" is_npu = latent_model_input.device.type == "npu" + is_neuron = latent_model_input.device.type == "neuron" if isinstance(current_timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) From 291171b3bb38afc763f13a25689f9526063ba770 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 11 May 2026 14:36:35 +0000 Subject: [PATCH 06/24] cleanup: remove tp part, for another pr --- src/diffusers/models/_modeling_parallel.py | 75 +--------------------- 1 file changed, 1 insertion(+), 74 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 63bec62ad5b8..8573c01ca4c7 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -35,6 +35,7 @@ # - Unified Attention # - More dispatcher attention backends # - CFG/Data Parallel +# - Tensor Parallel @dataclass @@ -63,9 +64,6 @@ class ContextParallelConfig: Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and `ring_degree` must be 1. - ring_anything (`bool`, *optional*, defaults to `False`): - Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, - `ring_degree` must be greater than 1 and `ulysses_degree` must be 1. mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of creating a new one. This is useful when combining context parallelism with other parallelism strategies @@ -84,8 +82,6 @@ class ContextParallelConfig: # Whether to enable ulysses anything attention to support # any sequence lengths and any head numbers. ulysses_anything: bool = False - # Whether to enable ring anything attention to support any sequence lengths. - ring_anything: bool = False _rank: int = None _world_size: int = None @@ -118,13 +114,6 @@ def __post_init__(self): raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") if self.ring_degree > 1: raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") - if self.ring_anything: - if self.ring_degree == 1: - raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.") - if self.ulysses_degree > 1: - raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.") - if self.ulysses_anything and self.ring_anything: - raise ValueError("ulysses_anything and ring_anything cannot both be enabled.") @property def mesh_shape(self) -> tuple[int, int]: @@ -153,63 +142,6 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() -@dataclass -class TensorParallelConfig: - """ - Configuration for tensor parallelism. - - Tensor parallelism shards weight matrices (column-wise and row-wise) across devices. - Each device computes a partial result; an AllReduce/AllGather at layer boundaries - reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module`` - with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles. - - On Neuron, use the ``_pre_shard_and_tp`` workaround from - ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug - on large tensors (>= 5120x5120). - - Args: - tp_degree (`int`, defaults to `1`): - Number of devices to shard across. Must be a divisor of the number of - attention heads (and FFN hidden dimensions) of the model being parallelised. - mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): - A custom device mesh to use. If provided, ``tp_degree`` is inferred from - ``mesh.size()`` and the argument is ignored. Useful when combining TP with - other parallelism strategies (e.g. CP) that share the same mesh. - """ - - tp_degree: int = 1 - mesh: torch.distributed.device_mesh.DeviceMesh | None = None - - _rank: int = None - _world_size: int = None - _device: torch.device = None - _mesh: torch.distributed.device_mesh.DeviceMesh = None - - def __post_init__(self): - if self.tp_degree < 1: - raise ValueError("`tp_degree` must be >= 1.") - - def setup( - self, - rank: int, - world_size: int, - device: torch.device, - mesh: torch.distributed.device_mesh.DeviceMesh | None = None, - ): - self._rank = rank - self._world_size = world_size - self._device = device - if mesh is not None: - self._mesh = mesh - elif self.mesh is not None: - self._mesh = self.mesh - else: - from torch.distributed.device_mesh import init_device_mesh - - device_type = str(device).split(":")[0] - self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",)) - - @dataclass class ParallelConfig: """ @@ -218,12 +150,9 @@ class ParallelConfig: Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism. - tensor_parallel_config (`TensorParallelConfig`, *optional*): - Configuration for tensor parallelism. """ context_parallel_config: ContextParallelConfig | None = None - tensor_parallel_config: TensorParallelConfig | None = None _rank: int = None _world_size: int = None @@ -244,8 +173,6 @@ def setup( self._mesh = mesh if self.context_parallel_config is not None: self.context_parallel_config.setup(rank, world_size, device, mesh) - if self.tensor_parallel_config is not None: - self.tensor_parallel_config.setup(rank, world_size, device, mesh) @dataclass(frozen=True) From f1caec0b36aa5916910bd038adc9fd0a7ebaec8b Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 11 May 2026 16:16:09 +0000 Subject: [PATCH 07/24] fix: restore ring_anything to ContextParallelConfig after over-aggressive TP cleanup The previous cleanup commit removed TensorParallelConfig (correct) but also accidentally removed ring_anything from ContextParallelConfig (incorrect). ring_anything is a context-parallel feature referenced in context_parallel.py and must remain in the config. Co-Authored-By: Claude Sonnet 4.6 --- src/diffusers/models/_modeling_parallel.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 8573c01ca4c7..56e1eced9eef 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -64,6 +64,9 @@ class ContextParallelConfig: Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and `ring_degree` must be 1. + ring_anything (`bool`, *optional*, defaults to `False`): + Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, + `ring_degree` must be greater than 1 and `ulysses_degree` must be 1. mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of creating a new one. This is useful when combining context parallelism with other parallelism strategies @@ -82,6 +85,8 @@ class ContextParallelConfig: # Whether to enable ulysses anything attention to support # any sequence lengths and any head numbers. ulysses_anything: bool = False + # Whether to enable ring anything attention to support any sequence lengths. + ring_anything: bool = False _rank: int = None _world_size: int = None @@ -114,6 +119,13 @@ def __post_init__(self): raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") if self.ring_degree > 1: raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") + if self.ring_anything: + if self.ring_degree == 1: + raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.") + if self.ulysses_degree > 1: + raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.") + if self.ulysses_anything and self.ring_anything: + raise ValueError("ulysses_anything and ring_anything cannot both be enabled.") @property def mesh_shape(self) -> tuple[int, int]: From 510a914ae2c0c2db978725ed2f9417eed60397c1 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 11 May 2026 16:46:07 +0000 Subject: [PATCH 08/24] removal: style fix --- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 8 +++----- src/diffusers/loaders/peft.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 5ff552276dc3..4b74e3b61607 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -85,11 +85,9 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final os.makedirs(val_save_dir) original_image = ( - lambda image_url_or_path: ( - load_image(image_url_or_path) - if urlparse(image_url_or_path).scheme - else Image.open(image_url_or_path).convert("RGB") - ) + lambda image_url_or_path: load_image(image_url_or_path) + if urlparse(image_url_or_path).scheme + else Image.open(image_url_or_path).convert("RGB") )(args.val_image_url_or_path) if torch.backends.mps.is_available(): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 68d9104e028d..daa078bc25d5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( - lambda: lambda model_cls, weights: weights, + lambda: (lambda model_cls, weights: weights), { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, From c3121670b0c390a5a14a50e691fb140610477980 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 12 May 2026 16:33:33 +0000 Subject: [PATCH 09/24] tests:sdxl + flux2 --- .../flux2/test_pipeline_flux2_klein.py | 76 ++++++++++++++++++- .../test_stable_diffusion_xl.py | 53 +++++++++++++ 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index 8ed9bf3d1e91..cda202fa74ab 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -1,3 +1,4 @@ +import gc import unittest import numpy as np @@ -11,8 +12,14 @@ Flux2KleinPipeline, Flux2Transformer2DModel, ) +from diffusers.utils.import_utils import is_torch_neuronx_available -from ...testing_utils import torch_device +from ...testing_utils import ( + backend_empty_cache, + require_torch_accelerator, + slow, + torch_device, +) from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist @@ -181,3 +188,70 @@ def test_image_input(self): @unittest.skip("Needs to be revisited") def test_encode_prompt_works_in_isolation(self): pass + + +@slow +@require_torch_accelerator +class Flux2KleinPipelineNeuronTests(unittest.TestCase): + ckpt_id = "black-forest-labs/FLUX.2-klein-4B" + prompt = "A small cactus with a happy face in the Sahara desert." + # Reuse the shared NEFF cache so per-op kernels (e.g. dtype casts) are not + # recompiled from scratch on each test run. TORCH_NEURONX_FALLBACK_ONLY_FOR_UNIMPLEMENTED_OPS=1 + # is set system-wide, so a fresh compilation failure raises an error instead of + # silently falling back to CPU. + neff_cache_dir = "/tmp/neff_cache" + + def setUp(self): + super().setUp() + import os + + os.makedirs(self.neff_cache_dir, exist_ok=True) + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = self.neff_cache_dir + # The Qwen3 text-encoder attention ops accumulate in the XLA lazy graph and + # are compiled together when the transformer triggers its first graph flush + # (ids.float() in Flux2PosEmbed). The NKI SDPA kernel selected for those + # attention ops fails with NCC_INLA001 (NKI version mismatch). Disabling it + # falls back to the standard SDPA decomposition which compiles correctly. + os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_flux2_klein_inference_512(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = Flux2KleinPipeline.from_pretrained(self.ckpt_id, torch_dtype=torch.bfloat16) + pipe.to(torch_device) + if is_torch_neuronx_available(): + # Flush pending lazy XLA parameter-copy ops so they don't pile up and + # trigger a batch compilation on the first inference call (NCC_IDRV017). + torch.neuron.synchronize() + pipe.set_progress_bar_config(disable=None) + + image = pipe( + prompt=self.prompt, + height=512, + width=512, + num_inference_steps=4, + guidance_scale=1.0, + generator=generator, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1] + self.assertEqual(image.shape, (1, 512, 512, 3)) + + # Verify outputs are valid pixel values + self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") + # Verify the image is non-trivial (not blank or saturated) + self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") + + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA. + # Use a wider tolerance when running on Neuron vs. a reference CUDA run. + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + _ = atol # atol is used when comparing against a reference slice; add it here once available. + diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index b318a505e9db..60629a3649cc 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -23,6 +23,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( + AutoPipelineForText2Image, AutoencoderKL, DDIMScheduler, DPMSolverMultistepScheduler, @@ -34,6 +35,7 @@ UNet2DConditionModel, UniPCMultistepScheduler, ) +from diffusers.utils.import_utils import is_torch_neuronx_available from ...testing_utils import ( backend_empty_cache, @@ -974,3 +976,54 @@ def test_stable_diffusion_lcm(self): max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten()) assert max_diff < 1e-2 + + +@slow +@require_torch_accelerator +class StableDiffusionXLTurboPipelineNeuronTests(unittest.TestCase): + ckpt_id = "stabilityai/sdxl-turbo" + prompt = "A small cactus with a happy face in the Sahara desert." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_sdxl_turbo_512(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = AutoPipelineForText2Image.from_pretrained( + self.ckpt_id, torch_dtype=torch.float16, variant="fp16" + ) + pipe.to(torch_device) + if is_torch_neuronx_available(): + # Flush pending lazy XLA parameter-copy ops so they don't pile up and + # trigger a batch compilation on the first inference call (NCC_IDRV017). + torch.neuron.synchronize() + pipe.set_progress_bar_config(disable=None) + + image = pipe( + self.prompt, + num_inference_steps=1, + guidance_scale=0.0, + generator=generator, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1] + self.assertEqual(image.shape, (1, 512, 512, 3)) + + # Verify outputs are valid pixel values + self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") + # Verify the image is non-trivial (not blank or saturated) + self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") + + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA. + # Use a wider tolerance when running on Neuron vs. a reference CUDA run. + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + _ = atol # atol is used when comparing against a reference slice; add it here once available. From 1165f9fcc2db94ed536c53ea7f9b0770cf838567 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 13 May 2026 11:31:40 +0000 Subject: [PATCH 10/24] tests: simplify --- .../pipeline_stable_diffusion_xl.py | 5 +-- .../flux2/test_pipeline_flux2_klein.py | 32 +++++-------------- .../test_stable_diffusion_xl.py | 14 +++----- 3 files changed, 14 insertions(+), 37 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index fdda2547f09e..9e12e9459369 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1092,9 +1092,6 @@ def __call__( ) # 4. Prepare timesteps - # Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where - # dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep() - # are incompatible with static-graph compilation. is_neuron_device = hasattr(device, "type") and device.type == "neuron" if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" @@ -1210,7 +1207,7 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds - # For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support + # [Neuron] pre-cast timestep to float32 on device. Neuron XLA does not support # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. t_unet = t.to(torch.float32).to(device) if is_neuron_device else t noise_pred = self.unet( diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index cda202fa74ab..264da46e1d27 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -1,4 +1,5 @@ import gc +import os import unittest import numpy as np @@ -192,27 +193,17 @@ def test_encode_prompt_works_in_isolation(self): @slow @require_torch_accelerator -class Flux2KleinPipelineNeuronTests(unittest.TestCase): +class Flux2KleinPipelineIntegrationTests(unittest.TestCase): ckpt_id = "black-forest-labs/FLUX.2-klein-4B" prompt = "A small cactus with a happy face in the Sahara desert." - # Reuse the shared NEFF cache so per-op kernels (e.g. dtype casts) are not - # recompiled from scratch on each test run. TORCH_NEURONX_FALLBACK_ONLY_FOR_UNIMPLEMENTED_OPS=1 - # is set system-wide, so a fresh compilation failure raises an error instead of - # silently falling back to CPU. - neff_cache_dir = "/tmp/neff_cache" def setUp(self): super().setUp() - import os - - os.makedirs(self.neff_cache_dir, exist_ok=True) - os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = self.neff_cache_dir - # The Qwen3 text-encoder attention ops accumulate in the XLA lazy graph and - # are compiled together when the transformer triggers its first graph flush - # (ids.float() in Flux2PosEmbed). The NKI SDPA kernel selected for those - # attention ops fails with NCC_INLA001 (NKI version mismatch). Disabling it - # falls back to the standard SDPA decomposition which compiles correctly. - os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") + if is_torch_neuronx_available(): + neff_cache_dir = "/tmp/neff_cache" + os.makedirs(neff_cache_dir, exist_ok=True) + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = neff_cache_dir + os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") gc.collect() backend_empty_cache(torch_device) @@ -227,8 +218,6 @@ def test_flux2_klein_inference_512(self): pipe = Flux2KleinPipeline.from_pretrained(self.ckpt_id, torch_dtype=torch.bfloat16) pipe.to(torch_device) if is_torch_neuronx_available(): - # Flush pending lazy XLA parameter-copy ops so they don't pile up and - # trigger a batch compilation on the first inference call (NCC_IDRV017). torch.neuron.synchronize() pipe.set_progress_bar_config(disable=None) @@ -244,14 +233,9 @@ def test_flux2_klein_inference_512(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 512, 512, 3)) - - # Verify outputs are valid pixel values self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") - # Verify the image is non-trivial (not blank or saturated) self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") - # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA. - # Use a wider tolerance when running on Neuron vs. a reference CUDA run. atol = 1e-2 if is_torch_neuronx_available() else 1e-4 - _ = atol # atol is used when comparing against a reference slice; add it here once available. + _ = atol diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 60629a3649cc..5d875765a76c 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -15,6 +15,7 @@ import copy import gc +import os import tempfile import unittest @@ -980,12 +981,14 @@ def test_stable_diffusion_lcm(self): @slow @require_torch_accelerator -class StableDiffusionXLTurboPipelineNeuronTests(unittest.TestCase): +class StableDiffusionXLTurboPipelineIntegrationTests(unittest.TestCase): ckpt_id = "stabilityai/sdxl-turbo" prompt = "A small cactus with a happy face in the Sahara desert." def setUp(self): super().setUp() + if is_torch_neuronx_available(): + os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") gc.collect() backend_empty_cache(torch_device) @@ -1002,8 +1005,6 @@ def test_sdxl_turbo_512(self): ) pipe.to(torch_device) if is_torch_neuronx_available(): - # Flush pending lazy XLA parameter-copy ops so they don't pile up and - # trigger a batch compilation on the first inference call (NCC_IDRV017). torch.neuron.synchronize() pipe.set_progress_bar_config(disable=None) @@ -1017,13 +1018,8 @@ def test_sdxl_turbo_512(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 512, 512, 3)) - - # Verify outputs are valid pixel values self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") - # Verify the image is non-trivial (not blank or saturated) self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") - # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA. - # Use a wider tolerance when running on Neuron vs. a reference CUDA run. atol = 1e-2 if is_torch_neuronx_available() else 1e-4 - _ = atol # atol is used when comparing against a reference slice; add it here once available. + _ = atol From 5d13779c0250bc28594ba0868dd1d805ab3d708a Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 13 May 2026 11:44:58 +0000 Subject: [PATCH 11/24] tests: simplify --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index af15e3d374b0..fdf5710ebad3 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -100,7 +100,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, - "neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(), + "neuron": getattr(getattr(torch, "neuron", None), "synchronize", None), "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name From c0077092b194c37940b7fc32140802ee071140e1 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 13 May 2026 11:50:33 +0000 Subject: [PATCH 12/24] fix style --- tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 5d875765a76c..f03c58876bd8 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -24,8 +24,8 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( - AutoPipelineForText2Image, AutoencoderKL, + AutoPipelineForText2Image, DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, From fb5ea9429447a30c3c2656fa6aa34a2961b95b45 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 13 May 2026 12:03:14 +0000 Subject: [PATCH 13/24] fix style --- tests/pipelines/flux2/test_pipeline_flux2_klein.py | 1 - .../pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index 264da46e1d27..bcaf10555fa0 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -238,4 +238,3 @@ def test_flux2_klein_inference_512(self): atol = 1e-2 if is_torch_neuronx_available() else 1e-4 _ = atol - diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index f03c58876bd8..786f8a412d08 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -1000,9 +1000,7 @@ def tearDown(self): def test_sdxl_turbo_512(self): generator = torch.Generator("cpu").manual_seed(0) - pipe = AutoPipelineForText2Image.from_pretrained( - self.ckpt_id, torch_dtype=torch.float16, variant="fp16" - ) + pipe = AutoPipelineForText2Image.from_pretrained(self.ckpt_id, torch_dtype=torch.float16, variant="fp16") pipe.to(torch_device) if is_torch_neuronx_available(): torch.neuron.synchronize() From 9d93c6833a962c437f6c043e5d39897ad13ddfbd Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 13 May 2026 12:30:47 +0000 Subject: [PATCH 14/24] fix style for doc-builder --- src/diffusers/utils/import_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index bd00876603bd..937e20d6b155 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -606,7 +606,8 @@ def is_av_available(): TORCH_NEURONX_IMPORT_ERROR = """ -{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ +{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it +following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ """ BACKENDS_MAPPING = OrderedDict( From e8bf642113d5275a16780625998db350f805c2fc Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 18 May 2026 13:58:36 +0000 Subject: [PATCH 15/24] review: address comments --- .../pixart_alpha/pipeline_pixart_alpha.py | 4 +- .../flux2/test_pipeline_flux2_klein.py | 17 +++-- .../test_stable_diffusion_xl.py | 14 ++-- tests/testing_utils.py | 66 ++++++++----------- 4 files changed, 50 insertions(+), 51 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index bda4e40f3768..299265345e2a 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -29,7 +29,6 @@ deprecate, is_bs4_available, is_ftfy_available, - is_torch_neuronx_available, is_torch_xla_available, logging, replace_example_docstring, @@ -863,7 +862,8 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - if XLA_AVAILABLE or is_torch_neuronx_available(): + is_neuron_device = hasattr(device, "type") and device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index bcaf10555fa0..7499cb03967c 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -17,8 +17,7 @@ from ...testing_utils import ( backend_empty_cache, - require_torch_accelerator, - slow, + require_torch_neuron, torch_device, ) from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist @@ -191,17 +190,19 @@ def test_encode_prompt_works_in_isolation(self): pass -@slow -@require_torch_accelerator +@require_torch_neuron class Flux2KleinPipelineIntegrationTests(unittest.TestCase): ckpt_id = "black-forest-labs/FLUX.2-klein-4B" prompt = "A small cactus with a happy face in the Sahara desert." def setUp(self): super().setUp() + self._saved_env = {} if is_torch_neuronx_available(): neff_cache_dir = "/tmp/neff_cache" os.makedirs(neff_cache_dir, exist_ok=True) + for key in ("TORCH_NEURONX_NEFF_CACHE_DIR", "TORCH_NEURONX_ENABLE_NKI_SDPA"): + self._saved_env[key] = os.environ.get(key) os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = neff_cache_dir os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") gc.collect() @@ -209,6 +210,11 @@ def setUp(self): def tearDown(self): super().tearDown() + for key, original in self._saved_env.items(): + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original gc.collect() backend_empty_cache(torch_device) @@ -235,6 +241,3 @@ def test_flux2_klein_inference_512(self): self.assertEqual(image.shape, (1, 512, 512, 3)) self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") - - atol = 1e-2 if is_torch_neuronx_available() else 1e-4 - _ = atol diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 786f8a412d08..0b00731bd695 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -44,6 +44,7 @@ load_image, numpy_cosine_similarity_distance, require_torch_accelerator, + require_torch_neuron, slow, torch_device, ) @@ -979,21 +980,27 @@ def test_stable_diffusion_lcm(self): assert max_diff < 1e-2 -@slow -@require_torch_accelerator +@require_torch_neuron class StableDiffusionXLTurboPipelineIntegrationTests(unittest.TestCase): ckpt_id = "stabilityai/sdxl-turbo" prompt = "A small cactus with a happy face in the Sahara desert." def setUp(self): super().setUp() + self._saved_env = {} if is_torch_neuronx_available(): + self._saved_env["TORCH_NEURONX_ENABLE_NKI_SDPA"] = os.environ.get("TORCH_NEURONX_ENABLE_NKI_SDPA") os.environ.setdefault("TORCH_NEURONX_ENABLE_NKI_SDPA", "0") gc.collect() backend_empty_cache(torch_device) def tearDown(self): super().tearDown() + for key, original in self._saved_env.items(): + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original gc.collect() backend_empty_cache(torch_device) @@ -1018,6 +1025,3 @@ def test_sdxl_turbo_512(self): self.assertEqual(image.shape, (1, 512, 512, 3)) self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") - - atol = 1e-2 if is_torch_neuronx_available() else 1e-4 - _ = atol diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 0f88074ca50b..6d6df8b24d1e 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -145,8 +145,8 @@ def assert_tensors_close( """ Assert that two tensors are close within tolerance. - Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| - Provides concise, actionable error messages without dumping full tensors. + Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| Provides concise, + actionable error messages without dumping full tensors. Args: actual: The actual tensor from the computation. @@ -340,8 +340,7 @@ def nightly(test_case): def is_torch_compile(test_case): """ Decorator marking a test as a torch.compile test. These tests can be filtered using: - pytest -m "not compile" to skip - pytest -m compile to run only these tests + pytest -m "not compile" to skip pytest -m compile to run only these tests """ return pytest.mark.compile(test_case) @@ -349,8 +348,7 @@ def is_torch_compile(test_case): def is_single_file(test_case): """ Decorator marking a test as a single file loading test. These tests can be filtered using: - pytest -m "not single_file" to skip - pytest -m single_file to run only these tests + pytest -m "not single_file" to skip pytest -m single_file to run only these tests """ return pytest.mark.single_file(test_case) @@ -358,8 +356,7 @@ def is_single_file(test_case): def is_lora(test_case): """ Decorator marking a test as a LoRA test. These tests can be filtered using: - pytest -m "not lora" to skip - pytest -m lora to run only these tests + pytest -m "not lora" to skip pytest -m lora to run only these tests """ return pytest.mark.lora(test_case) @@ -367,8 +364,7 @@ def is_lora(test_case): def is_ip_adapter(test_case): """ Decorator marking a test as an IP Adapter test. These tests can be filtered using: - pytest -m "not ip_adapter" to skip - pytest -m ip_adapter to run only these tests + pytest -m "not ip_adapter" to skip pytest -m ip_adapter to run only these tests """ return pytest.mark.ip_adapter(test_case) @@ -376,8 +372,7 @@ def is_ip_adapter(test_case): def is_training(test_case): """ Decorator marking a test as a training test. These tests can be filtered using: - pytest -m "not training" to skip - pytest -m training to run only these tests + pytest -m "not training" to skip pytest -m training to run only these tests """ return pytest.mark.training(test_case) @@ -385,8 +380,7 @@ def is_training(test_case): def is_attention(test_case): """ Decorator marking a test as an attention test. These tests can be filtered using: - pytest -m "not attention" to skip - pytest -m attention to run only these tests + pytest -m "not attention" to skip pytest -m attention to run only these tests """ return pytest.mark.attention(test_case) @@ -394,8 +388,7 @@ def is_attention(test_case): def is_memory(test_case): """ Decorator marking a test as a memory optimization test. These tests can be filtered using: - pytest -m "not memory" to skip - pytest -m memory to run only these tests + pytest -m "not memory" to skip pytest -m memory to run only these tests """ return pytest.mark.memory(test_case) @@ -403,8 +396,7 @@ def is_memory(test_case): def is_cpu_offload(test_case): """ Decorator marking a test as a CPU offload test. These tests can be filtered using: - pytest -m "not cpu_offload" to skip - pytest -m cpu_offload to run only these tests + pytest -m "not cpu_offload" to skip pytest -m cpu_offload to run only these tests """ return pytest.mark.cpu_offload(test_case) @@ -412,8 +404,7 @@ def is_cpu_offload(test_case): def is_group_offload(test_case): """ Decorator marking a test as a group offload test. These tests can be filtered using: - pytest -m "not group_offload" to skip - pytest -m group_offload to run only these tests + pytest -m "not group_offload" to skip pytest -m group_offload to run only these tests """ return pytest.mark.group_offload(test_case) @@ -421,8 +412,7 @@ def is_group_offload(test_case): def is_quantization(test_case): """ Decorator marking a test as a quantization test. These tests can be filtered using: - pytest -m "not quantization" to skip - pytest -m quantization to run only these tests + pytest -m "not quantization" to skip pytest -m quantization to run only these tests """ return pytest.mark.quantization(test_case) @@ -430,8 +420,7 @@ def is_quantization(test_case): def is_bitsandbytes(test_case): """ Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using: - pytest -m "not bitsandbytes" to skip - pytest -m bitsandbytes to run only these tests + pytest -m "not bitsandbytes" to skip pytest -m bitsandbytes to run only these tests """ return pytest.mark.bitsandbytes(test_case) @@ -439,8 +428,7 @@ def is_bitsandbytes(test_case): def is_quanto(test_case): """ Decorator marking a test as a Quanto quantization test. These tests can be filtered using: - pytest -m "not quanto" to skip - pytest -m quanto to run only these tests + pytest -m "not quanto" to skip pytest -m quanto to run only these tests """ return pytest.mark.quanto(test_case) @@ -448,8 +436,7 @@ def is_quanto(test_case): def is_torchao(test_case): """ Decorator marking a test as a TorchAO quantization test. These tests can be filtered using: - pytest -m "not torchao" to skip - pytest -m torchao to run only these tests + pytest -m "not torchao" to skip pytest -m torchao to run only these tests """ return pytest.mark.torchao(test_case) @@ -457,8 +444,7 @@ def is_torchao(test_case): def is_gguf(test_case): """ Decorator marking a test as a GGUF quantization test. These tests can be filtered using: - pytest -m "not gguf" to skip - pytest -m gguf to run only these tests + pytest -m "not gguf" to skip pytest -m gguf to run only these tests """ return pytest.mark.gguf(test_case) @@ -466,8 +452,7 @@ def is_gguf(test_case): def is_modelopt(test_case): """ Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using: - pytest -m "not modelopt" to skip - pytest -m modelopt to run only these tests + pytest -m "not modelopt" to skip pytest -m modelopt to run only these tests """ return pytest.mark.modelopt(test_case) @@ -475,8 +460,7 @@ def is_modelopt(test_case): def is_context_parallel(test_case): """ Decorator marking a test as a context parallel inference test. These tests can be filtered using: - pytest -m "not context_parallel" to skip - pytest -m context_parallel to run only these tests + pytest -m "not context_parallel" to skip pytest -m context_parallel to run only these tests """ return pytest.mark.context_parallel(test_case) @@ -484,8 +468,7 @@ def is_context_parallel(test_case): def is_cache(test_case): """ Decorator marking a test as a cache test. These tests can be filtered using: - pytest -m "not cache" to skip - pytest -m cache to run only these tests + pytest -m "not cache" to skip pytest -m cache to run only these tests """ return pytest.mark.cache(test_case) @@ -555,6 +538,14 @@ def require_torch_accelerator(test_case): return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case) +def require_torch_neuron(test_case): + """Decorator marking a test that requires a Neuron device (Trainium/Inferentia).""" + return pytest.mark.skipif( + not (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available()), + reason="test requires Neuron device", + )(test_case) + + def require_torch_multi_gpu(test_case): """ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without @@ -1330,7 +1321,7 @@ class CaptureLogger: Example: ```python >>> from diffusers import logging - >>> from diffusers..testing_utils import CaptureLogger + >>> from diffusers.utils.testing_utils import CaptureLogger >>> msg = "Testing 1, 2, 3" >>> logging.set_verbosity_info() @@ -1506,6 +1497,7 @@ def _is_torch_fp64_available(device): BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0 BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize + BACKEND_SUPPORTS_TRAINING[_neuron_device] = False # This dispatches a defined function according to the accelerator from the function definitions. From a3b6ccb4755abb1759d14c0797ef733935675ced Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 20 May 2026 15:26:03 +0000 Subject: [PATCH 16/24] review:apply suggestion for the fix of index_for_timtestep --- .../pipeline_stable_diffusion_xl.py | 17 +++-------------- .../test_stable_diffusion_xl.py | 3 ++- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 9e12e9459369..adc29fae53bc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1188,6 +1188,7 @@ def __call__( ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) + self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -1195,13 +1196,7 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - - # For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region. - # index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs. - if is_neuron_device: - latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device) - else: - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -1231,13 +1226,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - # For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device. - if is_neuron_device: - latents = self.scheduler.step( - noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False - )[0].to(device) - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 0b00731bd695..c9afdc3209cd 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -1024,4 +1024,5 @@ def test_sdxl_turbo_512(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 512, 512, 3)) self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") - self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") + expected_slice = np.array([0.3524, 0.3160, 0.3652, 0.3316, 0.3376, 0.3315, 0.3042, 0.3102, 0.3449]) + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 5e-2) From 86e550a906435d6a660039dd22b5eab331cc4af6 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Wed, 20 May 2026 15:26:48 +0000 Subject: [PATCH 17/24] review: stronger guard on image slice --- tests/pipelines/flux2/test_pipeline_flux2_klein.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py index 7499cb03967c..377f02dc9aa1 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -240,4 +240,5 @@ def test_flux2_klein_inference_512(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 512, 512, 3)) self.assertTrue(np.all((image >= 0.0) & (image <= 1.0)), "Pixel values must be in [0, 1]") - self.assertGreater(image_slice.std(), 0.01, "Output image should have meaningful variance") + expected_slice = np.array([0.3652, 0.3574, 0.3633, 0.4102, 0.4062, 0.4043, 0.4453, 0.4355, 0.4570]) + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 5e-2) From 45a881246788a1dd3097f5aff736bac9e1848aff Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Wed, 20 May 2026 17:31:49 +0200 Subject: [PATCH 18/24] Apply suggestions from code review Co-authored-by: YiYi Xu --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 299265345e2a..0fa44a15b53a 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -862,7 +862,7 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - is_neuron_device = hasattr(device, "type") and device.type == "neuron" + is_neuron_device = device.type == "neuron" if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: From d5ee083f773be2d8c241198ae20f40629663fa1b Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 26 May 2026 14:42:03 +0000 Subject: [PATCH 19/24] fix: when set_begin_index not implemented for scheduler --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index adc29fae53bc..bcf663a6105f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1188,7 +1188,8 @@ def __call__( ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) - self.scheduler.set_begin_index(0) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: From 4b3ff51030cf55720a307728c3063a00f4b4cad7 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 1 Jun 2026 14:57:45 +0000 Subject: [PATCH 20/24] review:add maybe_adjust_dtype_for_device and apply to all models with downcasting needs --- src/diffusers/models/controlnets/controlnet.py | 10 ++++------ .../models/controlnets/controlnet_sparsectrl.py | 10 ++++------ src/diffusers/models/controlnets/controlnet_union.py | 10 ++++------ src/diffusers/models/controlnets/controlnet_xs.py | 11 ++++------- src/diffusers/models/embeddings.py | 3 ++- .../models/transformers/transformer_anyflow.py | 9 +++------ .../models/transformers/transformer_anyflow_far.py | 9 +++------ src/diffusers/models/transformers/transformer_bria.py | 8 +++----- .../models/transformers/transformer_bria_fibo.py | 5 ++--- src/diffusers/models/transformers/transformer_flux.py | 6 ++---- .../models/transformers/transformer_flux2.py | 6 ++---- .../models/transformers/transformer_hidream_image.py | 7 ++----- .../models/transformers/transformer_longcat_image.py | 6 ++---- .../models/transformers/transformer_motif_video.py | 5 ++--- .../models/transformers/transformer_ovis_image.py | 6 ++---- src/diffusers/models/transformers/transformer_prx.py | 5 ++--- src/diffusers/models/unets/unet_2d_condition.py | 11 ++++------- src/diffusers/models/unets/unet_3d_condition.py | 10 ++++------ src/diffusers/models/unets/unet_i2vgen_xl.py | 10 ++++------ src/diffusers/models/unets/unet_motion_model.py | 11 ++++------- .../models/unets/unet_spatio_temporal_condition.py | 10 ++++------ src/diffusers/utils/torch_utils.py | 11 +++++++++++ 22 files changed, 74 insertions(+), 105 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 8c2ff2fdd123..d2030f4e7044 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -22,6 +22,7 @@ from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -675,12 +676,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 715d9dad2c34..dda653ea7a50 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -604,12 +605,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 8e434ba1f250..8dfcb1795618 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -19,6 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -620,12 +621,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 6221d4878de9..efc242f332b9 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging -from ...utils.torch_utils import apply_freeu +from ...utils.torch_utils import apply_freeu, maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -1014,12 +1014,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bcd192c1f166..c5eaa746252e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -19,6 +19,7 @@ from torch import nn from ..utils import deprecate +from ..utils.torch_utils import maybe_adjust_dtype_for_device from .activations import FP32SiLU, get_activation from .attention_processor import Attention @@ -346,7 +347,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin # Auto-detect appropriate dtype if not specified if dtype is None: - dtype = torch.float32 if pos.device.type == "mps" else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype) omega /= embed_dim / 2.0 diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py index 2ac554419e5e..231c2113d7e6 100644 --- a/src/diffusers/models/transformers/transformer_anyflow.py +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -28,6 +28,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -41,9 +42,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): # MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices. - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states) @@ -338,9 +337,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor: if self._freqs_cache is not None and self._freqs_cache[0] == cache_key: return self._freqs_cache[1] - is_mps = device.type == "mps" - is_npu = device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device) h_dim = w_dim = 2 * (self.attention_head_dim // 6) t_dim = self.attention_head_dim - h_dim - w_dim diff --git a/src/diffusers/models/transformers/transformer_anyflow_far.py b/src/diffusers/models/transformers/transformer_anyflow_far.py index a40e2fafcb61..78865dfab731 100644 --- a/src/diffusers/models/transformers/transformer_anyflow_far.py +++ b/src/diffusers/models/transformers/transformer_anyflow_far.py @@ -30,6 +30,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import BaseOutput, apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -44,9 +45,7 @@ # Copied from diffusers.models.transformers.transformer_anyflow.apply_rotary_emb def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): # MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices. - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + rotary_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states) @@ -647,9 +646,7 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor: if self._freqs_cache is not None and self._freqs_cache[0] == cache_key: return self._freqs_cache[1] - is_mps = device.type == "mps" - is_npu = device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, device) h_dim = w_dim = 2 * (self.attention_head_dim // 6) t_dim = self.attention_head_dim - h_dim - w_dim diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 8e79046508e9..ff4261343ab2 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -9,7 +9,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -276,8 +276,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], @@ -344,8 +343,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 31c826bbf6b2..7b4ac1a3bedf 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -25,7 +25,7 @@ apply_lora_scale, logging, ) -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -222,8 +222,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 13177bc67878..94857dffacb2 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -503,9 +503,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 595dcc7fe74a..c3fa6ac141f3 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import BaseOutput, apply_lora_scale, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -959,10 +960,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - is_neuron = ids.device.type == "neuron" - freqs_dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] for i in range(len(self.axes_dim)): cos, sin = get_1d_rotary_pos_embed( diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index b6c0e3533657..bd69d5de68ca 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -9,7 +9,7 @@ from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin from ...utils import apply_lora_scale, deprecate, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps @@ -95,10 +95,7 @@ def forward(self, latent) -> torch.Tensor: def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - is_mps = pos.device.type == "mps" - is_npu = pos.device.type == "npu" - - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index fe4713ea02db..7b842c42132d 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -361,9 +361,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_motif_video.py b/src/diffusers/models/transformers/transformer_motif_video.py index c0908f198f90..fb3ff0666f95 100644 --- a/src/diffusers/models/transformers/transformer_motif_video.py +++ b/src/diffusers/models/transformers/transformer_motif_video.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -483,9 +484,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: grid = torch.stack(grid, dim=0) freqs = [] - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) for i in range(3): freq = get_1d_rotary_pos_embed( dim=self.rope_dim[i], diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 8280a4871f10..7a9df427e0b9 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import maybe_adjust_dtype_for_device, maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -364,9 +364,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: cos_out = [] sin_out = [] pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, ids.device) for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 3c8e8ae4e2c9..37553dd44c87 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -19,6 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..embeddings import get_timestep_embedding @@ -275,9 +276,7 @@ def __init__(self, dim: int, theta: int, axes_dim: list[int]): def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 - is_mps = pos.device.type == "mps" - is_npu = pos.device.type == "npu" - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index b533bef35414..38a41a3dc93f 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -26,6 +26,7 @@ deprecate, logging, ) +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -853,13 +854,9 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - is_neuron = sample.device.type == "neuron" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 5006e48feb46..0d15e93da68f 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -547,12 +548,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 30fb46095326..9e7841f95e58 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -20,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import Attention, AttentionMixin, FeedForward from ..attention_processor import ( @@ -499,12 +500,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timesteps, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timesteps, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7c4201facacf..6904cc05f10c 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput, apply_lora_scale, deprecate, logging -from ...utils.torch_utils import apply_freeu +from ...utils.torch_utils import apply_freeu, maybe_adjust_dtype_for_device from ..attention import AttentionMixin, BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -1952,12 +1952,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index eddeb9826b0c..d38be0b0675f 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -6,6 +6,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import BaseOutput, logging +from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor from ..embeddings import TimestepEmbedding, Timesteps @@ -335,12 +336,9 @@ def forward( if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = maybe_adjust_dtype_for_device( + torch.float64 if isinstance(timestep, float) else torch.int64, sample.device + ) timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index fdf5710ebad3..8c641ab460ae 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -169,6 +169,17 @@ def backend_supports_training(device: str): return BACKEND_SUPPORTS_TRAINING[device] +_FP64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) +_INT64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) +_DTYPE_DOWNCAST = {torch.float64: torch.float32, torch.int64: torch.int32} +_DTYPE_UNSUPPORTED_DEVICES = {torch.float64: _FP64_UNSUPPORTED_DEVICES, torch.int64: _INT64_UNSUPPORTED_DEVICES} + + +def maybe_adjust_dtype_for_device(dtype: "torch.dtype", device: "torch.device") -> "torch.dtype": + unsupported = _DTYPE_UNSUPPORTED_DEVICES.get(dtype) + return _DTYPE_DOWNCAST[dtype] if unsupported and device.type in unsupported else dtype + + def randn_tensor( shape: tuple | list, generator: list["torch.Generator"] | "torch.Generator" | None = None, From 888a936a48bef3bc105aa37f8fd21170fa1a2056 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 1 Jun 2026 15:00:53 +0000 Subject: [PATCH 21/24] fix: dependency --- src/diffusers/utils/torch_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 8c641ab460ae..263334dce8cd 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -103,6 +103,12 @@ "neuron": getattr(getattr(torch, "neuron", None), "synchronize", None), "default": None, } + + _FP64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) + _INT64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) + _DTYPE_DOWNCAST = {torch.float64: torch.float32, torch.int64: torch.int32} + _DTYPE_UNSUPPORTED_DEVICES = {torch.float64: _FP64_UNSUPPORTED_DEVICES, torch.int64: _INT64_UNSUPPORTED_DEVICES} + logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: @@ -169,12 +175,6 @@ def backend_supports_training(device: str): return BACKEND_SUPPORTS_TRAINING[device] -_FP64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) -_INT64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"}) -_DTYPE_DOWNCAST = {torch.float64: torch.float32, torch.int64: torch.int32} -_DTYPE_UNSUPPORTED_DEVICES = {torch.float64: _FP64_UNSUPPORTED_DEVICES, torch.int64: _INT64_UNSUPPORTED_DEVICES} - - def maybe_adjust_dtype_for_device(dtype: "torch.dtype", device: "torch.device") -> "torch.dtype": unsupported = _DTYPE_UNSUPPORTED_DEVICES.get(dtype) return _DTYPE_DOWNCAST[dtype] if unsupported and device.type in unsupported else dtype From e98f17ef849dff6e7ca5317d04799f83fc69b227 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Tue, 2 Jun 2026 11:24:26 +0200 Subject: [PATCH 22/24] Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Co-authored-by: YiYi Xu --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index cd0b37744285..d08b6c5a5973 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1095,7 +1095,7 @@ def __call__( ) # 4. Prepare timesteps - is_neuron_device = hasattr(device, "type") and device.type == "neuron" + is_neuron_device = device.type == "neuron" if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: From ab82699d90935a51244c5625e6a416ceda92a053 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 2 Jun 2026 09:38:55 +0000 Subject: [PATCH 23/24] review: apply maybe_adjust_dtype_for_device in pixart pipe --- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 0fa44a15b53a..11eaeaca7fc0 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -33,7 +33,7 @@ logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import maybe_adjust_dtype_for_device, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -913,13 +913,10 @@ def __call__( if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - is_npu = latent_model_input.device.type == "npu" - is_neuron = latent_model_input.device.type == "neuron" if isinstance(current_timestep, float): - dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 + dtype = maybe_adjust_dtype_for_device(torch.float64, latent_model_input.device) else: - dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 + dtype = maybe_adjust_dtype_for_device(torch.int64, latent_model_input.device) current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) From 20af97babc7b4b6254efd8d0feaa43eee0335bd4 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 2 Jun 2026 09:44:47 +0000 Subject: [PATCH 24/24] review: update .ai/models.md --- .ai/models.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.ai/models.md b/.ai/models.md index 954cbd343781..6e37e742ae57 100644 --- a/.ai/models.md +++ b/.ai/models.md @@ -163,14 +163,14 @@ Boolean gate. If `False` (default), calling that method raises `ValueError`. All 3. **Capability flags without matching implementation.** for example, `_supports_gradient_checkpointing = True` only takes effect if `forward` actually has `if self.gradient_checkpointing:` branches calling `self._gradient_checkpointing_func` on each block. Setting the flag without those branches means training code silently no-ops the checkpoint and runs a normal forward. 4. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`. -5. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows: +5. **`torch.float64` anywhere in the model.** MPS, NPU, and Neuron backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows: - **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on. - - **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo: + - **Only if float32 visibly degrades output, use the `maybe_adjust_dtype_for_device` helper** from `diffusers.utils.torch_utils`. It centralizes the device-specific dtype downcast (float64→float32, int64→int32) for all restricted backends (mps, npu, neuron): ```python - is_mps = hidden_states.device.type == "mps" - is_npu = hidden_states.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + from diffusers.utils.torch_utils import maybe_adjust_dtype_for_device + + freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device) ``` - See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model. + See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py`, and `pipeline_pixart_alpha.py` for reference usages. Never leave an unconditional `torch.float64` in the model. 6. **Using `torch.empty`.** - Do not use `torch.empty` to initialize parameters. Use `torch.zeros` or `torch.ones`, instead. \ No newline at end of file