Skip to content

[KDA] sm100 GVA enhance#65

Open
sjmshsh wants to merge 11 commits into
inclusionAI:mainfrom
sjmshsh:feat/kda-sm100-gva
Open

[KDA] sm100 GVA enhance#65
sjmshsh wants to merge 11 commits into
inclusionAI:mainfrom
sjmshsh:feat/kda-sm100-gva

Conversation

@sjmshsh
Copy link
Copy Markdown
Contributor

@sjmshsh sjmshsh commented May 7, 2026

PR: feat/kda-sm100-gva

Summary

Adds GVA (Grouped Value Attention) support to the KDA training path (chunk_kda) on SM100 (Blackwell), allowing num_v_heads (HV) > num_qk_heads (HQK) with HV % HQK == 0.

Tensor layout follows the gated delta rule GVA convention:

Tensors Head dimension
q, k HQK
v, g, beta, o, state HV

When HV == HQK, behavior matches the existing MHA path.


Motivation

GVA shares fewer Q/K heads across multiple V heads, reducing Q/K compute and memory while preserving model capacity. This PR wires native HQK/HV shapes through the full cuLA KDA stack (SM100 CUDA kernels, Triton backward, Python orchestration) instead of only simulating GVA via host-side repeat_interleave.


Changes

1. SM100 CUDA kernels (csrc/kda/sm100/ + csrc/api/kda_sm100.cu)

  • kda_config.hpp: Split h into h_qk, h_v, and heads_per_group; document Q/K vs V/g/beta/A layouts and strides for GVA.
  • tile_scheduler.hpp: Enumerate tiles by v-head; decode_tile_coord returns (batch_idx, v_head_idx, seq_idx, qk_head_idx) with qk_head_idx = v_head_idx / heads_per_group.
  • kda_fwd_intra_*: TMA/load uses h_qk strides for Q/K and h_v strides for V/g.
  • kda_fwd_recomp_w_u_*: Same head-space split; intermediates (w, u, kg, qg, etc.) remain in HV space.
  • kda_sm100.cu: Derive h_qk / h_v from q.size(2) and g.size(2) (or v.size(2)), validate h_v % h_qk == 0, and pass heads_per_group into the tile scheduler.

2. Python training orchestration (cula/kda/)

File Change
chunk.py chunk_kda entry asserts: q/k are [B, T, HQK, D]; v/g/beta use HV
chunk_fwd.py repeat_interleave on q before fwd_o when HV > HQK (host compat layer; fwd_o still expects a unified head dim)
chunk_intra.py Triton bwd intra: grid B × HV; i_hqk = i_h // (HV // HQK); HQK strides for q/k/dq/dk, HV strides for g/beta/dA
chunk_bwd.py Remove unused q/k from dAv; add HQK constexpr to wy_dqkg_fused and fix q/k/dq/dk pointer offsets

3. Tests

  • tests/test_kda_gva_intra_sm100.py (new): GVA tests for SM100 intra / recomp
    • uniform and varlen layouts
    • degenerate case HV == HQK matches non-GVA
    • output shapes and rejection of invalid HV % HQK
  • tests/test_kda.py: test_chunk_kda_gva and test_chunk_kda_gva_varlen — end-to-end forward/backward vs FLA reference (with q/k expanded to HV); dq/dk compared after summing over the group axis

4. Benchmarks

  • benchmarks/utils.py: Add prepare_safe_gate_inputs_gva (q/k stay in HQK; v/g/beta in HV)
  • bench_kda.py / bench_kda_fwd_bwd_e2e.py: Unified --hv flag (GVA when HV is a multiple of H), aligned with bench_kda_fused_fwd.py
  • bench_kda_chunk_intra.py: GVA configs and comparisons

5. Misc

  • cula/utils.py: Update get_kda_fused_fwd docstring (Blackwell fused prefill still NotImplementedError)

Out of scope (not changed in this PR)

  • cula/kda/blackwell_fused_fwd.py / ops/kda_fully_fused_wip.py: Blackwell fused prefill still requires HQK == HV
  • ops/fwd_o.py / ops/chunk_delta_h.py: No in-kernel native GVA; fwd_o relies on host-side q expansion in chunk_fwd
  • SM90 (Hopper): No changes under csrc/kda/sm90/ in this branch; Hopper prefill GVA lives in the existing hopper_fused_fwd + kda_sm90.cu path

Design notes

GVA mapping:  qk_head = v_head // (HV // HQK)

SM100 tile scheduler:  enumerate tiles by HV; each CTA handles one v-head
SM100 intra CUDA:      Q/K TMA uses h_qk; V/g/beta TMA uses h_v
Triton backward:       grid = B × HV; q/k pointers use i_hqk
fwd_o (CuTe):          host repeat_interleave(q) for now — ops unchanged

Test plan

On an SM100 machine:

# Intra / recomp GVA
pytest tests/test_kda_gva_intra_sm100.py -v

# chunk_kda end-to-end GVA
pytest tests/test_kda.py -k gva -v

# Benchmarks (optional)
python benchmarks/bench_kda.py --hv 32          # e.g. H=16 → GVA group size 2
python benchmarks/bench_kda_fwd_bwd_e2e.py --hv 32
python benchmarks/bench_kda_chunk_intra.py
  • tests/test_kda_gva_intra_sm100.py passes
  • test_chunk_kda_gva / test_chunk_kda_gva_varlen pass
  • Existing MHA tests (HV == HQK) show no regression
  • (Optional) GVA benchmark sanity vs MHA configs

