Skip to content

feat: Add SM90 CuTe DSL bwd_dhu kernel#71

Open
yechenzhi wants to merge 24 commits into
inclusionAI:mainfrom
yechenzhi:feat_bwd_delta
Open

feat: Add SM90 CuTe DSL bwd_dhu kernel#71
yechenzhi wants to merge 24 commits into
inclusionAI:mainfrom
yechenzhi:feat_bwd_delta

Conversation

@yechenzhi
Copy link
Copy Markdown
Contributor

📌 Description

This PR adds an SM90 CuTe DSL implementation of chunk_gated_delta_rule_bwd_dhu.

The new path supports the current SM90 target shape constraints (K=V=128, BT=64, BV=64) and handles both non-varlen and packed varlen inputs. It implements the backward DHU recurrence, produces dh, optional dh0, and dv2.

The implementation uses static CTA scheduling and follows the existing chunk_delta_h forward kernel style, with warp-specialized TMA/WGMMA pipelines for loads, compute, and stores.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

1(SHORT, not representative):

========================================================================================================================
                     BENCHMARK REPORT: chunk_delta_rule_bwd_dhu
                     CuTe DSL (Hopper SM90) vs FLA Triton
                     K=128  V=128  BT=64  dtype=bf16
                     Warmup=10  Iters=100
========================================================================================================================

  [Non-Varlen]
  ------------------------------------------------------------------------------------------------------------------------------------
  Config                                        |    max_abs    max_rel |   FLA(ms)  CuTe(ms)  Speedup Compiled
  ------------------------------------------------------------------------------------------------------------------------------------
  B= 1 T=  512 H=  4 mode=- [gk,dht]            |   0.000000  0.000e+00 |    0.0731    0.0253    2.89x      yes
  B= 1 T=  512 H=  4 mode=- [g,dht]             |   0.003906  5.208e-03 |    0.0728    0.0257    2.84x      yes
  B= 2 T= 1024 H= 64 mode=- [gk,dht,dh0]        |   0.000000  0.000e+00 |    0.1935    0.1774    1.09x      yes
  B= 1 T= 2048 H= 64 mode=- [gk,dht]            |   0.000000  0.000e+00 |    0.2167    0.1925    1.13x      yes
  ------------------------------------------------------------------------------------------------------------------------------------
  Geometric mean                                |                       |                        1.78x         

  [Varlen]
  ------------------------------------------------------------------------------------------------------------------------
                                                        Config |   max_diff    mean_diff |   FLA(ms)  CuTe(ms)  Speedup
  ------------------------------------------------------------------------------------------------------------------------
                 3seqs T=512 [117..221] avg=170 H=  2 [gk,dht] |   0.000000   0.00000000 |    0.0826    0.0236    3.49x
              4seqs T=768 [124..274] avg=192 H=  2 [g,dht,dh0] |   0.002436   0.00018821 |    0.0875    0.0234    3.74x
  ------------------------------------------------------------------------------------------------------------------------
                                                Geometric mean |                         |                        3.62x

========================================================================================================================

2:

========================================================================================================================
                     BENCHMARK REPORT: chunk_delta_rule_bwd_dhu
                     CuTe DSL (Hopper SM90) vs FLA Triton
                     K=128  V=128  BT=64  dtype=bf16
                     Warmup=10  Iters=100
