feat: Add SM90 CuTe DSL bwd_dhu kernel#71
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the SM90 CuTe DSL implementation for the backward DHU kernel, supporting both non-varlen and varlen modes along with various gating and state layout options. The PR also includes a benchmark script and a correctness test suite. A review comment identifies a performance concern where the kernel compilation is cached on concrete problem dimensions, suggesting a refactor to use symbolic dimensions to better handle dynamic shapes and prevent excessive recompilation.
| @functools.lru_cache(maxsize=64) | ||
| def _compile_bwd_dhu_sm90( | ||
| B: int, | ||
| T: int, | ||
| N: int, | ||
| NT: int, | ||
| H: int, | ||
| K: int, | ||
| V: int, | ||
| is_varlen: bool, | ||
| use_g: bool, | ||
| use_gk: bool, | ||
| use_dht: bool, | ||
| use_dh0: bool, | ||
| use_exp2: bool, | ||
| transpose_state_layout: bool, | ||
| scale: float, | ||
| ): |
There was a problem hiding this comment.
The current implementation of _compile_bwd_dhu_sm90 is cached on problem dimensions (B, T, N, NT), which will trigger a new kernel compilation for every unique input shape. This leads to long warm-up times and high memory usage from the compilation cache, making the kernel inefficient for dynamic shapes.
To address this, I recommend refactoring to use symbolic dimensions for fake tensors during compilation, and passing the concrete problem dimensions at runtime. This is a common pattern in cutlass-python and is used in other parts of this repository (e.g., cula.ops.chunk_delta_h.ChunkDeltaRuleFwdH).
This would involve:
- Removing
B,T,N,NTfrom the_compile_bwd_dhu_sm90function signature and itslru_cachekey. - Using
cute.sym_int()to define symbolic dimensions for creating fake tensors inside_compile_bwd_dhu_sm90. - Modifying
ChunkDeltaRuleBwdDHUSm90to accept problem dimensions at runtime in its__call__orkernelmethod, rather than in__init__. - Updating
chunk_gated_delta_rule_bwd_dhu_sm90to pass the problem dimensions to the compiled kernel at runtime.
This change will allow a single compiled kernel to handle various input sizes, significantly improving performance and usability.
Kernel overviewThis kernel implements the SM90 CuTe DSL path for Each CTA owns one The recurrence is: High-level pseudocode:one CTA handles one (v_tile, batch/sequence, head). Warp rolesThe CTA uses 7 warps:
Each specialized warp runs its own reverse-chunk loop. The loops are synchronized through the load/store pipeline barriers rather than by sharing a single syntactic loop body. |
|
Hi, @KevinZeng08 Performance questionI ran the current SM90
For example, with My guess is that this shape only launches
Modes with Do you have suggestions for the next optimization direction?
If you have access to other Hopper machines, it would also be great to get additional performance numbers on H100/H200, since I currently only tested on my available H800 PCIE setup. |
📌 Description
This PR adds an SM90 CuTe DSL implementation of
chunk_gated_delta_rule_bwd_dhu.The new path supports the current SM90 target shape constraints (
K=V=128,BT=64,BV=64) and handles both non-varlen and packed varlen inputs. It implements the backward DHU recurrence, producesdh, optionaldh0, anddv2.The implementation uses static CTA scheduling and follows the existing
chunk_delta_hforward kernel style, with warp-specialized TMA/WGMMA pipelines for loads, compute, and stores.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
⚡ Performance
1(SHORT, not representative):
2:
3:
Reviewer Notes