[KDA] Add intra-card CP for chunk_delta_h forward in SM100#70
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, adding specialized kernels for pre-scan and merge stages along with an auto-dispatch mechanism for long sequences. Feedback focuses on optimizing the pre-scan kernel by gating the key gate (gk) loading and decay logic when unused, which would prevent potential out-of-memory errors from large zero-tensor allocations and reduce memory traffic.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, featuring a three-stage pipeline (pre-scan, merge, and forward) to optimize long sequence processing. It includes specialized CuTeDSL kernels for Blackwell architectures, an auto-dispatch heuristic, and comprehensive benchmarks and tests. Reviewers identified a potential cache key collision in multi-GPU setups and recommended using torch.cuda.current_device() for safer indexing. Additionally, a function-level import of the math module should be moved to the top of the file to adhere to PEP 8 standards.
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operator, introducing a three-stage pipeline (pre-scan, merge, and forward H) to parallelize the processing of long sequences. The implementation includes specialized CuTeDSL kernels for fused pre-scan and prefix-scan merge operations, along with an auto-dispatch heuristic that triggers CP based on sequence length and SM occupancy. A new benchmark script and comprehensive accuracy tests are also provided. The reviewer noted a potential performance bottleneck where the auto-dispatch logic might trigger a synchronous GPU-to-CPU transfer if cu_seqlens_cpu is not provided, suggesting that this behavior should be documented or the heuristic refined to avoid unnecessary synchronization.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces an intra-card context-parallel (CP) implementation for the chunk_delta_h operator, designed to accelerate the processing of long sequences on a single GPU. Key additions include a fused pre-scan kernel optimized for Blackwell architectures, a prefix-scan merge kernel, and an auto-dispatch heuristic that dynamically engages the CP path. The PR also includes a new benchmark suite and comprehensive accuracy tests. Reviewer feedback identifies opportunities to simplify the CP predicate API by removing unused parameters, optimize performance by moving local imports to the top level, and enhance robustness by replacing uninitialized dummy tensors with zero-initialized ones.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request implements Intra-Card Context Parallel (CP) for the chunk_delta_h operation, optimizing long sequence processing through a three-stage pipeline consisting of pre-scan, prefix-scan merge, and forward H stages. The implementation leverages CuTeDSL for high-performance kernels and includes an auto-dispatch heuristic to balance CP overhead against potential speedups. Additionally, the PR adds a benchmark suite and extensive tests to ensure numerical consistency with the non-CP baseline and PyTorch references. Reviewer feedback highlights opportunities to improve CPU efficiency by optimizing tensor construction in _prepare_chunk_indices and streamlining the metadata caching logic in the merge step.
Replace _get_meta_tensors LRU cache with direct torch.tensor calls. The secondary cache required converting lists to tuples for hashing on every merge_fwd call. Since the metadata lists are tiny (1-4 int32 values), the hashing overhead outweighs any benefit. Each torch.tensor call produces an independently aligned allocation required by from_dlpack(assumed_align=16). A single merged tensor with slice views would not guarantee per-slice 16-byte alignment.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request introduces Intra-Card Context Parallel (CP) for the chunk_delta_h operation, which improves performance for long sequences on a single GPU by parallelizing across the sequence dimension. The implementation includes new CuTeDSL kernels for pre-scan and merge stages, an auto-dispatch heuristic, and comprehensive benchmarking and testing tools. A review comment correctly identified that the global _intracard_cache is not thread-safe and recommended implementing a threading.Lock to prevent race conditions when the function is called concurrently.
| _intracard_cache: OrderedDict[tuple, _CacheEntry] = OrderedDict() | ||
| _INTRACARD_CACHE_MAXSIZE = 8 |
There was a problem hiding this comment.
The global _intracard_cache is not thread-safe. OrderedDict operations are not atomic, which can lead to race conditions if intracard_fwd_h is called from multiple threads concurrently with different sequence layouts. This could corrupt the cache's internal state, leading to incorrect results or crashes.
To ensure thread safety, you should wrap all accesses and modifications to _intracard_cache with a threading.Lock.
For example:
import threading
_intracard_cache_lock = threading.Lock()Then, wrap cache operations inside intracard_fwd_h:
# Reading from cache
with _intracard_cache_lock:
cached = _intracard_cache.get(cache_key)
if cached is not None:
# ... cache validation and move_to_end
# Writing to cache
with _intracard_cache_lock:
_intracard_cache[cache_key] = _CacheEntry(...)
while len(_intracard_cache) > _INTRACARD_CACHE_MAXSIZE:
_intracard_cache.popitem(last=False)| ("32K+256+256", [32768, 256, 256]), | ||
| ("40K+1K+8K", [40960, 1024, 8192]), | ||
| ("64K+512+256+128", [65536, 512, 256, 128]), | ||
| ("128K+1K", [131072, 1024]), |
There was a problem hiding this comment.
Just be curious, what's the performance of more cases like one 128K + several, say, 10 1K short seqs?
| _intracard_cache: OrderedDict[tuple, _CacheEntry] = OrderedDict() | ||
| _INTRACARD_CACHE_MAXSIZE = 8 |
| subseq_chunks = (seq_chunks + target_splits - 1) // target_splits | ||
| subseq_chunks = max(subseq_chunks, MIN_SUBSEQ_CHUNKS) | ||
|
|
||
| subseq_chunks = 2 ** round(math.log2(subseq_chunks)) |
There was a problem hiding this comment.
Why should the chunk count be set to a power of 2?
| NT = (seq_len + BT - 1) // BT | ||
|
|
||
| # Grid-level dispatch: he mode vs m mode | ||
| is_he_mode = tile_idx < num_v_tiles |
There was a problem hiding this comment.
As discussed, we might need to check the alternative method in the later sm90 work, i.e. fuse he & m in the same CTA so that we could fuse most of the gmem io.
📌 Description
The serial bottleneck
chunk_gated_delta_rule_fwd_hruns a sequential chunk recurrence inside each sequence(
h_t = decay_t · h_{t-1} + k_t^T @ v_t), so within one sequence the work cannotparallelize across chunks — only across the
(NUM_V_BLOCKS, H × num_seqs)grid.This becomes a bottleneck when two conditions hit at the same time:
H × num_seqsis small → the baseline grid doesn't fill the SMs. A singlelong sequence at H=4 occupies only
2 × 4 × 1 = 8SM units on a 152-SM B200.chunk recurrence dominates wall time, while short seqs finish early and leave SMs
idle waiting on the one long chain.
Approach
Inspired by FLA's intracard-CP design (README),
this PR splits long sequences into sub-sequences on the same card and parallelizes the
recurrence across them via a 3-stage pipeline:
(he, m): exit state[K,V]+ decay matrix[K,K]chunk_gated_delta_rule_fwd_hon the split sub-seqsActivation strategy
A lightweight CPU-side predicate (
should_use_intracard_cp) decides up-front whetherto dispatch to CP — workloads that don't benefit pay zero overhead. Four guards:
2·H·num_seqs ≥ SM)max_chunks < 256↔ <16K tokens)Be·H > 10, whereBe = Σchunks / max(chunks)is the length-weighted effective batch size —
Be → 1means one dominant seq,Be → num_seqsmeans balanced)total_subseqs · 2 · H > SM→ fall back to baselineSummary: CP fires when
H × num_seqsis small &&Beis close to 1 — i.e.,the batch has a long-tail seq serializing the card.
The thresholds (
MIN_SUBSEQ_CHUNKS=16,MIN_LONG_SEQ_CHUNKS=256,MAX_BE_H=10,NUM_V_BLOCKS=2) are manually tuned on B200 sm100 based on bench sweeps. They areisolated at the top of
cula/ops/cp/chunk_delta_h.py.🔍 Related Issues
Closes #20
🧪 Tests
⚡ Performance