Skip to content

[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289

Open
JingyaHuang wants to merge 44 commits into
huggingface:mainfrom
JingyaHuang:add-neuron-backend
Open

[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289
JingyaHuang wants to merge 44 commits into
huggingface:mainfrom
JingyaHuang:add-neuron-backend

Conversation

@JingyaHuang
Copy link
Copy Markdown
Contributor

@JingyaHuang JingyaHuang commented Mar 19, 2026

What does this PR do?

This PR adds AWS Neuron (Trainium/Inferentia) as an officially supported compute backend in Diffusers, on par with existing backends like CUDA, MPS, XPU, and MLU.

Changes

  • import_utils.py — adds is_torch_neuronx_available() detection, following the existing pattern for optional backends.
  • torch_utils.py — registers "neuron" in all backend dispatch tables (BACKEND_SUPPORTS_TRAINING, BACKEND_EMPTY_CACHE, BACKEND_DEVICE_COUNT, BACKEND_MANUAL_SEED, etc.) and adds a randn_tensor workaround since Neuron/XLA does not support creating random tensors directly on device (falls back to CPU).
  • utils/init.py — exports is_torch_neuronx_available.
  • pipeline_utils.py — adds two new DiffusionPipeline methods:
    • enable_neuron_compile(model_names, cache_dir, fullgraph) — wraps pipeline nn.Module components with torch.compile(backend="neuron") for whole-graph NEFF compilation. Supports optional NEFF caching via TORCH_NEURONX_NEFF_CACHE_DIR.
    • neuron_warmup(*args, **kwargs) — runs a single dummy forward pass to trigger upfront neuronx-cc compilation before timed inference.

Usage

  • Eager mode
import torch                                                                                                             
import torch_neuronx  # noqa: F401 — registers torch.neuron                                                            
                                                                                                                           
from diffusers import AutoPipelineForText2Image                                                                          
                                                                                                                           
# Load and move to Neuron device                                                                                         
pipe = AutoPipelineForText2Image.from_pretrained(                                                                        
    "stabilityai/sdxl-turbo",                                                                                            
    torch_dtype=torch.bfloat16,                           
    variant="fp16",                                                                                                      
)
pipe = pipe.to(torch.neuron.current_device())                                                                            
                                                                                                                         
# Warmup                                                                   
pipe(prompt="warmup", height=512, width=512, num_inference_steps=1, guidance_scale=0.0)                                                                                                                        
                                                          
# Inference                                                                                               
image = pipe(                                             
    prompt="a golden retriever surfing a wave, photorealistic",                                                          
    height=512,
    width=512,                                                                                                           
    num_inference_steps=1, 
    guidance_scale=0.0,                                                                    
).images[0]                                                                                                              
                                                                                                                         
image.save("output.png") 

Validation

So far we validated the following models, the idea is to manually validated a part of representative models, and ensure the model coverage with pur CIs afterward.

  • pixart
  • sdxl
  • flux2-klein-4B (able to run within one NeuronCore w/o. tp under eager mode for resolution 1024x1024)

Next Steps

  • Enable torch.compile on Neuron device
  • Add tensor parallel support for memory-bound devices like neuron
  • Tackle the compatibility of diffusers+nki kernels lib to boost the performance on neuron.
  • Support sequence parallel on Neuron for video models.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added lora examples size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 9, 2026
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
@JingyaHuang
Copy link
Copy Markdown
Contributor Author

JingyaHuang commented May 18, 2026

Thanks @sayakpaul, I addressed the review comments, shall I merge?

Comment thread tests/pipelines/flux2/test_pipeline_flux2_klein.py
Comment thread tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! i left some questions/comments

Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Outdated
Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Comment thread src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Outdated

# 4. Prepare timesteps
if XLA_AVAILABLE:
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude-2-serge can you trace the code for how device is defined here and let me know if there is any code path that _exeuction_devide would not return torch.device type?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up inline comments are something we're working with @tarekziade. Should land pretty soon.

Copy link
Copy Markdown
Contributor Author

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed the comments! @sayakpaul @yiyixuxu

Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Outdated
Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Comment thread tests/pipelines/flux2/test_pipeline_flux2_klein.py
Comment thread src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Outdated
@JingyaHuang
Copy link
Copy Markdown
Contributor Author

The failing CI checks are unrelated to this PR, ok for us to merge? @sayakpaul @yiyixuxu

@JingyaHuang JingyaHuang requested review from sayakpaul and yiyixuxu May 26, 2026 15:18
Comment thread src/diffusers/models/transformers/transformer_flux2.py
@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Jun 2, 2026

@askserge can you do a review here? in addition to review-rules, also these 3 commetns from my end

  1. for maybe_adjust_dtype_for_device, check if backward-compatible everywhere
  2. suggest how we update agent doc for maybe_adjust_dtype_for_device
  3. check the sdxl pipeline, of set_begin_index() can be hardcoded at 0 for all scenarios? for example denoising_end

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Jun 2, 2026

@claude can you do a review here? in addition to review-rules, also these 3 commetns from my end

  1. for maybe_adjust_dtype_for_device, check if backward-compatible everywhere
  2. suggest how we update agent doc for maybe_adjust_dtype_for_device
  3. check the sdxl pipeline, of set_begin_index() can be hardcoded at 0 for all scenarios? for example denoising_end, please answer this by traceing the code

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 2, 2026

Claude finished @yiyixuxu's task in 6m 53s —— View job


Review of [Neuron] Add AWS Neuron backend

  • Read review guidelines (AGENTS.md, models.md, pipelines.md)
  • Review maybe_adjust_dtype_for_device for backward compatibility
  • Suggest agent doc update for maybe_adjust_dtype_for_device
  • Trace set_begin_index() in SDXL pipeline for denoising_end scenarios
  • General code review of all changes under src/diffusers/

1. maybe_adjust_dtype_for_device — backward compatibility

Verdict: Backward-compatible everywhere, with one bonus bug fix on NPU.

The new centralized utility (src/diffusers/utils/torch_utils.py:107-110,178-180):

_FP64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"})
_INT64_UNSUPPORTED_DEVICES = frozenset({"mps", "npu", "neuron"})
_DTYPE_DOWNCAST = {torch.float64: torch.float32, torch.int64: torch.int32}

replaces the scattered inline pattern:

is_mps = device.type == "mps"
is_npu = device.type == "npu"
dtype = torch.float32 if (is_mps or is_npu) else torch.float64

Device-by-device analysis:

Device Old behavior New behavior Status
MPS float64→float32, int64→int32 Same No change
NPU float64→float32, int64→int32 (in unets/controlnets) Same No change
NPU float64→float64 (in get_timestep_embedding, embeddings.py:349) float64→float32 Bug fix — old code only checked mps, missed npu
Neuron No handling float64→float32, int64→int32 New (desired)
CUDA/CPU/XPU float64/int64 unchanged Same No change

The get_timestep_embedding fix for NPU is a positive side-effect: models.md gotcha #5 documents that NPU does not support float64, so the old code was buggy.

One thing to note: maybe_adjust_dtype_for_device expects device to have a .type attribute (i.e., a torch.device object). All current callers pass tensor .device properties, so this is fine. But if any future caller passes a string, it would break. The type annotation makes this clear.

2. Suggested agent doc update for maybe_adjust_dtype_for_device

models.md gotcha #5 currently documents the old inline pattern. It should be updated to reference the new utility. Suggested change:

In .ai/models.md, replace the current gotcha #5 fallback pattern:

# Old pattern (no longer preferred):
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64

with:

from ...utils.torch_utils import maybe_adjust_dtype_for_device

freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)

The full gotcha #5 text should remain (explaining why float64 is problematic, the "default: just use float32" guidance, and when to use the device-gated pattern), but the code example and the reference models should be updated. The reference list (transformer_flux.py, transformer_flux2.py, etc.) is still valid since those files now use the new pattern.

3. set_begin_index(0) in SDXL pipeline — code trace

Verdict: Yes, set_begin_index(0) can be safely hardcoded at 0 for all scenarios, including denoising_end.

Here's the full trace:

How set_begin_index works in schedulers:

set_begin_index(idx) stores idx in _begin_index. This is used by _init_step_index() (called on the first scheduler.step()) to set the initial step index:

# e.g. scheduling_dpmsolver_multistep.py, scheduling_euler_discrete.py
def _init_step_index(self, timestep):
    if self.begin_index is None:
        self._step_index = self.index_for_timestep(timestep)
    else:
        self._step_index = self._begin_index   # ← uses begin_index directly

Why 0 is correct for the base SDXL pipeline, even with denoising_end:

  1. denoising_end truncates the timesteps list, not the index. Lines 1170-1183:

    discrete_timestep_cutoff = int(round(
        self.scheduler.config.num_train_timesteps
        - (self.denoising_end * self.scheduler.config.num_train_timesteps)
    ))
    num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
    timesteps = timesteps[:num_inference_steps]  # ← truncation happens here

    The base pipeline always starts from the beginning of the timestep schedule and stops early. It never starts from the middle. So index 0 is correct.

  2. begin_index only matters for add_noise(), which the base pipeline doesn't call during the denoising loop. The denoising loop only calls scheduler.step(), which uses _init_step_index() on the first call. The begin_index is primarily relevant for img2img pipelines where add_noise() is called with a partially-denoised input.

  3. The img2img pipeline correctly uses non-zero set_begin_index. In pipeline_stable_diffusion_xl_img2img.py, get_timesteps() computes t_start from strength and calls set_begin_index(t_start * self.scheduler.order). This is correct because img2img starts denoising from the middle of the schedule.

  4. Custom timesteps are also fine. When custom timesteps are passed, retrieve_timesteps() calls scheduler.set_timesteps(timesteps=...) which resets _begin_index to None. Then set_begin_index(0) sets it back to 0, which is correct since the denoising loop always iterates from the start of whatever timesteps list was provided.

In short: The base pipeline's set_begin_index(0) is correct by construction — it always iterates from timesteps[0], whether the list is full or truncated by denoising_end.


General review

Most issues from the previous Serge review appear to have been addressed (dead atol, env var leakage, BACKEND_SUPPORTS_TRAINING missing in tests, pixart device check). A few items remain:

Correctness issues

1. SDXL timestep workaround is pipeline-level, will need replication (pipeline_stable_diffusion_xl.py:1211)

t_unet = t.to(torch.float32).to(device) if is_neuron_device else t

This casts the scheduler's timestep to float32 before passing to the UNet, because Neuron doesn't support int64. The issue is that the UNet's get_time_embed method (unet_2d_condition.py:854-860) only adjusts dtype for non-tensor inputs — if t is already a tensor (which it is from the scheduler), the UNet passes it through unchanged.

This means every pipeline using a UNet on Neuron will need this same workaround. Consider whether the dtype adjustment should be moved into the UNet's get_time_embed instead:

def get_time_embed(self, sample, timestep):
    timesteps = timestep
    if not torch.is_tensor(timesteps):
        dtype = maybe_adjust_dtype_for_device(...)
        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+   else:
+       timesteps = timesteps.to(maybe_adjust_dtype_for_device(timesteps.dtype, sample.device))

This would make the neuron dtype handling automatic for all pipelines. Not blocking, but worth considering for the "next steps" list.

2. Minor inconsistency in neuron device detection (pipeline_stable_diffusion_xl.py:1098 vs pipeline_pixart_alpha.py:865)

SDXL:

is_neuron_device = hasattr(device, "type") and device.type == "neuron"

Pixart:

is_neuron_device = device.type == "neuron"

Since _execution_device always returns a torch.device, the hasattr guard is unnecessary. Both work, but the inconsistency could confuse future contributors. Minor — not blocking.

Remaining old-style pattern (pre-existing, not introduced by this PR)

src/diffusers/models/embeddings.py:1168-1170 in get_1d_rotary_pos_embed still has:

is_npu = freqs.device.type == "npu"
if is_npu:
    freqs = freqs.float()

This is a pre-existing old-style check that wasn't migrated. It converts freqs to float32 on NPU before .cos()/.sin(). Since the caller already passes freqs_dtype through maybe_adjust_dtype_for_device, the freqs tensor would already be float32 on NPU, making this a no-op. However, on Neuron, this check is missing — if freqs were float64 (it won't be since callers pass freqs_dtype through the utility), the .cos() would fail.