pytest tests/test_kda_gva_intra_sm100.py -v

root@6e1bd959f395:~/cuLA# pytest tests/test_kda_gva_intra_sm100.py -v
===================================================== test session starts =====================================================
platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /usr/bin/python3.12
cachedir: .pytest_cache
rootdir: /root/cuLA
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 28 items                                                                                                            

tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T256-HQK2-gs2-D128-recomp] PASSED                          [  3%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T256-HQK2-gs2-D128-no_recomp] PASSED                       [  7%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T512-HQK4-gs2-D128-recomp] PASSED                          [ 10%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T512-HQK4-gs2-D128-no_recomp] PASSED                       [ 14%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1024-HQK2-gs4-D128-recomp] PASSED                         [ 17%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1024-HQK2-gs4-D128-no_recomp] PASSED                      [ 21%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T1024-HQK4-gs4-D128-recomp] PASSED                         [ 25%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T1024-HQK4-gs4-D128-no_recomp] PASSED                      [ 28%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T500-HQK2-gs2-D128-recomp] PASSED                          [ 32%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T500-HQK2-gs2-D128-no_recomp] PASSED                       [ 35%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1000-HQK4-gs2-D128-recomp] PASSED                         [ 39%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1000-HQK4-gs2-D128-no_recomp] PASSED                      [ 42%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs2-D128-ns3-recomp] PASSED                               [ 46%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs2-D128-ns3-no_recomp] PASSED                            [ 50%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns4-recomp] PASSED                               [ 53%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns4-no_recomp] PASSED                            [ 57%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs4-D128-ns5-recomp] PASSED                               [ 60%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs4-D128-ns5-no_recomp] PASSED                            [ 64%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns10-recomp] PASSED                              [ 67%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns10-no_recomp] PASSED                           [ 71%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B1-T512-H4-D128-recomp] PASSED              [ 75%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B1-T512-H4-D128-no_recomp] PASSED           [ 78%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B2-T1024-H4-D128-recomp] PASSED             [ 82%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B2-T1024-H4-D128-no_recomp] PASSED          [ 85%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_output_shapes[1] PASSED                                               [ 89%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_output_shapes[2] PASSED                                               [ 92%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_output_shapes[4] PASSED                                               [ 96%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_rejects_non_multiple_ratio PASSED                                     [100%]

===================================================== 28 passed in 8.14s ======================================================

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Grouped V-head Attention (GVA) support across the KDA kernels for both SM90 and SM100 architectures. Key changes include decoupling head counts for Q/K and V/G tensors, updating TMA descriptors and tile scheduling logic to handle these grouped configurations, and adding comprehensive validation checks. The Python API and test suite have been updated to support and verify GVA functionality. Feedback from the review identifies a documentation mismatch regarding tensor layouts in the SM100 mainloop and suggests correcting terminology in Python error messages to distinguish between head count and head dimension.

int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16);

// GMEM output address: layout [total_len, d, h], stride [d*h, 1, d]
// GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment mentions layout [total_len, d, h_v], but the stride [d*h_v, 1, d] and the code logic actually correspond to a [total_len, h_v, d] layout (where d is the inner-most dimension).

                // GMEM output address: layout [total_len, h_v, d], stride [d*h_v, 1, d]

Comment thread cula/kda/chunk_intra.py Outdated
f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}"
)
assert HV > 0 and HQK > 0 and HV % HQK == 0, (
f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message incorrectly uses the term 'head-dim' when referring to HV and HQK, which represent the number of heads (head count). The head dimension is represented by K.

Suggested change
f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})"
f"v head count (HV={HV}) must be a positive multiple of k head count (HQK={HQK})"

Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads.

C++ changes:

- tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour.

- kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space.

- intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v.

- recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v.

API / Python:

- kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes.

- cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions.

Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout.
@sjmshsh sjmshsh force-pushed the feat/kda-sm100-gva branch from e0e3494 to 58535e2 Compare May 7, 2026 03:02
@sjmshsh sjmshsh changed the title [KDA] sm100 GVA enhance 【Draft】[KDA] sm100 GVA enhance May 7, 2026
@sjmshsh sjmshsh changed the title 【Draft】[KDA] sm100 GVA enhance [KDA] sm100 GVA enhance May 19, 2026
@sjmshsh
Copy link
Copy Markdown
Contributor Author

sjmshsh commented May 19, 2026

@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell.

@KevinZeng08
Copy link
Copy Markdown
Collaborator

@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell.

Thanks for your contribution. You can try to first run benchmarks/bench_kda_chunk_intra.py because this PR seems to support GVA for chunk_intra and recompute_wu. If OK, you may refactor the code to only support GVA for chunk_intra and recompute_wu, together with kda_chunk_intra benchmarks without modifying the end-to-end implementation. Then we can merge it first.
For delta_h and fwd_o with CuTeDSL implementation and FLA v0.5.0 upgrade #67, I will open two PRs separately . After the upgrade, the Triton code is the same as FLA.
After these changes, we can verify the end-to-end correctness and benchmark.

@sjmshsh sjmshsh force-pushed the feat/kda-sm100-gva branch from b291751 to c6a492b Compare May 19, 2026 11:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants