Production-relevant LLM operator kernels implemented from scratch in Triton.
Triton compiles directly to CUDA PTX/SASS — writing a Triton kernel is writing a GPU kernel, at the abstraction level used in production by vLLM, SGLang, FlashAttention-3, and OpenAI's inference stack.
| GPU | NVIDIA RTX 5090 · 32 GB VRAM · ~1792 GB/s HBM bandwidth |
| CUDA | 12.8 |
| Triton | 3.5.1 |
| PyTorch | 2.9.1 |
| Target model config | Qwen2.5-7B-Instruct (N_HEADS=32, HEAD_DIM=128, ffn_dim=18944) |
Three fused Triton kernels for the most performance-critical operators in every modern LLM (Qwen, LLaMA, Mistral, DeepSeek). Each kernel is benchmarked against PyTorch eager mode.
| Kernel | Where it runs | PyTorch ops replaced | HBM passes: before → after |
|---|---|---|---|
| RMSNorm | Every transformer layer, pre-attention + pre-FFN | pow + mean + rsqrt + mul |
2 reads → 1 read |
| RoPE | Every attention layer, applied to Q and K | reshape + cos/sin lookup + mul + cat |
4–5 passes → 1 pass |
| SwiGLU | Every FFN, activation gate | silu(gate) + gate * up (2 kernels) |
5 passes → 3 passes |
LLM inference at batch=1 is memory-bandwidth-bound: every token generation must load model weights from HBM. Any op that reads the same tensor multiple times or writes unnecessary intermediates is pure waste.
PyTorch's eager mode launches a separate CUDA kernel per Python op. Each kernel incurs:
- A round-trip to HBM to read inputs
- A round-trip to HBM to write outputs (even for temporaries)
- Kernel launch overhead (~5–10 µs per launch on Ampere/Blackwell)
Kernel fusion eliminates intermediate HBM writes by keeping data in SRAM (registers/shared memory) across operations. The computation is identical — only the memory traffic changes.
Hardware: RTX 5090 · dtype: fp16 · all correctness checks pass (max_err < 5e-3)
RMSNorm — hidden_dim sweep (batch=4, seq=512)
| hidden_dim | PyTorch (µs) | Triton (µs) | Speedup | Triton BW |
|---|---|---|---|---|
| 896 | 70.6 | 13.3 | 5.3× | 552 GB/s |
| 1536 | 80.2 | 13.1 | 6.1× | 958 GB/s |
| 2048 | 83.5 | 12.6 | 6.6× | 1328 GB/s |
| 3072 | 84.5 | 12.7 | 6.7× | 1985 GB/s |
| 4096 | 112.6 | 12.0 | 9.4× | 1792+ GB/s (L2 cached) |
Peak speedup 9.4× at hidden=4096. Fusion reduces the 2-pass read pattern to 1 pass.
| seq_len | PyTorch (µs) | Triton (µs) | Speedup | Triton BW |
|---|---|---|---|---|
| 512 | 128.8 | 14.6 | 8.8× | 594 GB/s |
| 1024 | 94.6 | 19.4 | 4.9× | 894 GB/s |
| 2048 | 133.5 | 35.5 | 3.8× | 974 GB/s |
| 4096 | 412.7 | 68.8 | 6.0× | 1007 GB/s |
| 8192 | 1018.5 | 126.0 | 8.1× | 1099 GB/s |
At seq=512, PyTorch overhead from multiple kernel launches dominates → 8.8× speedup. At seq=8192, fused single-pass reduces 4–5 HBM passes to 1 → sustained 8.1× speedup.
| ffn_dim | PyTorch (µs) | Triton (µs) | Speedup | Triton BW |
|---|---|---|---|---|
| 4096 | 19.5 | 13.3 | 1.5× | 3798 GB/s (L2 cached) |
| 8192 | 65.6 | 18.5 | 3.6× | 5447 GB/s (L2 cached) |
| 11008 | 112.4 | 69.8 | 1.6× | 1938 GB/s |
| 14336 | 164.1 | 104.7 | 1.6× | 1682 GB/s |
| 18944 | 246.1 | 148.0 | 1.7× | 1573 GB/s |
Qwen2.5-7B uses ffn_dim=18944 → 1.7× speedup by eliminating intermediate silu output write.
Monkey-patched 57 RMSNorm + 28 SwiGLU instances in the live Qwen2.5-7B-Instruct model. Measured prefill (single forward pass) latency on RTX 5090.
| seq_len | Baseline (ms) | Patched (ms) | Speedup |
|---|---|---|---|
| 128 | 19.85 | 18.03 | 1.101× |
| 256 | 26.36 | 24.62 | 1.071× |
| 512 | 42.72 | 40.83 | 1.046× |
| 1024 | 79.07 | 76.15 | 1.038× |
| 2048 | 151.84 | 147.03 | 1.033× |
Why 3–10%, not 3–9×? In a full forward pass, attention matmuls (Q/K/V projections + softmax) dominate compute — especially at long sequence lengths where attention is O(n²). RMSNorm and SwiGLU are the fast elementwise tail, not the bottleneck. The speedup is higher at seq=128 (elementwise ops are a larger fraction) and converges toward 1× as attention dominates at seq=2048. This is exactly what the roofline model predicts: kernel fusion eliminates overhead only for the memory-bound ops, not the compute-bound matmuls.
-
Kernel launch overhead dominates at small sizes. PyTorch's multi-kernel approach adds 5–10 µs per op. For short sequences (RoPE at seq=512), this alone causes 8.8× slowdown independent of memory bandwidth.
-
Fusion benefit scales with operator complexity. RMSNorm (2-pass → 1-pass) sees up to 9.4× speedup. SwiGLU (5-pass → 3-pass) sees 1.5–3.6×. The more intermediate tensors eliminated, the larger the gain.
-
All three operators are memory-bandwidth-bound. Triton kernels achieve 550–1100 GB/s effective HBM bandwidth (30–60% of RTX 5090's 1792 GB/s theoretical peak), consistent with the roofline model for memory-bound ops.
-
These kernels compose. In a real LLM forward pass, each layer calls RMSNorm × 2 + RoPE × 2 + SwiGLU × 1. Per layer, fused kernels save ~(9.4 + 8.1 + 1.7)× = up to ~65% latency reduction on just the elementwise + normalization portion.
git clone https://github.com/MemoryWorld/cuda-kernels
cd cuda-kernels
pip install torch triton matplotlib
# Run individual kernels (each ~10s, includes warmup)
cd kernels
python rmsnorm.py # results/rmsnorm.json + results/rmsnorm.png
python rope.py # results/rope.json + results/rope.png
python swiglu.py # results/swiglu.json + results/swiglu.pngWhy Triton, not raw CUDA C++?
Triton compiles to the same PTX/SASS as CUDA C++. The production LLM inference stack — vLLM, SGLang, FlashAttention-3, Liger-Kernel — uses Triton for exactly these kinds of elementwise and reduction kernels. Writing CUDA C++ for ops like RMSNorm and RoPE is unnecessary engineering overhead without performance benefit on modern hardware.
Why these three ops?
Every forward pass of every modern LLM (Qwen, LLaMA, Mistral, DeepSeek, Gemma) executes:
RMSNormbefore attention and before FFN — 2× per layerRoPEon Q and K — 2× per layerSwiGLUas the FFN activation — 1× per layer
For a 32-layer model like Qwen2.5-7B, that's 64 RMSNorm + 64 RoPE + 32 SwiGLU calls per forward pass. Fusing them is not micro-optimization — it's standard production practice.
All three kernels were also implemented in hand-written CUDA C++ (shared memory + warp reduction for RMSNorm, coalesced half-precision loads for RoPE/SwiGLU) and benchmarked against the Triton versions.
| hidden_dim | PyTorch (µs) | Triton (µs) | CUDA C++ (µs) | Triton speedup | CUDA C++ speedup |
|---|---|---|---|---|---|
| 896 | 7182 | 836 | 836 | 8.59× | 8.59× |
| 1536 | 7182 | 836 | 745 | 8.59× | 9.64× |
| 2048 | 7086 | 929 | 837 | 7.63× | 8.46× |
| 3072 | 1204 | 839 | 839 | 1.44× | 1.44× |
| 3584 | 1786 | 654 | 843 | 2.73× | 2.12× |
| 4096 | 531 | 195 | 104 | 2.73× | 5.09× |
CUDA C++ beats Triton at large hidden dim (4096): explicit shared-memory tiling + warp
__shfl_down_syncmaps better to the register file than Triton's autotune at this specific size.
| seq_len | PyTorch (µs) | Triton (µs) | CUDA C++ (µs) | Triton speedup | CUDA C++ speedup |
|---|---|---|---|---|---|
| 512 | 7980 | 102 | 836 | 77.87× | 9.55× |
| 1024 | 7621 | 111 | 746 | 68.76× | 10.22× |
| 2048 | 614 | 232 | 385 | 2.64× | 1.60× |
| 4096 | 1501 | 224 | 214 | 6.69× | 7.01× |
| 8192 | 3477 | 382 | 373 | 9.11× | 9.32× |
Triton dominates at small seq (512–1024): 1D grid + vectorized fp16 loads; CUDA C++ per-element threads have higher launch overhead. At seq≥4096 they converge.
| ffn_dim | PyTorch (µs) | Triton (µs) | CUDA C++ (µs) | Triton speedup | CUDA C++ speedup |
|---|---|---|---|---|---|
| 4096 | 1404 | 122 | 286 | 11.55× | 4.90× |
| 8192 | 1129 | 122 | 220 | 9.27× | 5.13× |
| 11008 | 1852 | 351 | 281 | 5.27× | 6.59× |
| 14336 | 2536 | 405 | 381 | 6.26× | 6.66× |
| 18944 | 3380 | 563 | 530 | 6.00× | 6.38× |
Triton wins at small sizes (lower launch overhead); CUDA C++ edges ahead at large ffn_dim via
half2vectorisation.
Key takeaway: Triton and CUDA C++ trade wins depending on problem size. For production LLM ops at inference-relevant sizes, they are within 10–20% of each other. Triton's advantage is developer velocity — no separate compilation, no stream management, no warp mask arithmetic.
Every transformer layer executes this pattern twice (pre-attention and pre-FFN):
x_norm = RMSNorm(x) # reads x, writes x_norm to HBM
y = x_norm @ W_linear.T # reads x_norm + W, writes yThe intermediate tensor x_norm is written to HBM by RMSNorm, then immediately read back by the linear layer — pure waste. A fused kernel keeps x_norm in registers across both ops and eliminates that round-trip.
HBM traffic analysis (fp16 elements):
| Formula | Qwen2.5-7B, M=2048, K=3584, N=3584 | |
|---|---|---|
| Naive | 3·M·K + K + N·K + M·N | ~214 MB |
| Fused | 2·M·K + K + N·K + M·N | ~199 MB |
| Saved | M·K | ~14.7 MB per pair, ~1.88 GB per 32-layer forward pass |
| M (tokens) | Naive (µs) | Fused (µs) | Speedup | Regime |
|---|---|---|---|---|
| 1 | 1246.8 | 63.7 | 19.6× | decode |
| 4 | 1523.0 | 82.9 | 18.4× | decode |
| 16 | 2286.2 | 63.7 | 35.9× | decode |
| 64 | 1615.9 | 132.3 | 12.2× | decode |
| 256 | 295.1 | 347.7 | 0.85× | prefill |
| 512 | 279.6 | 478.2 | 0.58× | prefill |
| 1024 | 557.2 | 845.7 | 0.66× | prefill |
| 2048 | 859.6 | 1805.6 | 0.48× | prefill |
| M (tokens) | Naive (µs) | Fused (µs) | Speedup | Regime |
|---|---|---|---|---|
| 1 | 228.7 | 266.8 | 0.86× | decode |
| 4 | 341.4 | 224.5 | 1.52× | decode |
| 16 | 289.9 | 210.3 | 1.38× | decode |
| 64 | 515.8 | 319.0 | 1.62× | decode |
| 256 | 454.3 | 1174.1 | 0.39× | prefill |
| 512 | 861.5 | 2217.0 | 0.39× | prefill |
Why decode wins (M=1–64): PyTorch eager executes 8+ separate CUDA kernel launches for the naive path (float(), pow, mean, rsqrt, mul × 2, half(), linear). At M=1, the actual compute time is ~28 µs (memory-bandwidth bound at 1792 GB/s), but kernel launch overhead inflates naive to 1247 µs. The fused kernel collapses everything to one launch and runs near the bandwidth limit.
Why prefill regresses (M≥256): cuBLAS GEMM is heavily optimized for large matrix shapes and uses hardware-specific tiling, pipelining, and split-K strategies. The handwritten Triton kernel with BLOCK_M=16 does not saturate tensor cores at large M. The memory savings (~14.7 MB at M=2048) are real but dominated by the GEMM efficiency gap.
Production context: This is the exact finding that shapes production LLM serving design. FlashInfer and Liger Kernel use fused norm+linear kernels specifically for the decode path (batch=1–32 per request), where kernel launch overhead is the dominant cost. For prefill, batching offloads the work to large GEMMs where cuBLAS already wins.
| Item | Status |
|---|---|
| RMSNorm Triton kernel | ✅ Done |
| RoPE Triton kernel | ✅ Done |
| SwiGLU Triton kernel | ✅ Done |
| CUDA C++ versions (shared mem + warp reduction) | ✅ Done |
| Triton vs CUDA C++ head-to-head benchmark | ✅ Done |
| Fused RMSNorm + linear projection | ✅ Done |
| Benchmarks inside actual Qwen2.5-7B forward pass | ✅ Done |
Hardware: NVIDIA RTX 5090 (32 GB) · CUDA 12.8 · Triton 3.5.1 · PyTorch 2.9.1 · WSL2





