feat: replace RMSNorm backward with persistent CuTile kernel#60
Open
aghilann wants to merge 4 commits intoNVIDIA:mainfrom
Open
feat: replace RMSNorm backward with persistent CuTile kernel#60aghilann wants to merge 4 commits intoNVIDIA:mainfrom
aghilann wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
added 3 commits
February 14, 2026 20:10
…on benchmark Forward kernels (gather + static persistent) remain unchanged except the persistent kernel now also stores rstd so backward works from both modes. Backward: replaced old one-row-per-block approach (M×N temp buffer) with Bastile's grid-stride persistent kernel (grid × TILE_N partial sums for dw). - Both forward modes now support backward (previously only gather did) - Removed unused ConstInt/ConstFloat/PAD_ZERO aliases, import math, experimental_kernel - Added bench_rmsnorm_tilegym_vs_bastile.py comparison benchmark - All 8 correctness tests pass, benchmark numbers unchanged
…cision The rms_norm_backward_torch reference was computing x*dy in bf16/fp16 before casting to fp32, losing precision. The CuTile kernel correctly operates in fp32 throughout. Fixed reference to cast to fp32 upfront so both agree. All 13 tests now pass (5 experimental backward + 8 fwd+bwd).
Contributor
Author
|
Could I get a review when you get a chance @hannahli-nv ! |
Collaborator
|
/ok to test 4d947d3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Replaces the old one-row-per-block RMSNorm backward kernel with a persistent grid-stride kernel that fuses
dwaccumulation into a compact(grid × TILE_N)partial-sum buffer instead of allocating anM×Ntemp buffer. I wrote this kernel and was able to get it to exceed the performance of Liger's Triton Kernels and get quite close to the performance of Quacks CuteDSL Kernel (which I'm somewhat assuming is near peak-performance)Changes:
_rms_bwdwith grid-stride loop and fuseddwaccumulationrstdso backward works from both modesrms_norm_backward_torchnow casts to fp32 upfront, matching kernel precisionBackward kernel throughput (GB/s) — standalone, M=4096
CI Configuration
Checklist
./format.sh)