Not blocking since the callers handle it, but this old-style check could be cleaned up or extended to also cover neuron for completeness.

Tests

The integration tests (test_pipeline_flux2_klein.py, test_stable_diffusion_xl.py) look reasonable:

  • Properly gated with @require_torch_neuron
  • Environment variables are saved/restored in setUp/tearDown
  • Expected slices are compared with 5e-2 tolerance

The test_pixart.py changes appropriately relax tolerance for Neuron (1e-2 vs 1e-4) with a clear comment explaining why (bfloat16 precision).


Summary

The core plumbing is solid. maybe_adjust_dtype_for_device is a clean centralization that's backward-compatible. The set_begin_index(0) hardcoding in SDXL is correct. The main suggestion is to consider moving the timestep dtype workaround from the pipeline into the UNet to avoid per-pipeline replication, and to update .ai/models.md gotcha #5 to reference the new utility.


# 4. Prepare timesteps
if XLA_AVAILABLE:
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
is_neuron_device = device.type == "neuron"

think it's ok, no?

is_neuron = latent_model_input.device.type == "neuron"
if isinstance(current_timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use the maybe_adjust_dtype_for_device here too?

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for working on this!
Can you also update our agent.md on the new helper function you added here? https://github.com/huggingface/diffusers/blob/main/.ai/models.md (Claude CI review #2 #13289 (comment))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants