fix(cuda/quantized): chunk long-prompt qmm launches to avoid gridDim>65535 and int overflow#652
Open
inureyes wants to merge 1 commit into
Open
fix(cuda/quantized): chunk long-prompt qmm launches to avoid gridDim>65535 and int overflow#652inureyes wants to merge 1 commit into
inureyes wants to merge 1 commit into
Conversation
…65535 and int overflow Long-prompt prefill of 4-bit models aborted with an invalid CUDA launch (cuGraphAddKernelNode / cuLaunchKernelEx "invalid argument", SIGABRT). Two distinct causes in the quantized-matmul path (#648): - MoE GatherQMM sets qmm_sm80 grid.z = l = tokens * num_experts_per_tok, which exceeds the CUDA gridDim.z limit of 65535 once tokens*top_k >= 65536 (qwen3-30b-a3b @8192, mixtral-8x7b @32768). - The dense LM head computes l = out.size() / (m * n) with int32 m,n, which overflows when tokens*vocab >= 2^31 (about 16744 tokens for a 128256 vocab), yielding l = 0 and an invalid grid.z of 0. Fix, all in the mlxcel-owned overlay quantized.cpp: - run_batch_chunked splits GatherQMM's gathered batch into <=65535 slices. - run_row_chunked splits QuantizedMatmul's M into slices of min(65535, INT32_MAX/N) so each launch keeps grid.y <= 65535 and m*n < 2^31. - int B = out.size() / M / N (sequential 64-bit division) so the overflow no longer yields B=0 and misroutes dispatch. Both early-return to the original single call when no split is needed, so normal-size matmuls are byte-identical. Validated on GB10: previously FAIL:bench long-prompt cells now pass (llama-3.1-8b-4bit @17000/@23491/@32768, qwen3-30b-a3b-4bit @8192), with coherent generation confirming the sub-array offsets are correct. CUDA-only; the Metal backend uses out.size()/M/N and is unaffected. Refs #648
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Long-prompt prefill of 4-bit models aborted on the CUDA backend with an invalid launch (
cuGraphAddKernelNode/cuLaunchKernelEx"invalid argument", SIGABRT). This surfaced asFAIL:benchcells in the long-prompt sweep (epic #623/#624) and as crashes inmlxcel generateon long prompts. Two distinct causes in the quantized-matmul path:qmm_sm80setsgrid.z = l = tokens * num_experts_per_tok, which exceeds the CUDAgridDim.zlimit of 65535 oncetokens*top_k >= 65536(qwen3-30b-a3b @8192, mixtral-8x7b @32768).l = out.size() / (m * n)is computed withint32m,n, which overflows whentokens*vocab >= 2^31(about 16744 tokens for a 128256 vocab), yieldingl = 0and an invalidgrid.zof 0.Fix
All in the mlxcel-owned overlay
src/lib/mlx-cpp/patches/mlx/backend/cuda/quantized/quantized.cpp:run_batch_chunkedsplits GatherQMM's gathered batch into<= 65535slices (sub-array views over the same buffer).run_row_chunkedsplits QuantizedMatmul'sMinto slices ofmin(65535, INT32_MAX/N)so each launch keepsgrid.y <= 65535andm*n < 2^31. Gated onsingle_batch = out.size() == size_t(M)*N(the LM-head output is 3D[1, tokens, vocab], so anndim==2gate would miss it).int B = out.size() / M / N(sequential 64-bit division) so the overflow no longer yieldsB=0and misroutes dispatch.Both helpers early-return to the original single call when no split is needed, so normal-size matmuls are byte-identical. The change is CUDA-only; the Metal backend already uses
out.size()/M/Nand has no 65535 grid cap, so it is unaffected.Validation (GB10, NVIDIA sm_121)
Previously-failing long-prompt cells now pass, exit 0 with real throughput:
scripts/bench_longprompt.shruns):llama-3.1-8b-4bit@17000 / @32768,qwen3-30b-a3b-4bit@8192.mlxcel generate:llama-3.1-8b-4bit@23491 andqwen3-30b-a3b-4bit@~11k complete with coherent output, confirming the sub-array offsets are correct.Closes #648