========================================================================================================================

  [Non-Varlen]
  ------------------------------------------------------------------------------------------------------------------------------------
  Config                                        |    max_abs    max_rel |   FLA(ms)  CuTe(ms)  Speedup Compiled
  ------------------------------------------------------------------------------------------------------------------------------------
  B= 1 T= 8192 H= 64 mode=- [gk,dht,dh0]        |   0.000000  0.000e+00 |    0.7940    0.7054    1.13x      yes
  B= 2 T= 8192 H= 64 mode=- [gk,dht,dh0]        |   0.000000  0.000e+00 |    1.3946    1.3734    1.02x      yes
  B= 4 T= 8192 H= 64 mode=- [gk,dht,dh0]        |   0.000000  0.000e+00 |    2.7080    2.5960    1.04x      yes
  B= 8 T= 8192 H= 64 mode=- [gk,dht,dh0]        |   0.000000  0.000e+00 |    5.3910    5.1033    1.06x      yes
  ------------------------------------------------------------------------------------------------------------------------------------
  Geometric mean                                |                       |                        1.06x         

  [Varlen]
  ------------------------------------------------------------------------------------------------------------------------
                                                        Config |   max_diff    mean_diff |   FLA(ms)  CuTe(ms)  Speedup
  ------------------------------------------------------------------------------------------------------------------------
           20seqs T=8192 [296..571] avg=409 H= 64 [gk,dht,dh0] |   0.000000   0.00000000 |    0.8024    0.7463    1.08x
           25seqs T=8192 [197..558] avg=327 H= 64 [gk,dht,dh0] |   0.000000   0.00000000 |    0.8341    0.7716    1.08x
           20seqs T=8192 [205..762] avg=409 H= 64 [gk,dht,dh0] |   0.000000   0.00000000 |    0.8011    0.7446    1.08x
       20seqs T=32768 [1185..2286] avg=1638 H= 64 [gk,dht,dh0] |   0.000000   0.00000000 |    2.7649    2.7283    1.01x
        25seqs T=32768 [788..2233] avg=1310 H= 64 [gk,dht,dh0] |   0.000000   0.00000000 |    2.8013    2.7577    1.02x
  ------------------------------------------------------------------------------------------------------------------------
                                                Geometric mean |                         |                        1.05x

========================================================================================================================

3:

========================================================================================================================
                     BENCHMARK REPORT: chunk_delta_rule_bwd_dhu
                     CuTe DSL (Hopper SM90) vs FLA Triton
                     K=128  V=128  BT=64  dtype=bf16
                     Warmup=10  Iters=100
