Skip to content

[CUDA] Implement SegmentedMM#3238

Merged
angeloskath merged 8 commits intoml-explore:mainfrom
Lyxot:cuda/segmented_mm
Mar 11, 2026
Merged

[CUDA] Implement SegmentedMM#3238
angeloskath merged 8 commits intoml-explore:mainfrom
Lyxot:cuda/segmented_mm

Conversation

@Lyxot
Copy link
Contributor

@Lyxot Lyxot commented Mar 10, 2026

Proposed changes

Implement mx.segmented_mm for the CUDA backend using CUTLASS grouped GEMM.

Performance

MLX_ENABLE_TF32=0, random segments.

float32

| Case              | MLX ms | Loop ms | Speedup |
|-------------------|--------|---------|---------|
| 128x128x1024x16   | 0.040  | 0.353   | 8.73x   |
| 128x128x1024x32   | 0.042  | 0.645   | 15.49x  |
| 256x256x2048x16   | 0.130  | 0.566   | 4.35x   |
| 512x512x4096x32   | 0.200  | 0.788   | 3.95x   |
| 1024x1024x4096x32 | 0.596  | 1.509   | 2.53x   |
| 1024x1024x8192x64 | 1.197  | 2.929   | 2.45x   |

bfloat16

| Case              | MLX ms | Loop ms | Speedup |
|-------------------|--------|---------|---------|
| 128x128x1024x16   | 0.041  | 0.184   | 4.54x   |
| 128x128x1024x32   | 0.032  | 0.327   | 10.35x  |
| 256x256x2048x16   | 0.061  | 0.186   | 3.05x   |
| 512x512x4096x32   | 0.161  | 0.395   | 2.46x   |
| 1024x1024x4096x32 | 0.515  | 0.791   | 1.54x   |
| 1024x1024x8192x64 | 1.095  | 1.593   | 1.46x   |

MLX ms = mx.segmented_mm (CUTLASS grouped GEMM), Loop ms = MLX loop-of-matmuls baseline.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copilot AI review requested due to automatic review settings March 10, 2026 09:39
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements mx.segmented_mm on the CUDA backend by dispatching to a CUTLASS grouped GEMM path, enabling the existing Python BLAS segmented-mm tests to run on CUDA and adding a standalone benchmark script to compare against a loop-of-matmuls baseline.

Changes:

  • Enable CUDA support for the SegmentedMM primitive and remove the CUDA test skip.
  • Add SegmentedMM::eval_gpu implementation that prepares inputs and calls a new CUTLASS grouped-GEMM launcher.
  • Introduce a Python benchmark for mx.segmented_mm performance and numerical error checks.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
python/tests/cuda_skip.py Removes the CUDA skip for the segmented_mm BLAS test.
mlx/backend/cuda/primitives.cpp Marks SegmentedMM as supported on CUDA (removes NO_GPU).
mlx/backend/cuda/matmul.cpp Adds CUDA implementation for SegmentedMM::eval_gpu.
mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu Adds segment-to-grouped-GEMM argument preparation kernel and CUTLASS dispatch wrapper.
mlx/backend/cuda/gemms/grouped_gemm.h Declares cutlass_segmented_mm entrypoint.
benchmarks/python/segmented_mm_bench.py Adds benchmarking and correctness checking script for segmented_mm.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@angeloskath angeloskath requested a review from zcbenz March 10, 2026 20:12
Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing the missing ops! I'm not sure why the tests are failing, can you try rebasing the branch?

Lyxot added 6 commits March 11, 2026 12:42
Replace the host-side cuBLAS loop with a single CUTLASS grouped GEMM
dispatch. A GPU-side prepare kernel builds per-segment problem sizes
and pointer offsets from the segments array, eliminating the host sync
that was required to read segment boundaries.
CUTLASS handles K=0 segments correctly: the mainloop iterates
zero times and the epilogue writes zeros to the output.
Compare mx.segmented_mm (grouped GEMM) against MLX loop-of-matmuls baseline. Remove torch dependency. Add accuracy checks: fp32 vs numpy fp64, fp16/bf16 vs own fp32 result.
@Lyxot Lyxot force-pushed the cuda/segmented_mm branch from 6bc92b2 to b800f69 Compare March 11, 2026 04:51
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Lyxot Lyxot force-pushed the cuda/segmented_mm branch from b800f69 to d273114 Compare March 11, 2026 05:01
@Lyxot Lyxot requested a review from zcbenz March 11, 2026 05:04
@Lyxot
Copy link
Contributor Author

Lyxot commented Mar 11, 2026

Can't reproduce test failure on my setup

@zcbenz
Copy link
Collaborator

zcbenz commented Mar 11, 2026

For failures that only happened in CI, a plausible explanation is that there are code relying on uninitialized GPU memory being zeros, but we are not going to know since it happened in a cuDNN kernel. It is also possible that we have memory corruptions or out-of-bound writes, but it should be able to reproduce locally. Anyway changing the order of tests by renaming can work around it, and I have went through the code several times and I don't think it is the problem of this PR.

@angeloskath angeloskath merged commit a9573f9 into ml-explore:main Mar 11, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants