[CUDA] Implement SegmentedMM#3238
Conversation
There was a problem hiding this comment.
Pull request overview
Implements mx.segmented_mm on the CUDA backend by dispatching to a CUTLASS grouped GEMM path, enabling the existing Python BLAS segmented-mm tests to run on CUDA and adding a standalone benchmark script to compare against a loop-of-matmuls baseline.
Changes:
- Enable CUDA support for the
SegmentedMMprimitive and remove the CUDA test skip. - Add
SegmentedMM::eval_gpuimplementation that prepares inputs and calls a new CUTLASS grouped-GEMM launcher. - Introduce a Python benchmark for
mx.segmented_mmperformance and numerical error checks.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| python/tests/cuda_skip.py | Removes the CUDA skip for the segmented_mm BLAS test. |
| mlx/backend/cuda/primitives.cpp | Marks SegmentedMM as supported on CUDA (removes NO_GPU). |
| mlx/backend/cuda/matmul.cpp | Adds CUDA implementation for SegmentedMM::eval_gpu. |
| mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu | Adds segment-to-grouped-GEMM argument preparation kernel and CUTLASS dispatch wrapper. |
| mlx/backend/cuda/gemms/grouped_gemm.h | Declares cutlass_segmented_mm entrypoint. |
| benchmarks/python/segmented_mm_bench.py | Adds benchmarking and correctness checking script for segmented_mm. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
zcbenz
left a comment
There was a problem hiding this comment.
Thanks for implementing the missing ops! I'm not sure why the tests are failing, can you try rebasing the branch?
Replace the host-side cuBLAS loop with a single CUTLASS grouped GEMM dispatch. A GPU-side prepare kernel builds per-segment problem sizes and pointer offsets from the segments array, eliminating the host sync that was required to read segment boundaries.
CUTLASS handles K=0 segments correctly: the mainloop iterates zero times and the epilogue writes zeros to the output.
Compare mx.segmented_mm (grouped GEMM) against MLX loop-of-matmuls baseline. Remove torch dependency. Add accuracy checks: fp32 vs numpy fp64, fp16/bf16 vs own fp32 result.
6bc92b2 to
b800f69
Compare
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
b800f69 to
d273114
Compare
|
Can't reproduce test failure on my setup |
|
For failures that only happened in CI, a plausible explanation is that there are code relying on uninitialized GPU memory being zeros, but we are not going to know since it happened in a cuDNN kernel. It is also possible that we have memory corruptions or out-of-bound writes, but it should be able to reproduce locally. Anyway changing the order of tests by renaming can work around it, and I have went through the code several times and I don't think it is the problem of this PR. |
Proposed changes
Implement
mx.segmented_mmfor the CUDA backend using CUTLASS grouped GEMM.Performance
MLX_ENABLE_TF32=0, random segments.float32
bfloat16
MLX ms=mx.segmented_mm(CUTLASS grouped GEMM),Loop ms= MLX loop-of-matmuls baseline.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes