Skip to content

[KDA] Add intra-card CP for chunk_delta_h forward in SM100#70

Open
cherhh wants to merge 10 commits into
inclusionAI:mainfrom
cherhh:dev.kcp.v2
Open

[KDA] Add intra-card CP for chunk_delta_h forward in SM100#70
cherhh wants to merge 10 commits into
inclusionAI:mainfrom
cherhh:dev.kcp.v2

Conversation

@cherhh
Copy link
Copy Markdown
Collaborator

@cherhh cherhh commented May 17, 2026

📌 Description

The serial bottleneck

chunk_gated_delta_rule_fwd_h runs a sequential chunk recurrence inside each sequence
(h_t = decay_t · h_{t-1} + k_t^T @ v_t), so within one sequence the work cannot
parallelize across chunks — only across the (NUM_V_BLOCKS, H × num_seqs) grid.

This becomes a bottleneck when two conditions hit at the same time:

  • H × num_seqs is small → the baseline grid doesn't fill the SMs. A single
    long sequence at H=4 occupies only 2 × 4 × 1 = 8 SM units on a 152-SM B200.
  • the varlen batch is highly uneven with a long-tail seq → the long seq's serial
    chunk recurrence dominates wall time, while short seqs finish early and leave SMs
    idle waiting on the one long chain.

Approach

Inspired by FLA's intracard-CP design (README),
this PR splits long sequences into sub-sequences on the same card and parallelizes the
recurrence across them via a 3-stage pipeline:

  1. Pre-scan — per sub-seq, compute packed (he, m): exit state [K,V] + decay matrix [K,K]
  2. Merge — prefix scan across sub-seqs of the same original seq → init states for non-first sub-seqs
  3. Forward H — run the existing chunk_gated_delta_rule_fwd_h on the split sub-seqs

Activation strategy

A lightweight CPU-side predicate (should_use_intracard_cp) decides up-front whether
to dispatch to CP — workloads that don't benefit pay zero overhead. Four guards:

  • Guard 0: baseline already saturates SMs (2·H·num_seqs ≥ SM)
  • Guard 1: longest seq too short to amortize CP overhead (max_chunks < 256 ↔ <16K tokens)
  • Guard 2: existing parallelism already high (Be·H > 10, where Be = Σchunks / max(chunks)
    is the length-weighted effective batch size — Be → 1 means one dominant seq,
    Be → num_seqs means balanced)
  • Guard 3 (post-split): total_subseqs · 2 · H > SM → fall back to baseline

Summary: CP fires when H × num_seqs is small && Be is close to 1 — i.e.,
the batch has a long-tail seq serializing the card.

The thresholds (MIN_SUBSEQ_CHUNKS=16, MIN_LONG_SEQ_CHUNKS=256, MAX_BE_H=10,
NUM_V_BLOCKS=2) are manually tuned on B200 sm100 based on bench sweeps. They are
isolated at the top of cula/ops/cp/chunk_delta_h.py.

🔍 Related Issues

Closes #20

🧪 Tests

python -m pytest tests/test_intracard_cp.py -v
platform linux -- Python 3.12.3, pytest-9.0.2
collected 29 items

tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens0-4-False] PASSED  [  3%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens1-4-True] PASSED   [  6%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens2-4-True] PASSED   [ 10%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens3-8-True] PASSED   [ 13%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens4-4-True] PASSED   [ 17%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens5-4-False] PASSED  [ 20%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens6-4-True] PASSED   [ 24%]
tests/test_intracard_cp.py::test_cp_autodispatch_matches_baseline[seq_lens7-8-True] PASSED   [ 27%]
tests/test_intracard_cp.py::test_cp_autodispatch_with_h0[seq_lens0-4] PASSED                 [ 31%]
tests/test_intracard_cp.py::test_cp_autodispatch_with_h0[seq_lens1-4] PASSED                 [ 34%]
tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[32768-4] PASSED                      [ 37%]
tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[65536-4] PASSED                      [ 41%]
tests/test_intracard_cp.py::test_cp_autodispatch_vs_fla[32768-8] PASSED                      [ 44%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens0-4-False-False] PASSED [ 48%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens1-4-True-True] PASSED   [ 51%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens2-4-True-True] PASSED   [ 55%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens3-4-True-False] PASSED  [ 58%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens4-4-False-True] PASSED  [ 62%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens5-4-True-True] PASSED   [ 65%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens6-4-True-False] PASSED  [ 68%]
tests/test_intracard_cp.py::test_intracard_cp_vs_pytorch_ref[seq_lens7-8-True-True] PASSED   [ 72%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens0-4-False-False] PASSED [ 75%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens1-4-True-True] PASSED  [ 79%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens2-8-True-True] PASSED  [ 82%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens3-4-True-True] PASSED  [ 86%]
tests/test_intracard_cp.py::test_intracard_cp_final_state_per_seq[seq_lens4-4-True-False] PASSED [ 89%]
tests/test_intracard_cp.py::test_intracard_cp_stress_repeat[single-32K-H4-gk-h0] PASSED     [ 93%]
tests/test_intracard_cp.py::test_intracard_cp_stress_repeat[multi-32K+4K-H4-gk-h0] PASSED   [ 96%]
tests/test_intracard_cp.py::test_intracard_cp_h0_none_equiv_h0_zeros PASSED                  [100%]

29 passed in 33.52s

⚡ Performance

python benchmarks/bench_intracard_cp.py
Device: NVIDIA GB200 (SM=152)
Bench : warmup=10, n_iters=100

--- H=4 ---
config                         T pred  sub    CP_off     CP_on   speedup
------------------------------------------------------------------------
T=4K                        4096    N     0     0.372     0.370     1.01x
T=8K                        8192    N     0     0.486     0.483     1.01x
T=32K                      32768    Y    16     1.315     0.924     1.42x
T=64K                      65536    Y    16     2.462     1.334     1.85x
T=128K                    131072    Y    16     4.769     2.196     2.17x
8x4K                       32768    N     0     0.549     0.550     1.00x
4x8K                       32768    N     0     0.658     0.658     1.00x
2x16K                      32768    Y    32     0.876     0.879     1.00x
16K+16K                    32768    Y    32     0.875     0.877     1.00x
24K+8K                     32768    Y    32     1.095     1.096     1.00x
28K+4K                     32768    Y    15     1.202     1.008     1.19x
32K+256+256                33280    Y    18     1.315     0.919     1.43x
40K+1K+8K                  50176    Y    25     1.676     1.676     1.00x
64K+512+256+128            66432    Y    19     2.475     1.342     1.84x
128K+1K                   132096    Y    17     4.783     2.201     2.17x

--- H=8 ---
config                         T pred  sub    CP_off     CP_on   speedup
------------------------------------------------------------------------
T=4K                        4096    N     0     0.372     0.377     0.99x
T=8K                        8192    N     0     0.518     0.516     1.00x
T=32K                      32768    Y     8     1.575     1.355     1.16x
T=64K                      65536    Y     8     3.004     2.203     1.36x
T=128K                    131072    Y     8     5.851     3.900     1.50x
8x4K                       32768    N     0     0.816     0.815     1.00x
4x8K                       32768    N     0     0.925     0.924     1.00x
2x16K                      32768    N     0     1.143     1.142     1.00x
16K+16K                    32768    N     0     1.143     1.142     1.00x
24K+8K                     32768    N     0     1.362     1.359     1.00x
28K+4K                     32768    Y     8     1.471     1.358     1.08x
32K+256+256                33280    Y    10     1.597     1.595     1.00x
40K+1K+8K                  50176    Y    12     2.094     2.093     1.00x
64K+512+256+128            66432    Y    11     3.024     3.025     1.00x
128K+1K                   132096    Y     9     5.869     3.917     1.50x

CP triggered (19 configs): geo-mean=1.30x  best=2.17x  worst=1.00x
CP bypassed  (11 configs): mean overhead=0.999x  max=1.013x  (1.00 = no regression)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, adding specialized kernels for pre-scan and merge stages along with an auto-dispatch mechanism for long sequences. Feedback focuses on optimizing the pre-scan kernel by gating the key gate (gk) loading and decay logic when unused, which would prevent potential out-of-memory errors from large zero-tensor allocations and reduce memory traffic.

Comment thread cula/ops/cp/pre_scan.py Outdated
Comment thread cula/ops/cp/pre_scan.py Outdated
Comment thread cula/ops/cp/pre_scan.py Outdated
@cherhh cherhh changed the title Dev.kcp.v2 [KDA] Add intra-card CP for chunk_delta_h forward in SM100 May 17, 2026
@cherhh cherhh requested review from KevinZeng08 and icavan May 17, 2026 17:16
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, featuring a three-stage pipeline (pre-scan, merge, and forward) to optimize long sequence processing. It includes specialized CuTeDSL kernels for Blackwell architectures, an auto-dispatch heuristic, and comprehensive benchmarks and tests. Reviewers identified a potential cache key collision in multi-GPU setups and recommended using torch.cuda.current_device() for safer indexing. Additionally, a function-level import of the math module should be moved to the top of the file to adhere to PEP 8 standards.

Comment thread cula/ops/cp/merge.py Outdated
Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, introducing a three-stage pipeline (pre-scan, merge, and forward H) to parallelize the processing of long sequences. The implementation includes specialized CuTeDSL kernels for fused pre-scan and prefix-scan merge operations, along with an auto-dispatch heuristic that triggers CP based on sequence length and SM occupancy. A new benchmark script and comprehensive accuracy tests are also provided. The reviewer noted a potential performance bottleneck where the auto-dispatch logic might trigger a synchronous GPU-to-CPU transfer if cu_seqlens_cpu is not provided, suggesting that this behavior should be documented or the heuristic refined to avoid unnecessary synchronization.

Comment thread cula/ops/chunk_delta_h.py Outdated
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an intra-card context-parallel (CP) implementation for the chunk_delta_h operator, designed to accelerate the processing of long sequences on a single GPU. Key additions include a fused pre-scan kernel optimized for Blackwell architectures, a prefix-scan merge kernel, and an auto-dispatch heuristic that dynamically engages the CP path. The PR also includes a new benchmark suite and comprehensive accuracy tests. Reviewer feedback identifies opportunities to simplify the CP predicate API by removing unused parameters, optimize performance by moving local imports to the top level, and enhance robustness by replacing uninitialized dummy tensors with zero-initialized ones.

Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/merge.py Outdated
Comment thread cula/ops/cp/pre_scan.py Outdated
cherhh and others added 3 commits May 18, 2026 11:40
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, optimizing long sequence processing through a three-stage pipeline consisting of pre-scan, prefix-scan merge, and forward H stages. The implementation leverages CuTeDSL for high-performance kernels and includes an auto-dispatch heuristic to balance CP overhead against potential speedups. Additionally, the PR adds a benchmark suite and extensive tests to ensure numerical consistency with the non-CP baseline and PyTorch references. Reviewer feedback highlights opportunities to improve CPU efficiency by optimizing tensor construction in _prepare_chunk_indices and streamlining the metadata caching logic in the merge step.

Comment thread cula/ops/cp/chunk_delta_h.py Outdated
Comment thread cula/ops/cp/merge.py Outdated
cherhh and others added 3 commits May 18, 2026 12:55
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Replace _get_meta_tensors LRU cache with direct torch.tensor calls.
The secondary cache required converting lists to tuples for hashing
on every merge_fwd call. Since the metadata lists are tiny (1-4 int32
values), the hashing overhead outweighs any benefit.

Each torch.tensor call produces an independently aligned allocation
required by from_dlpack(assumed_align=16). A single merged tensor with
slice views would not guarantee per-slice 16-byte alignment.
@cherhh
Copy link
Copy Markdown
Collaborator Author

cherhh commented May 18, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Intra-Card Context Parallel (CP) for the chunk_delta_h operation, which improves performance for long sequences on a single GPU by parallelizing across the sequence dimension. The implementation includes new CuTeDSL kernels for pre-scan and merge stages, an auto-dispatch heuristic, and comprehensive benchmarking and testing tools. A review comment correctly identified that the global _intracard_cache is not thread-safe and recommended implementing a threading.Lock to prevent race conditions when the function is called concurrently.

Comment on lines +82 to +83
_intracard_cache: OrderedDict[tuple, _CacheEntry] = OrderedDict()
_INTRACARD_CACHE_MAXSIZE = 8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The global _intracard_cache is not thread-safe. OrderedDict operations are not atomic, which can lead to race conditions if intracard_fwd_h is called from multiple threads concurrently with different sequence layouts. This could corrupt the cache's internal state, leading to incorrect results or crashes.

To ensure thread safety, you should wrap all accesses and modifications to _intracard_cache with a threading.Lock.

For example:

import threading

_intracard_cache_lock = threading.Lock()

Then, wrap cache operations inside intracard_fwd_h:

# Reading from cache
with _intracard_cache_lock:
    cached = _intracard_cache.get(cache_key)
    if cached is not None:
        # ... cache validation and move_to_end

# Writing to cache
with _intracard_cache_lock:
    _intracard_cache[cache_key] = _CacheEntry(...)
    while len(_intracard_cache) > _INTRACARD_CACHE_MAXSIZE:
        _intracard_cache.popitem(last=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cherhh could you check this

("32K+256+256", [32768, 256, 256]),
("40K+1K+8K", [40960, 1024, 8192]),
("64K+512+256+128", [65536, 512, 256, 128]),
("128K+1K", [131072, 1024]),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just be curious, what's the performance of more cases like one 128K + several, say, 10 1K short seqs?

Comment on lines +82 to +83
_intracard_cache: OrderedDict[tuple, _CacheEntry] = OrderedDict()
_INTRACARD_CACHE_MAXSIZE = 8
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cherhh could you check this

subseq_chunks = (seq_chunks + target_splits - 1) // target_splits
subseq_chunks = max(subseq_chunks, MIN_SUBSEQ_CHUNKS)

subseq_chunks = 2 ** round(math.log2(subseq_chunks))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should the chunk count be set to a power of 2?

Comment thread cula/ops/cp/pre_scan.py
NT = (seq_len + BT - 1) // BT

# Grid-level dispatch: he mode vs m mode
is_he_mode = tile_idx < num_v_tiles
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, we might need to check the alternative method in the later sm90 work, i.e. fuse he & m in the same CTA so that we could fuse most of the gmem io.

Copy link
Copy Markdown
Collaborator

@icavan icavan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Identify and implement CUDA optimization opportunities for Intracard CP (single-card sequence splitting)

2 participants