Skip to content

fix(cuda/quantized): chunk long-prompt qmm launches to avoid gridDim>65535 and int overflow#652

Open
inureyes wants to merge 1 commit into
mainfrom
fix/648-cuda-longprompt-grid-limit
Open

fix(cuda/quantized): chunk long-prompt qmm launches to avoid gridDim>65535 and int overflow#652
inureyes wants to merge 1 commit into
mainfrom
fix/648-cuda-longprompt-grid-limit

Conversation

@inureyes

@inureyes inureyes commented Jul 4, 2026

Copy link
Copy Markdown
Member

Summary

Long-prompt prefill of 4-bit models aborted on the CUDA backend with an invalid launch (cuGraphAddKernelNode / cuLaunchKernelEx "invalid argument", SIGABRT). This surfaced as FAIL:bench cells in the long-prompt sweep (epic #623/#624) and as crashes in mlxcel generate on long prompts. Two distinct causes in the quantized-matmul path:

  • MoE GatherQMM: qmm_sm80 sets grid.z = l = tokens * num_experts_per_tok, which exceeds the CUDA gridDim.z limit of 65535 once tokens*top_k >= 65536 (qwen3-30b-a3b @8192, mixtral-8x7b @32768).
  • Dense LM head: l = out.size() / (m * n) is computed with int32 m,n, which overflows when tokens*vocab >= 2^31 (about 16744 tokens for a 128256 vocab), yielding l = 0 and an invalid grid.z of 0.

Fix

All in the mlxcel-owned overlay src/lib/mlx-cpp/patches/mlx/backend/cuda/quantized/quantized.cpp:

  • run_batch_chunked splits GatherQMM's gathered batch into <= 65535 slices (sub-array views over the same buffer).
  • run_row_chunked splits QuantizedMatmul's M into slices of min(65535, INT32_MAX/N) so each launch keeps grid.y <= 65535 and m*n < 2^31. Gated on single_batch = out.size() == size_t(M)*N (the LM-head output is 3D [1, tokens, vocab], so an ndim==2 gate would miss it).
  • int B = out.size() / M / N (sequential 64-bit division) so the overflow no longer yields B=0 and misroutes dispatch.

Both helpers early-return to the original single call when no split is needed, so normal-size matmuls are byte-identical. The change is CUDA-only; the Metal backend already uses out.size()/M/N and has no 65535 grid cap, so it is unaffected.

Validation (GB10, NVIDIA sm_121)

Previously-failing long-prompt cells now pass, exit 0 with real throughput:

  • Bench (what scripts/bench_longprompt.sh runs): llama-3.1-8b-4bit @17000 / @32768, qwen3-30b-a3b-4bit @8192.
  • mlxcel generate: llama-3.1-8b-4bit @23491 and qwen3-30b-a3b-4bit @~11k complete with coherent output, confirming the sub-array offsets are correct.

Closes #648

…65535 and int overflow

Long-prompt prefill of 4-bit models aborted with an invalid CUDA launch
(cuGraphAddKernelNode / cuLaunchKernelEx "invalid argument", SIGABRT). Two
distinct causes in the quantized-matmul path (#648):

- MoE GatherQMM sets qmm_sm80 grid.z = l = tokens * num_experts_per_tok,
  which exceeds the CUDA gridDim.z limit of 65535 once tokens*top_k >= 65536
  (qwen3-30b-a3b @8192, mixtral-8x7b @32768).
- The dense LM head computes l = out.size() / (m * n) with int32 m,n, which
  overflows when tokens*vocab >= 2^31 (about 16744 tokens for a 128256 vocab),
  yielding l = 0 and an invalid grid.z of 0.

Fix, all in the mlxcel-owned overlay quantized.cpp:
- run_batch_chunked splits GatherQMM's gathered batch into <=65535 slices.
- run_row_chunked splits QuantizedMatmul's M into slices of
  min(65535, INT32_MAX/N) so each launch keeps grid.y <= 65535 and m*n < 2^31.
- int B = out.size() / M / N (sequential 64-bit division) so the overflow no
  longer yields B=0 and misroutes dispatch.
Both early-return to the original single call when no split is needed, so
normal-size matmuls are byte-identical.

Validated on GB10: previously FAIL:bench long-prompt cells now pass
(llama-3.1-8b-4bit @17000/@23491/@32768, qwen3-30b-a3b-4bit @8192), with
coherent generation confirming the sub-array offsets are correct. CUDA-only;
the Metal backend uses out.size()/M/N and is unaffected.

Refs #648
@inureyes inureyes added area:benchmark Benchmark harness and performance measurement (bench_*.sh, /update-benchmarks) area:inference Generation, sampling, decoding (incl. speculative, DRY) priority:high High priority type:bug Bug fixes, error corrections, or issue resolutions labels Jul 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:benchmark Benchmark harness and performance measurement (bench_*.sh, /update-benchmarks) area:inference Generation, sampling, decoding (incl. speculative, DRY) priority:high High priority type:bug Bug fixes, error corrections, or issue resolutions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA: long-prompt prefill aborts (exit 134, invalid launch config) when a kernel grid dim exceeds 65535

1 participant