========================================================================================================================

  [Non-Varlen]
  ------------------------------------------------------------------------------------------------------------------------------------
  Config                                        |    max_abs    max_rel |   FLA(ms)  CuTe(ms)  Speedup Compiled
  ------------------------------------------------------------------------------------------------------------------------------------
  B= 1 T= 2048 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.0715    0.0738    0.97x      yes
  B= 1 T= 2048 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.0762    0.0737    1.03x      yes
  B= 1 T= 2048 H= 16 mode=C [g]                 |   0.003906  4.854e-03 |    0.0888    0.1109    0.80x      yes
  B= 1 T= 2048 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.0730    0.0809    0.90x      yes
  B= 1 T= 2048 H= 16 mode=E [g,gk]              |   0.003906  4.926e-03 |    0.1027    0.1128    0.91x      yes
  B= 1 T= 4096 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.1041    0.1425    0.73x      yes
  B= 1 T= 4096 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.1067    0.1370    0.78x      yes
  B= 1 T= 4096 H= 16 mode=C [g]                 |   0.003937  4.484e-03 |    0.1706    0.2173    0.79x      yes
  B= 1 T= 4096 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.1271    0.1551    0.82x      yes
  B= 1 T= 4096 H= 16 mode=E [g,gk]              |   0.003906  4.630e-03 |    0.1957    0.2195    0.89x      yes
  B= 1 T= 8192 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.2024    0.2780    0.73x      yes
  B= 1 T= 8192 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.2065    0.2576    0.80x      yes
  B= 1 T= 8192 H= 16 mode=C [g]                 |   0.003906  4.762e-03 |    0.3346    0.4293    0.78x      yes
  B= 1 T= 8192 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.2457    0.3017    0.81x      yes
  B= 1 T= 8192 H= 16 mode=E [g,gk]              |   0.003906  4.762e-03 |    0.3832    0.4261    0.90x      yes
  B= 1 T=16384 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.4013    0.5330    0.75x      yes
  B= 1 T=16384 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.4031    0.4938    0.82x      yes
  B= 1 T=16384 H= 16 mode=C [g]                 |   0.003906  4.739e-03 |    0.6558    0.8615    0.76x      yes
  B= 1 T=16384 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.4749    0.5994    0.79x      yes
  B= 1 T=16384 H= 16 mode=E [g,gk]              |   0.003906  4.878e-03 |    0.7593    0.8481    0.90x      yes
  B= 1 T= 2048 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.0862    0.0840    1.03x      yes
  B= 1 T= 2048 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.1069    0.0872    1.23x      yes
  B= 1 T= 2048 H= 32 mode=C [g]                 |   0.004883  5.061e-03 |    0.1295    0.1439    0.90x      yes
  B= 1 T= 2048 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.1108    0.0862    1.28x      yes
  B= 1 T= 2048 H= 32 mode=E [g,gk]              |   0.003906  4.630e-03 |    0.1397    0.1451    0.96x      yes
  B= 1 T= 4096 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.1803    0.1627    1.11x      yes
  B= 1 T= 4096 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.2027    0.1650    1.23x      yes
  B= 1 T= 4096 H= 32 mode=C [g]                 |   0.005859  6.707e-03 |    0.2442    0.2813    0.87x      yes
  B= 1 T= 4096 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.2116    0.1683    1.26x      yes
  B= 1 T= 4096 H= 32 mode=E [g,gk]              |   0.004883  5.252e-03 |    0.2717    0.2852    0.95x      yes
  B= 1 T= 8192 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.3275    0.3203    1.02x      yes
  B= 1 T= 8192 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.4024    0.3211    1.25x      yes
  B= 1 T= 8192 H= 32 mode=C [g]                 |   0.003906  4.167e-03 |    0.4813    0.5565    0.86x      yes
  B= 1 T= 8192 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.4079    0.3296    1.24x      yes
  B= 1 T= 8192 H= 32 mode=E [g,gk]              |   0.003967  4.184e-03 |    0.5375    0.5604    0.96x      yes
  B= 1 T=16384 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.6455    0.6313    1.02x      yes
  B= 1 T=16384 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.7710    0.6303    1.22x      yes
  B= 1 T=16384 H= 32 mode=C [g]                 |   0.005859  5.515e-03 |    0.9364    1.1287    0.83x      yes
  B= 1 T=16384 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.7994    0.6478    1.23x      yes
  B= 1 T=16384 H= 32 mode=E [g,gk]              |   0.003906  4.505e-03 |    1.0477    1.1210    0.93x      yes
  B= 2 T= 2048 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.1296    0.0837    1.55x      yes
  B= 2 T= 2048 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.1264    0.0871    1.45x      yes
  B= 2 T= 2048 H= 16 mode=C [g]                 |   0.004883  5.020e-03 |    0.1837    0.1347    1.36x      yes
  B= 2 T= 2048 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.1300    0.0864    1.50x      yes
  B= 2 T= 2048 H= 16 mode=E [g,gk]              |   0.003906  4.630e-03 |    0.1882    0.1386    1.36x      yes
  B= 2 T= 4096 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.2366    0.1617    1.46x      yes
  B= 2 T= 4096 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.2406    0.1643    1.46x      yes
  B= 2 T= 4096 H= 16 mode=C [g]                 |   0.004028  4.911e-03 |    0.3459    0.2650    1.31x      yes
  B= 2 T= 4096 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.2525    0.1668    1.51x      yes
  B= 2 T= 4096 H= 16 mode=E [g,gk]              |   0.003906  4.762e-03 |    0.3604    0.2707    1.33x      yes
  B= 2 T= 8192 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.4638    0.3183    1.46x      yes
  B= 2 T= 8192 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.4690    0.3188    1.47x      yes
  B= 2 T= 8192 H= 16 mode=C [g]                 |   0.004883  5.036e-03 |    0.6720    0.5187    1.30x      yes
  B= 2 T= 8192 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.5010    0.3294    1.52x      yes
  B= 2 T= 8192 H= 16 mode=E [g,gk]              |   0.003906  4.902e-03 |    0.7039    0.5302    1.33x      yes
  B= 2 T=16384 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.9347    0.6254    1.49x      yes
  B= 2 T=16384 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.9357    0.7239    1.29x      yes
  B= 2 T=16384 H= 16 mode=C [g]                 |   0.003952  4.251e-03 |    1.3144    1.0568    1.24x      yes
  B= 2 T=16384 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.9872    0.6497    1.52x      yes
  B= 2 T=16384 H= 16 mode=E [g,gk]              |   0.003906  4.202e-03 |    1.3826    1.0738    1.29x      yes
  B= 2 T= 2048 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.1830    0.1829    1.00x      yes
  B= 2 T= 2048 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.2234    0.1890    1.18x      yes
  B= 2 T= 2048 H= 32 mode=C [g]                 |   0.005859  6.356e-03 |    0.2551    0.2878    0.89x      yes
  B= 2 T= 2048 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.2237    0.1866    1.20x      yes
  B= 2 T= 2048 H= 32 mode=E [g,gk]              |   0.003906  4.950e-03 |    0.2524    0.2931    0.86x      yes
  B= 2 T= 4096 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.3515    0.3454    1.02x      yes
  B= 2 T= 4096 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.4309    0.3495    1.23x      yes
  B= 2 T= 4096 H= 32 mode=C [g]                 |   0.005859  6.250e-03 |    0.4896    0.5606    0.87x      yes
  B= 2 T= 4096 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.4395    0.3589    1.22x      yes
  B= 2 T= 4096 H= 32 mode=E [g,gk]              |   0.004883  4.371e-03 |    0.4950    0.5771    0.86x      yes
  B= 2 T= 8192 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.6742    0.6881    0.98x      yes
  B= 2 T= 8192 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.8662    0.6919    1.25x      yes
  B= 2 T= 8192 H= 32 mode=C [g]                 |   0.005859  5.515e-03 |    0.9650    1.1256    0.86x      yes
  B= 2 T= 8192 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.8729    0.6918    1.26x      yes
  B= 2 T= 8192 H= 32 mode=E [g,gk]              |   0.004883  4.596e-03 |    0.9756    1.1718    0.83x      yes
  B= 2 T=16384 H= 32 mode=A                     |   0.000000  0.000e+00 |    1.3294    1.3696    0.97x      yes
  B= 2 T=16384 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    1.7488    1.3776    1.27x      yes
  B= 2 T=16384 H= 32 mode=C [g]                 |   0.004517  4.167e-03 |    1.9099    2.2911    0.83x      yes
  B= 2 T=16384 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    1.7651    1.3777    1.28x      yes
  B= 2 T=16384 H= 32 mode=E [g,gk]              |   0.003906  4.167e-03 |    1.9474    2.3519    0.83x      yes
  B= 4 T= 2048 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.1909    0.1800    1.06x      yes
  B= 4 T= 2048 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.1949    0.1874    1.04x      yes
  B= 4 T= 2048 H= 16 mode=C [g]                 |   0.004211  5.036e-03 |    0.2645    0.2763    0.96x      yes
  B= 4 T= 2048 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.2077    0.1854    1.12x      yes
  B= 4 T= 2048 H= 16 mode=E [g,gk]              |   0.005859  5.769e-03 |    0.2798    0.2801    1.00x      yes
  B= 4 T= 4096 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.3742    0.3454    1.08x      yes
  B= 4 T= 4096 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.3807    0.3566    1.07x      yes
  B= 4 T= 4096 H= 16 mode=C [g]                 |   0.004395  4.630e-03 |    0.5073    0.5344    0.95x      yes
  B= 4 T= 4096 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.4077    0.3615    1.13x      yes
  B= 4 T= 4096 H= 16 mode=E [g,gk]              |   0.003906  4.902e-03 |    0.5418    0.5648    0.96x      yes
  B= 4 T= 8192 H= 16 mode=A                     |   0.000000  0.000e+00 |    0.7496    0.6848    1.09x      yes
  B= 4 T= 8192 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.7435    0.6907    1.08x      yes
  B= 4 T= 8192 H= 16 mode=C [g]                 |   0.005859  5.682e-03 |    0.9770    1.0851    0.90x      yes
  B= 4 T= 8192 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    0.8061    0.6882    1.17x      yes
  B= 4 T= 8192 H= 16 mode=E [g,gk]              |   0.004089  4.219e-03 |    1.0573    1.1215    0.94x      yes
  B= 4 T=16384 H= 16 mode=A                     |   0.000000  0.000e+00 |    1.4918    1.3633    1.09x      yes
  B= 4 T=16384 H= 16 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    1.4919    1.3692    1.09x      yes
  B= 4 T=16384 H= 16 mode=C [g]                 |   0.004150  4.167e-03 |    1.9358    2.1635    0.89x      yes
  B= 4 T=16384 H= 16 mode=D [gk]                |   0.000000  0.000e+00 |    1.6108    1.3651    1.18x      yes
  B= 4 T=16384 H= 16 mode=E [g,gk]              |   0.004150  4.167e-03 |    2.0890    2.2347    0.93x      yes
  B= 4 T= 2048 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.3185    0.3315    0.96x      yes
  B= 4 T= 2048 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.3762    0.3421    1.10x      yes
  B= 4 T= 2048 H= 32 mode=C [g]                 |   0.004883  4.687e-03 |    0.3757    0.4690    0.80x      yes
  B= 4 T= 2048 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.3703    0.3373    1.10x      yes
  B= 4 T= 2048 H= 32 mode=E [g,gk]              |   0.005859  5.357e-03 |    0.4304    0.4788    0.90x      yes
  B= 4 T= 4096 H= 32 mode=A                     |   0.000000  0.000e+00 |    0.6486    0.6631    0.98x      yes
  B= 4 T= 4096 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    0.7399    0.6715    1.10x      yes
  B= 4 T= 4096 H= 32 mode=C [g]                 |   0.007812  7.353e-03 |    0.7301    0.9382    0.78x      yes
  B= 4 T= 4096 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    0.7354    0.6695    1.10x      yes
  B= 4 T= 4096 H= 32 mode=E [g,gk]              |   0.005859  6.757e-03 |    0.8509    0.9532    0.89x      yes
  B= 4 T= 8192 H= 32 mode=A                     |   0.000000  0.000e+00 |    1.3216    1.3160    1.00x      yes
  B= 4 T= 8192 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    1.4675    1.3256    1.11x      yes
  B= 4 T= 8192 H= 32 mode=C [g]                 |   0.007812  7.194e-03 |    1.4705    1.9226    0.76x      yes
  B= 4 T= 8192 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    1.4720    1.3397    1.10x      yes
  B= 4 T= 8192 H= 32 mode=E [g,gk]              |   0.005859  5.396e-03 |    1.6942    1.9580    0.87x      yes
  B= 4 T=16384 H= 32 mode=A                     |   0.000000  0.000e+00 |    2.5995    2.5899    1.00x      yes
  B= 4 T=16384 H= 32 mode=B [dht,dh0]           |   0.000000  0.000e+00 |    2.9316    2.5995    1.13x      yes
  B= 4 T=16384 H= 32 mode=C [g]                 |   0.007812  8.264e-03 |    2.9571    3.8062    0.78x      yes
  B= 4 T=16384 H= 32 mode=D [gk]                |   0.000000  0.000e+00 |    2.9560    2.6620    1.11x      yes
  B= 4 T=16384 H= 32 mode=E [g,gk]              |   0.005859  4.934e-03 |    3.3428    3.8417    0.87x      yes
  ------------------------------------------------------------------------------------------------------------------------------------
  Geometric mean                                |                       |                        1.04x         

