[KDA] sm100 GVA enhance#65
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Grouped V-head Attention (GVA) support across the KDA kernels for both SM90 and SM100 architectures. Key changes include decoupling head counts for Q/K and V/G tensors, updating TMA descriptors and tile scheduling logic to handle these grouped configurations, and adding comprehensive validation checks. The Python API and test suite have been updated to support and verify GVA functionality. Feedback from the review identifies a documentation mismatch regarding tensor layouts in the SM100 mainloop and suggests correcting terminology in Python error messages to distinguish between head count and head dimension.
| int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16); | ||
|
|
||
| // GMEM output address: layout [total_len, d, h], stride [d*h, 1, d] | ||
| // GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d] |
There was a problem hiding this comment.
| f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}" | ||
| ) | ||
| assert HV > 0 and HQK > 0 and HV % HQK == 0, ( | ||
| f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" |
There was a problem hiding this comment.
The error message incorrectly uses the term 'head-dim' when referring to HV and HQK, which represent the number of heads (head count). The head dimension is represented by K.
| f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" | |
| f"v head count (HV={HV}) must be a positive multiple of k head count (HQK={HQK})" |
Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads. C++ changes: - tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour. - kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space. - intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v. - recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v. API / Python: - kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes. - cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions. Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout.
e0e3494 to
58535e2
Compare
|
@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell. |
Thanks for your contribution. You can try to first run |
b291751 to
c6a492b
Compare
PR: feat/kda-sm100-gva
Summary
Adds GVA (Grouped Value Attention) support to the KDA training path (
chunk_kda) on SM100 (Blackwell), allowingnum_v_heads (HV) > num_qk_heads (HQK)withHV % HQK == 0.Tensor layout follows the gated delta rule GVA convention:
q,kHQKv,g,beta,o, stateHVWhen
HV == HQK, behavior matches the existing MHA path.Motivation
GVA shares fewer Q/K heads across multiple V heads, reducing Q/K compute and memory while preserving model capacity. This PR wires native HQK/HV shapes through the full cuLA KDA stack (SM100 CUDA kernels, Triton backward, Python orchestration) instead of only simulating GVA via host-side
repeat_interleave.Changes
1. SM100 CUDA kernels (
csrc/kda/sm100/+csrc/api/kda_sm100.cu)kda_config.hpp: Splithintoh_qk,h_v, andheads_per_group; document Q/K vs V/g/beta/A layouts and strides for GVA.tile_scheduler.hpp: Enumerate tiles by v-head;decode_tile_coordreturns(batch_idx, v_head_idx, seq_idx, qk_head_idx)withqk_head_idx = v_head_idx / heads_per_group.kda_fwd_intra_*: TMA/load usesh_qkstrides for Q/K andh_vstrides for V/g.kda_fwd_recomp_w_u_*: Same head-space split; intermediates (w,u,kg,qg, etc.) remain in HV space.kda_sm100.cu: Deriveh_qk/h_vfromq.size(2)andg.size(2)(orv.size(2)), validateh_v % h_qk == 0, and passheads_per_groupinto the tile scheduler.2. Python training orchestration (
cula/kda/)chunk.pychunk_kdaentry asserts:q/kare[B, T, HQK, D];v/g/betause HVchunk_fwd.pyrepeat_interleaveonqbeforefwd_owhenHV > HQK(host compat layer;fwd_ostill expects a unified head dim)chunk_intra.pyB × HV;i_hqk = i_h // (HV // HQK); HQK strides for q/k/dq/dk, HV strides for g/beta/dAchunk_bwd.pyq/kfromdAv; addHQKconstexpr towy_dqkg_fusedand fix q/k/dq/dk pointer offsets3. Tests
tests/test_kda_gva_intra_sm100.py(new): GVA tests for SM100 intra / recompHV == HQKmatches non-GVAHV % HQKtests/test_kda.py:test_chunk_kda_gvaandtest_chunk_kda_gva_varlen— end-to-end forward/backward vs FLA reference (with q/k expanded to HV);dq/dkcompared after summing over the group axis4. Benchmarks
benchmarks/utils.py: Addprepare_safe_gate_inputs_gva(q/k stay in HQK; v/g/beta in HV)bench_kda.py/bench_kda_fwd_bwd_e2e.py: Unified--hvflag (GVA whenHVis a multiple ofH), aligned withbench_kda_fused_fwd.pybench_kda_chunk_intra.py: GVA configs and comparisons5. Misc
cula/utils.py: Updateget_kda_fused_fwddocstring (Blackwell fused prefill stillNotImplementedError)Out of scope (not changed in this PR)
cula/kda/blackwell_fused_fwd.py/ops/kda_fully_fused_wip.py: Blackwell fused prefill still requiresHQK == HVops/fwd_o.py/ops/chunk_delta_h.py: No in-kernel native GVA;fwd_orelies on host-side q expansion inchunk_fwdcsrc/kda/sm90/in this branch; Hopper prefill GVA lives in the existinghopper_fused_fwd+kda_sm90.cupathDesign notes
Test plan
On an SM100 machine:
tests/test_kda_gva_intra_sm100.pypassestest_chunk_kda_gva/test_chunk_kda_gva_varlenpassHV == HQK) show no regressionpytest tests/test_kda_gva_intra_sm100.py -v