Skip to content

feat: replace RMSNorm backward with persistent CuTile kernel#60

Open
aghilann wants to merge 4 commits intoNVIDIA:mainfrom
aghilann:benchmark-rmsnorm-comparison
Open

feat: replace RMSNorm backward with persistent CuTile kernel#60
aghilann wants to merge 4 commits intoNVIDIA:mainfrom
aghilann:benchmark-rmsnorm-comparison

Conversation

@aghilann
Copy link
Contributor

@aghilann aghilann commented Feb 15, 2026

Description

Replaces the old one-row-per-block RMSNorm backward kernel with a persistent grid-stride kernel that fuses dw accumulation into a compact (grid × TILE_N) partial-sum buffer instead of allocating an M×N temp buffer. I wrote this kernel and was able to get it to exceed the performance of Liger's Triton Kernels and get quite close to the performance of Quacks CuteDSL Kernel (which I'm somewhat assuming is near peak-performance)

Changes:

  • Backward kernel: New persistent _rms_bwd with grid-stride loop and fused dw accumulation
  • Forward kernels: Unchanged (gather + static persistent), except persistent now also stores rstd so backward works from both modes
  • Reference fix: rms_norm_backward_torch now casts to fp32 upfront, matching kernel precision
  • Removed unused aliases and imports

Backward kernel throughput (GB/s) — standalone, M=4096

N Old Version New Version PyTorch Old→New vs PyTorch
bf16
1024 1,534 3,020 482 2.0x 6.3x
2048 2,274 6,709 549 2.9x 12.2x
4096 2,823 6,100 538 2.2x 11.3x
8192 3,652 8,008 552 2.2x 14.5x
16384 3,762 4,135 573 1.1x 7.2x
fp16
1024 1,561 3,197 483 2.0x 6.6x
2048 2,368 4,454 553 1.9x 8.1x
4096 2,933 6,131 540 2.1x 11.4x
8192 3,578 7,987 553 2.2x 14.4x
16384 4,020 4,332 574 1.1x 7.5x
fp32
1024 2,268 4,940 955 2.2x 5.2x
2048 2,923 7,357 1,099 2.5x 6.7x
4096 3,901 7,140 1,030 1.8x 6.9x
8192 3,667 9,298 1,051 2.5x 8.8x
16384 2,634 6,943 1,090 2.6x 6.4x

CI Configuration

config:
  build: true
  # valid options are "ops" and "benchmark"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

root added 3 commits February 14, 2026 20:10
…on benchmark

Forward kernels (gather + static persistent) remain unchanged except
the persistent kernel now also stores rstd so backward works from both modes.

Backward: replaced old one-row-per-block approach (M×N temp buffer) with
Bastile's grid-stride persistent kernel (grid × TILE_N partial sums for dw).

- Both forward modes now support backward (previously only gather did)
- Removed unused ConstInt/ConstFloat/PAD_ZERO aliases, import math, experimental_kernel
- Added bench_rmsnorm_tilegym_vs_bastile.py comparison benchmark
- All 8 correctness tests pass, benchmark numbers unchanged
…cision

The rms_norm_backward_torch reference was computing x*dy in bf16/fp16
before casting to fp32, losing precision. The CuTile kernel correctly
operates in fp32 throughout. Fixed reference to cast to fp32 upfront
so both agree. All 13 tests now pass (5 experimental backward + 8 fwd+bwd).
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aghilann aghilann changed the title feat: re-write RMSNorm backward kernel for greatly improved performance feat: replace RMSNorm backward with persistent CuTile kernel Feb 15, 2026
@aghilann
Copy link
Contributor Author

aghilann commented Feb 15, 2026

Could I get a review when you get a chance @hannahli-nv !

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 17, 2026

/ok to test 4d947d3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants