Add MiniMax-M3 (MSA: MiniMax Sparse Attention) support#24908
Add MiniMax-M3 (MSA: MiniMax Sparse Attention) support#24908timkhronos wants to merge 46 commits into
Conversation
Text-only port that re-uses existing components: MiniMax-M2 style GQA with per-head QK-norm and partial rotary, DeepSeek-V3 style leading-dense and routed/shared experts, and swigluoai activation. Sparse attention is not yet supported (dense fallback); vision tower and MTP heads are dropped.
…ch per group block picking
|
Hi @timkhronos, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
|
Thanks. This is full support for a new model architecture (MiniMax-M3), so the core: arch definition, graph, KV-cache handling, indexer and conversion, has to land together to be functional; those aren't independently mergeable. The vision tower / mmproj is the one separable piece, but it's a small, self-contained addition (a standard CLIP-style ViT, the only model-specific parts are the patch embed, 3D-RoPE, and patch-merge projector), so I'd prefer to keep it in for complete model support in one pass rather than split it out. Of course, if a maintainer would still rather see vision as a follow-up, I'll split it. I am happy to open a tracking issue / RFC for the architecture now if you'd like the design discussed before review, though the PR description already covers the design, correctness validation, and limitations in detail. I went straight to the PR since this began as a vision-tower change and grew into full MSA support once I traced the long-context degradation past ~6k to the sparse-attention path, but I'm glad to write up an issue retroactively if that's the preferred process. |
4-way paths. Full debug harness remains at <8136a9c68ed7a5eb009aa67bba3fda8062f4648f> for reproducing the selection-parity validation.
…ollow value naming convention
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
CISC
left a comment
There was a problem hiding this comment.
Ok, this should be clean enough to proceed with reviewing MSA (@fairydreaming do you want to take a look?) and mmproj (unless @ngxson want that in a separate PR?).
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
|
What should perf expectations be when testing? I tried the earlier commit w/Avar6 gguf and perf is poor: TG=6t/s and PP seems even slower (!?) R6KP(attn)+EPYC(moe) . I don't see how this could be anyone's "daily driver" as described above by @Ganju0 |
Well I dont drive very fast. 6t/s is what I was getting |
|
@usrlocalben with the Q4 gguf from Avar's repo, and the majority of model in sys ram, I have been getting |
|
Please split the multimodal changes into a dedicated PR |
@CISC Sorry, I don't have this model downloaded and examined in detail yet, just skimmed over the paper so I'm not qualified. |
Large part of the code is clearly AI-written. Please be honest about AI usage. Proof: you manually modified AI-generated comments to make it human-like: 2e82759 (To be clear: I'm ok with using AI and I myself use it from time to time, however, dishonesty is not something I will tolerate) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Explained in #24908 (comment) and OP. |
Overview
Adds support for MiniMax-M3, a 60-layer / 128-expert MoE (3 dense + 57 MoE layers) using MiniMax Sparse Attention (MSA), a GQA-based block-sparse attention where a lightweight per-GQA-group indexer selects a small set of key/value blocks (top-k), and the main attention runs only over the selected blocks. This PR implements the architecture from the MSA paper arXiv:2606.13392v2 and follows the transformers reference (modular_minimax_m3_vl) for the indexer construction, selection, and RoPE.
Vision support moved to separate PR #25113
This implements MiniMax-M3 support: text tower (MSA attention + MoE)
and the vision tower / mmproj (CLIP-style ViT, Conv3d patch embed, 3D-RoPE, the two-stage patch-merge projector).MTP heads are not currently present in the released model file.Additional information
What MSA does (and why the code looks the way it does)
Per sparse layer, for each query and GQA group:
Index branch scores the causal context against a single shared index-key head via a small Q_idx·K_idxᵀ, max-pools scores into blocks ( with each block size being: block_size = 128), and selects top_k = 16 blocks. The local (query's own) block is always force-included.
Main branch attends only over the tokens in the selected blocks, which is a fixed budget of k·B_k = 16·128 = 2048 keys per query, independent of context length, meaning decode attention cost stops scaling with context.
This PR implements two execution paths, which are selected at runtime:
MSA Decode path
(build_attn_msa_decode, single-token batches ONLY): top-k -> gather the 2048 selected KV -> flash attention. This is the fast path, which only attends over the tokens in the selected blocks.
4-way / prefill path
(build_attn_msa_4way, single or multi-token batches): builds a per-group block mask and runs masked flash attention over the context. By default, only used for prefill (n_tokens > 1) and as the simpler, more directly-auditable reference path.
Both are per-GQA-group (4 groups, one index selection each, shared across the group's 16 query heads), matching the reference.
Correctness
vs. paper + HF reference: indexer projections/norms, partial-RoPE on the index heads, block max-pool, per-group top-k, forced-local-block, and the 1/√d omission (order-preserving, irrelevant to top-k) all follow the reference. The +1 Gemma-RMSNorm baking is applied to the indexer norms during conversion (index_{q,k}_norm); dumped converted weights show: (means ≈ 1.2, indistinguishable from attn_q_norm).
Selection parity, MSA decode vs. 4-way: SELDUMP(dumps the selected blocks for every index head) was ran on the first decoded token at a fixed prefix, and shows the two selectors agree except for occasional single-block flips at the rank-16 boundary, which are TF32-vs-f32 score-precision ties only appearing near the marginal, lowest-mass blocks. Forced blocks (sink, local) and high-mass blocks are always identical.
Debugs (MSA_DUMP_SEL, MSA_VERIFY_BS4, MSA_BYPASS, MSA_FORCE_4WAY) are left in, tagged with DEBUGREMOVE, I'll strip them before merge on request.
Index KV-cache (notable implementation detail)
The indexer needs its own per-layer key cache (single head, index_head_size = 128). It is allocated on the same backend buffer as the main K/V cache, not on a host buffer. An earlier host-resident version caused a full [d_idx, n_kv] host->device copy of the index cache on every sparse layer, every decode step, which dominated long-context decode (a ~26 ms/token slope at 30k that grew with context). Keeping it co-resident with K/V eliminates that.
However this introduced a separate issue. The cache is a strided view into the ring buffer, and an M=4 (index-heads) matmul against a strided operand falls off the fast path. The solution I found, is to have the index-cache view fed to the scoring mul_mat be forced contiguous (ggml_cont) before the matmul. Without it, long-context decode regresses badly even with the cache on-GPU. That cont is load bearing.
Flash-attention requirement
The MSA in this PR requires flash attention. Both build_attn_msa_decode and build_attn_msa_4way call ggml_flash_attn_ext directly and assume the non-transposed V layout that llama.cpp only provides when FA is enabled (v_trans = !cparams.flash_attn). There is intentionally no full-attention-materialization fallback, as the only environment that can't run flash_attn_ext is CPU-class, and materializing full attention scores 57×/token over long context, would perform catastrophically there. Dense fallback is strictly better behavior in that case.
The PR gates MSA on resolved cparams.flash_attn:
FA on -> MSA (decode for single-token, 4way for prefill).
FA off -> dense build_attn (no sparsity) + one-time warning that output is degraded.
MSA_BYPASS forces dense; MSA_FORCE_4WAY forces the 4way path at decode (debug flags, will be stripped before merging).
The FA nodes are named with the LLAMA_TENSOR_NAME_FATTN convention so the --flash-attn auto resolver in sched_reserve recognizes them (it identifies FA nodes by name to check their device assignment). Without the naming, server auto aborts at the FATTN-name assert before ever reaching the gate.
Auto-FA resolution is verified on CUDA only. The gate keys on resolved cparams.flash_attn as mentioned earlier, so other backends should behave correctly where they support flash_attn_ext for head-dim 128 (falling back to dense otherwise), however ROCm/HIP, Vulkan, and Metal are untested. Reports from those backends welcome.
Limitations/scope
Quantization
Quants should keep indexer projections (index_{q,k}_proj) at fp32, they drive block selection, so quant error there changes which blocks are read (a discrete retrieval loss), unlike a value projection. Strongly recommend fp32 for those even in otherwise-aggressive quants. The weights are bf16, but fp16 conversion might introduce issues. The difference is about 1GB, and these are genuinely some of the highest value tensors in the model.
Performance
At the tested config (decode bounded by CPU-offloaded experts), MSA decode tracks the dense baseline (~8 tok/s at 20–30k), and prefill benefits slightly from the reduced sparse-attention FLOPs (~400+ tok/s). On a fully GPU-resident deployment (experts on device), MSA decode should approach the official implementation's decode characteristics by construction, both compute attention over only the fixed 2048 selected KV per query rather than the full cache. This is an architectural expectation, not a measured result; the tested config here is expert-offload-bound and only demonstrates parity with the dense baseline.
Conversion
convert_hf_to_gguf.py support for MiniMaxM3SparseForCausalLM: merges routed experts, bakes Gemma (1+w) norms (incl. indexer norms), emits indexer hparams (head_count, key_length, top_k, block_size, local_blocks) and leading_dense_block_count (derived from the leading zeros of moe_layer_freq).
Further added support for MiniMaxM3VisionModel.Follor up PR#25113How to test
All tests assume a GGUF of M3 specifically converted with this branch. Weights with the indexers stripped (such as those by Unsloth) will NOT work. All tests were ran against https://huggingface.co/avar6/minimax-m3-MSA-gguf , which keeps the Indexer weights. (It does incorrectly quantize the non norm indexer tensors to Q4. This is due to the GGUFs being made early on in the implementation process.)Also assumed CUDA build (-DGGML_CUDA=ON), but should work on other FA compatible backends.Currently the existing GGUFs no longer work due to a tensor rename. New ones will have to be reconverted. Meanwhile the previous branch available over here is still able to load the Avar6 GGUFs in case anyone wants to experiment with the PR without converting their own GGUFs.
Performance can be compared directly with the dense attention. To test speed. feed a 4k and a 30k-token prompt and generate; decode throughput should stay roughly flat (the per-token attention budget is fixed at 2048 KV regardless of context). Compare against MSA_BYPASS=1 (dense) to see dense's KV-read grow while MSA stays fixed.
Debug variables:
MSA_DUMP_SEL=<token_position> # dumps per-group selected blocks for the specified token
MSA_FORCE_4WAY=1 # forces the 4way selector at decode (A/B vs decode path)
MSA_BYPASS=1 # dense attention, no MSA
Decode-path and 4way-path selections should match except for occasional single-block flips at the top-k boundary (score-precision ties). These env probes are tagged DEBUGREMOVE and will be stripped before merge.
Multimodal
Multimodal support has been moved into a separate PR upon maintained request. PR #25113
Vision tower (mmproj)A CLIP-style ViT, separate biased q/k/v/o projections, LayerNorm, GELU MLP, full bidirectional attention (no mask, no windowing). It differs from a vanilla CLIP encoder in four specific ways, all handled:Conv3D patch embed, run as summed Conv2D temporal slices. The HF model uses a Conv3D patch embedding with temporal_patch_size slices; conversion splits the 5D Conv3D weight into per-slice Conv2D kernels (V_ENC_EMBD_PATCH + .weight.{t}), and the graph sums the Conv2D outputs. For still images this is exact (video is out of scope here, see below). No patch-embed bias (asserted).Custom 3D (T/H/W) RoPE. Position encoding is a 3-axis RoPE: axis_dim = 2·((2·(d_head/2)/3)/2) = 26, rope_dim = 3·axis_dim = 78, applied to the first rope_dim channels of each head with HF rotate_half semantics, tail passed through. cos/sin are host-precomputed and fed as graph inputs (minimax_cos/minimax_sin). Because rope_dim (78) < d_head (80), this is a partial rotary, the same partial-RoPE pattern as the text tower, just 3-axis.2×2 spatial-merge token reduction. Patches are reordered raster->block (matching the HF flatten) and merged 2×2, so the projector consumes groups of 4 patches. spatial_merge_size is emitted in conversion.No class token, no absolute position table, no pre/post-layernorm asymmetry beyond a pre_layernorm and no post_layernorm (both asserted). Sinks/abs-pos/class-embedding are all absent and asserted null.Projector is two on-disk modules: a per-patch MLP (multi_modal_projector.linear_{1,2} -> mm.{1,2}) applied first, then the 2×2 group-of-4 concatenation, then a merge MLP (patch_merge_mlp.linear_{1,2} -> mm.merge.fc{1,2}), both GELU. Output is [proj_dim, n_pos/4] into the text embedding space.Current vision limitation:still images only (n_batch == 1). The Conv3D-as-summed-Conv2D path is image-correct; video (multiple temporal frames) is not yet wired and is out of scope for this PR. If preferred I could fold it in, however I believe it might be best left as a separate PR.Vision testingRequires a gguf mmproj. One can be found here: https://huggingface.co/Serpen/Minimax_M3_MMPROJ_GGUF/blob/main/MiniMaxAI-MiniMax-M3-bf16.ggufCredits
The initial MiniMax-M3 scaffolding was based on work by @danielhanchen. The sparse-attention (MSA) indexer, the decode and 4-way paths, the KV-cache index handling, FA gating, and the vision tower in this PR are new work.
Requirements