[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289
[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289JingyaHuang wants to merge 44 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… into add-neuron-backend
… into add-neuron-backend
|
Thanks @sayakpaul, I addressed the review comments, shall I merge? |
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks! i left some questions/comments
|
|
||
| # 4. Prepare timesteps | ||
| if XLA_AVAILABLE: | ||
| is_neuron_device = hasattr(device, "type") and device.type == "neuron" |
There was a problem hiding this comment.
@claude-2-serge can you trace the code for how device is defined here and let me know if there is any code path that _exeuction_devide would not return torch.device type?
There was a problem hiding this comment.
Follow-up inline comments are something we're working with @tarekziade. Should land pretty soon.
… into add-neuron-backend
Co-authored-by: YiYi Xu <yixu310@gmail.com>
JingyaHuang
left a comment
There was a problem hiding this comment.
addressed the comments! @sayakpaul @yiyixuxu
… into add-neuron-backend
|
The failing CI checks are unrelated to this PR, ok for us to merge? @sayakpaul @yiyixuxu |
|
@askserge can you do a review here? in addition to review-rules, also these 3 commetns from my end
|
|
@claude can you do a review here? in addition to review-rules, also these 3 commetns from my end
|
|
Claude finished @yiyixuxu's task in 6m 53s —— View job Review of [Neuron] Add AWS Neuron backend
1.
|
| Device | Old behavior | New behavior | Status |
|---|---|---|---|
| MPS | float64→float32, int64→int32 | Same | No change |
| NPU | float64→float32, int64→int32 (in unets/controlnets) | Same | No change |
| NPU | float64→float64 (in get_timestep_embedding, embeddings.py:349) |
float64→float32 | Bug fix — old code only checked mps, missed npu |
| Neuron | No handling | float64→float32, int64→int32 | New (desired) |
| CUDA/CPU/XPU | float64/int64 unchanged | Same | No change |
The get_timestep_embedding fix for NPU is a positive side-effect: models.md gotcha #5 documents that NPU does not support float64, so the old code was buggy.
One thing to note: maybe_adjust_dtype_for_device expects device to have a .type attribute (i.e., a torch.device object). All current callers pass tensor .device properties, so this is fine. But if any future caller passes a string, it would break. The type annotation makes this clear.
2. Suggested agent doc update for maybe_adjust_dtype_for_device
models.md gotcha #5 currently documents the old inline pattern. It should be updated to reference the new utility. Suggested change:
In .ai/models.md, replace the current gotcha #5 fallback pattern:
# Old pattern (no longer preferred):
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64with:
from ...utils.torch_utils import maybe_adjust_dtype_for_device
freqs_dtype = maybe_adjust_dtype_for_device(torch.float64, hidden_states.device)The full gotcha #5 text should remain (explaining why float64 is problematic, the "default: just use float32" guidance, and when to use the device-gated pattern), but the code example and the reference models should be updated. The reference list (transformer_flux.py, transformer_flux2.py, etc.) is still valid since those files now use the new pattern.
3. set_begin_index(0) in SDXL pipeline — code trace
Verdict: Yes, set_begin_index(0) can be safely hardcoded at 0 for all scenarios, including denoising_end.
Here's the full trace:
How set_begin_index works in schedulers:
set_begin_index(idx) stores idx in _begin_index. This is used by _init_step_index() (called on the first scheduler.step()) to set the initial step index:
# e.g. scheduling_dpmsolver_multistep.py, scheduling_euler_discrete.py
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index # ← uses begin_index directlyWhy 0 is correct for the base SDXL pipeline, even with denoising_end:
-
denoising_endtruncates the timesteps list, not the index. Lines 1170-1183:discrete_timestep_cutoff = int(round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) )) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # ← truncation happens here
The base pipeline always starts from the beginning of the timestep schedule and stops early. It never starts from the middle. So index 0 is correct.
-
begin_indexonly matters foradd_noise(), which the base pipeline doesn't call during the denoising loop. The denoising loop only callsscheduler.step(), which uses_init_step_index()on the first call. Thebegin_indexis primarily relevant for img2img pipelines whereadd_noise()is called with a partially-denoised input. -
The img2img pipeline correctly uses non-zero
set_begin_index. Inpipeline_stable_diffusion_xl_img2img.py,get_timesteps()computest_startfromstrengthand callsset_begin_index(t_start * self.scheduler.order). This is correct because img2img starts denoising from the middle of the schedule. -
Custom
timestepsare also fine. When custom timesteps are passed,retrieve_timesteps()callsscheduler.set_timesteps(timesteps=...)which resets_begin_indextoNone. Thenset_begin_index(0)sets it back to 0, which is correct since the denoising loop always iterates from the start of whatever timesteps list was provided.
In short: The base pipeline's set_begin_index(0) is correct by construction — it always iterates from timesteps[0], whether the list is full or truncated by denoising_end.
General review
Most issues from the previous Serge review appear to have been addressed (dead atol, env var leakage, BACKEND_SUPPORTS_TRAINING missing in tests, pixart device check). A few items remain:
Correctness issues
1. SDXL timestep workaround is pipeline-level, will need replication (pipeline_stable_diffusion_xl.py:1211)
t_unet = t.to(torch.float32).to(device) if is_neuron_device else tThis casts the scheduler's timestep to float32 before passing to the UNet, because Neuron doesn't support int64. The issue is that the UNet's get_time_embed method (unet_2d_condition.py:854-860) only adjusts dtype for non-tensor inputs — if t is already a tensor (which it is from the scheduler), the UNet passes it through unchanged.
This means every pipeline using a UNet on Neuron will need this same workaround. Consider whether the dtype adjustment should be moved into the UNet's get_time_embed instead:
def get_time_embed(self, sample, timestep):
timesteps = timestep
if not torch.is_tensor(timesteps):
dtype = maybe_adjust_dtype_for_device(...)
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ else:
+ timesteps = timesteps.to(maybe_adjust_dtype_for_device(timesteps.dtype, sample.device))This would make the neuron dtype handling automatic for all pipelines. Not blocking, but worth considering for the "next steps" list.
2. Minor inconsistency in neuron device detection (pipeline_stable_diffusion_xl.py:1098 vs pipeline_pixart_alpha.py:865)
SDXL:
is_neuron_device = hasattr(device, "type") and device.type == "neuron"Pixart:
is_neuron_device = device.type == "neuron"Since _execution_device always returns a torch.device, the hasattr guard is unnecessary. Both work, but the inconsistency could confuse future contributors. Minor — not blocking.
Remaining old-style pattern (pre-existing, not introduced by this PR)
src/diffusers/models/embeddings.py:1168-1170 in get_1d_rotary_pos_embed still has:
is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()This is a pre-existing old-style check that wasn't migrated. It converts freqs to float32 on NPU before .cos()/.sin(). Since the caller already passes freqs_dtype through maybe_adjust_dtype_for_device, the freqs tensor would already be float32 on NPU, making this a no-op. However, on Neuron, this check is missing — if freqs were float64 (it won't be since callers pass freqs_dtype through the utility), the .cos() would fail.
Not blocking since the callers handle it, but this old-style check could be cleaned up or extended to also cover neuron for completeness.
Tests
The integration tests (test_pipeline_flux2_klein.py, test_stable_diffusion_xl.py) look reasonable:
- Properly gated with
@require_torch_neuron - Environment variables are saved/restored in
setUp/tearDown - Expected slices are compared with
5e-2tolerance
The test_pixart.py changes appropriately relax tolerance for Neuron (1e-2 vs 1e-4) with a clear comment explaining why (bfloat16 precision).
Summary
The core plumbing is solid. maybe_adjust_dtype_for_device is a clean centralization that's backward-compatible. The set_begin_index(0) hardcoding in SDXL is correct. The main suggestion is to consider moving the timestep dtype workaround from the pipeline into the UNet to avoid per-pipeline replication, and to update .ai/models.md gotcha #5 to reference the new utility.
|
|
||
| # 4. Prepare timesteps | ||
| if XLA_AVAILABLE: | ||
| is_neuron_device = hasattr(device, "type") and device.type == "neuron" |
There was a problem hiding this comment.
| is_neuron_device = hasattr(device, "type") and device.type == "neuron" | |
| is_neuron_device = device.type == "neuron" |
think it's ok, no?
| is_neuron = latent_model_input.device.type == "neuron" | ||
| if isinstance(current_timestep, float): | ||
| dtype = torch.float32 if (is_mps or is_npu) else torch.float64 | ||
| dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 |
There was a problem hiding this comment.
why not use the maybe_adjust_dtype_for_device here too?
There was a problem hiding this comment.
thanks so much for working on this!
Can you also update our agent.md on the new helper function you added here? https://github.com/huggingface/diffusers/blob/main/.ai/models.md (Claude CI review #2 #13289 (comment))
What does this PR do?
This PR adds AWS Neuron (Trainium/Inferentia) as an officially supported compute backend in Diffusers, on par with existing backends like CUDA, MPS, XPU, and MLU.
Changes
Usage
Validation
So far we validated the following models, the idea is to manually validated a part of representative models, and ensure the model coverage with pur CIs afterward.
Next Steps
torch.compileon Neuron device