========================================================================================================================

Reviewer Notes

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the SM90 CuTe DSL implementation for the backward DHU kernel, supporting both non-varlen and varlen modes along with various gating and state layout options. The PR also includes a benchmark script and a correctness test suite. A review comment identifies a performance concern where the kernel compilation is cached on concrete problem dimensions, suggesting a refactor to use symbolic dimensions to better handle dynamic shapes and prevent excessive recompilation.

Comment on lines +1115 to +1132
@functools.lru_cache(maxsize=64)
def _compile_bwd_dhu_sm90(
B: int,
T: int,
N: int,
NT: int,
H: int,
K: int,
V: int,
is_varlen: bool,
use_g: bool,
use_gk: bool,
use_dht: bool,
use_dh0: bool,
use_exp2: bool,
transpose_state_layout: bool,
scale: float,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _compile_bwd_dhu_sm90 is cached on problem dimensions (B, T, N, NT), which will trigger a new kernel compilation for every unique input shape. This leads to long warm-up times and high memory usage from the compilation cache, making the kernel inefficient for dynamic shapes.

To address this, I recommend refactoring to use symbolic dimensions for fake tensors during compilation, and passing the concrete problem dimensions at runtime. This is a common pattern in cutlass-python and is used in other parts of this repository (e.g., cula.ops.chunk_delta_h.ChunkDeltaRuleFwdH).

This would involve:

  1. Removing B, T, N, NT from the _compile_bwd_dhu_sm90 function signature and its lru_cache key.
  2. Using cute.sym_int() to define symbolic dimensions for creating fake tensors inside _compile_bwd_dhu_sm90.
  3. Modifying ChunkDeltaRuleBwdDHUSm90 to accept problem dimensions at runtime in its __call__ or kernel method, rather than in __init__.
  4. Updating chunk_gated_delta_rule_bwd_dhu_sm90 to pass the problem dimensions to the compiled kernel at runtime.

This change will allow a single compiled kernel to handle various input sizes, significantly improving performance and usability.

@yechenzhi
Copy link
Copy Markdown
Contributor Author

yechenzhi commented May 19, 2026

Kernel overview

This kernel implements the SM90 CuTe DSL path for chunk_gated_delta_rule_bwd_dhu.

Each CTA owns one (V tile, batch/sequence, head) work unit. For the current target shape, K=V=128 and BV=64, so each (batch/sequence, head) has two V-tile CTAs. Each CTA carries a [BV, BK] = [64, 128] backward state tile in fp32 registers and scans chunks in reverse order.

The recurrence is:

D_t      = carried dh state before processing chunk t
dh[t]    = D_t
dv2[t]   = dv[t] + K[t] @ D_t
D_{t-1}  = decay(D_t) + scale * Q[t]^T @ do[t] - W[t]^T @ dv2[t]

High-level pseudocode:

one CTA handles one (v_tile, batch/sequence, head).

rState = dht if dht is not None else 0  # fp32 [BV, BK]

for chunk in reversed(chunks):
    # 1. emit current carried state as dh[chunk]
    publish_to_smem_for_store(dh[chunk] = bf16(rState))

    # 2. compute dv2 = dv + K @ dh
    # implemented as V-major tile:
    #   acc_dv[V, T] = rState[V, K] @ K[chunk]^T[K, T]
    acc_dv = wgmma(rState_bf16, K_chunk)
    if use_g:
        acc_dv *= exp(g_last - g_cur)
    dv2_tile = dv_chunk + acc_dv
    publish_to_smem_for_store(dv2[chunk] = bf16(dv2_tile))

    # 3. compute QDO
    if use_g:
        # scalar g is loaded directly by compute warps, not TMA-loaded
        do_for_qdo = do_chunk * exp(g_cur)
    else:
        do_for_qdo = do_chunk

    acc_qdo = wgmma(do_for_qdo, Q_chunk)

    # 4. apply decay to carried state
    if use_g:
        rState *= exp(g_last)
    if use_gk:
        rState *= exp(gk_last)

    # 5. compute WDV and update carried state
    acc_wdv = wgmma(dv2_tile, W_chunk)
    rState = rState + scale * acc_qdo - acc_wdv

if dh0 is requested:
    dh0 = rState

Warp roles

The CTA uses 7 warps:

  1. load_warp_id: preloads K, dv, and optional gk for the next reverse chunk.
  2. load_current_warp_id: loads current-chunk do, q, and w.
  3. compute warps: carry rState in registers, directly load scalar g when enabled, and run the WGMMA recurrence.
  4. store_warp_id: stores dh and dv2 after compute warps publish SMEM tiles.

Each specialized warp runs its own reverse-chunk loop. The loops are synchronized through the load/store pipeline barriers rather than by sharing a single syntactic loop body.

@yechenzhi
Copy link
Copy Markdown
Contributor Author

Hi, @KevinZeng08

Performance question

I ran the current SM90 bwd_dhu kernel against FLA. The long fwd-aligned preset is roughly on par overall, but I see two weak areas and would like to ask for guidance before making the kernel more complex.

  1. Low-parallelism shapes are weak, especially B=1, H=16.

For example, with B=1, H=16, many cases are below FLA:

T=4096  mode=A        ~0.73x
T=8192  mode=A        ~0.73x
T=16384 mode=A        ~0.75x
T=8192  mode=D [gk]   ~0.81x

My guess is that this shape only launches 2 * H = 32 CTAs because V=128, BV=64, so the GPU is underfilled. I tried more fine-grained V tiling, but BV=32 does not fit the current WGMMA path, so this does not look straightforward.

  1. The scalar g path is also weaker.

Modes with g are often below FLA, while no-g / gk-only modes are more competitive. My current understanding is that use_g adds direct scalar-g loads, exp computation, sG staging, and gated-do materialization before QDO.

Do you have suggestions for the next optimization direction?

  • Is the current static (V tile, batch/sequence, head) CTA mapping acceptable, with B=1,H=16 left as a known weak case?
  • For use_g, would you prefer keeping the simpler path for now, or trying a dedicated gated-QDO fast path?
  • Could you share your thoughts on whether the current design looks reasonable as a base kernel? Any suggestions for the next optimization direction would be very helpful.

If you have access to other Hopper machines, it would also be great to get additional performance numbers on H100/H200, since I currently only tested on my available H800 PCIE setup.

@yechenzhi yechenzhi marked this pull request as ready for review May 19, 2026 11:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant