Skip to content

Guard scatter_axis against 64-bit outputs on the GPU#3695

Open
obchain wants to merge 2 commits into
ml-explore:mainfrom
obchain:fix/scatter-axis-64bit-guard
Open

Guard scatter_axis against 64-bit outputs on the GPU#3695
obchain wants to merge 2 commits into
ml-explore:mainfrom
obchain:fix/scatter-axis-64bit-guard

Conversation

@obchain

@obchain obchain commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3690.

mx.put_along_axis / scatter_add_axis with a 64-bit element dtype (int64/uint64) fail the Metal library JIT build instead of raising a clean "unsupported dtype" error. In mlx/backend/metal/kernels/atomic.h, packing_size<T> = sizeof(uint)/sizeof(T) is 0 for 8-byte T, so uint_or_packed<T> declares a zero-length array (hard C++ error) and offset / packing_size<T> divides by zero. The whole mlx-metallib build then fails.

The plain Scatter path already guards 8-byte outputs on the GPU — in scatter() (ops.cpp) and Scatter::eval_gpu (indexing.cpp) — but ScatterAxis was missing the equivalent guard. This adds it in both places, mirroring Scatter:

  • scatter_axis() in mlx/ops.cpp — raises on GPU for 8-byte dtypes, matching the existing scatter() guard. CPU is unaffected.
  • ScatterAxis::eval_gpu in mlx/backend/metal/indexing.cpp — same guard as Scatter::eval_gpu.

Repro before the fix (Metal device):

import mlx.core as mx
mx.set_default_device(mx.gpu)
x   = mx.zeros((4, 8), dtype=mx.int64)
idx = mx.array([[0],[1],[2],[3]])
upd = mx.ones((4, 1), dtype=mx.int64)
mx.eval(mx.put_along_axis(x, idx, upd, axis=1))  # was: Metal JIT build failure

After the fix this raises a clean ValueError on the GPU, and continues to work on the CPU.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

put_along_axis / scatter_add_axis with int64/uint64 values failed the
Metal library JIT build instead of raising a clean error: packing_size<T>
in atomic.h is sizeof(uint)/sizeof(T) == 0 for 8-byte T, producing a
zero-length array and a divide-by-zero in the fallback atomic union.

The plain Scatter path already guards this in scatter() and
Scatter::eval_gpu; ScatterAxis had no equivalent guard. Mirror it in
scatter_axis() (GPU only, matching scatter()) and ScatterAxis::eval_gpu,
and add a test.
Comment thread mlx/ops.cpp Outdated
}

// TODO, remove when scatter_axis supports 64-bit outputs
if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check in metal/indexing.cpp along is enough, and it actually works in the cuda backend, the test should also be updated for metal only

The early throw in scatter_axis blocked 64-bit outputs on every GPU
stream, but only the Metal backend is missing support; CUDA scatters
them fine. Remove the op-level check and rely on the guard in
ScatterAxis::eval_gpu so non-Metal backends are unaffected, and scope
the test to the Metal backend while still exercising the CPU path.
@obchain

obchain commented Jun 22, 2026

Copy link
Copy Markdown
Contributor Author

Done — dropped the op-level check in scatter_axis and kept the guard in ScatterAxis::eval_gpu so only the Metal backend rejects 64-bit outputs and CUDA is unaffected. Scoped the test to mx.metal.is_available() (asserting on the gpu stream) while still exercising the CPU path.

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with this since it is doing exact same check with Scatter::eval_gpu.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] int64/uint64 put_along_axis / scatter_add_axis crashes the Metal JIT build

2 participants