[Metal] Fused Flash Attention backward (VJP) kernels#3241
[Metal] Fused Flash Attention backward (VJP) kernels#3241Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
Conversation
3e2965b to
f2f1380
Compare
Add fused Flash Attention backward pass (VJP) kernels for Apple Silicon
GPUs, implementing the two-kernel architecture from Flash Attention 2
(Dao, 2023). The fused backward eliminates O(L^2) attention matrix
materialization, reducing peak memory by 70-95% with auto-dispatch
routing between fused and unfused paths.
Key additions:
- Two Metal kernels: steel_attention_vjp_dq and steel_attention_vjp_dkv
- JIT compilation with baked constants for dead-code elimination
- Delta precomputation as lazy MLX graph ops
- Threadgroup memory aliasing (23KB -> 14.8KB for D=128 dKV)
- Sparse block mask support via function constant gating
- Auto-dispatch: causal L thresholds + 1GB memory ceiling
- Support for D={64,96,128}, float16/bfloat16, causal, GQA
- 2-pass vector kernel LSE output for VJP logsumexp
Performance (M3 Max, B=1 H=32 causal float16, fused vs unfused):
D=64: 1.22-1.40x faster
D=96: 1.29-1.35x faster
D=128: 0.77-0.81x (memory trade-off, 70-82% savings)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
f2f1380 to
746d28a
Compare
|
I know that submitting a PR of this size to a framework like mlx is very challenging, and I don’t want to create additional review burden for the reviewers here. The background is that I'm developing https://github.com/mlx-node/mlx-node, and I'm experimenting with local RL on some small parameter LLMs like qwen 3.5 0.8b on macOS via Node.js The reason for using Node.js development is that I think this way can make it simpler for JavaScript developers to use this framework to write Reward functions, thereby completing different RL tasks. My development is still in very early stages. When attempting RL, I found that 128GB of memory is still too small for RL (GRPO), so I identified this potential optimization. If optimizing mlx training performance on macOS is not currently a major goal for the mlx team, you can temporarily ignore this PR. I will sync here if I have any new progress in the future. |
Summary
This PR adds fused Flash Attention backward pass (VJP) kernels for Apple Silicon GPUs, implementing the two-kernel architecture from Flash Attention 2 (Dao, 2023) using MLX's STEEL tiling framework. The fused backward eliminates the
O(L^2)attention matrix materialization, reducing peak memory by 70-95% at the cost of additional recomputation. An auto-dispatch policy routes between fused and unfused paths based on sequence length and memory pressure.Key additions:
steel_attention_vjp_dq(dQ gradients) andsteel_attention_vjp_dkv(dK/dV gradients)gqa_factor,scale,scale_log2, alignment flags, causal, block mask) for dead-code eliminationred_smemoverQ_smem+dO_smem(temporally disjoint), reducing D=128 dKV from 23 KB to 14.8 KB (enables 2 TGs/core)Background
The Flash Attention Backward Recomputation Trade-off
Flash Attention (Dao et al., NeurIPS 2022) achieves O(N) memory for attention by never materializing the full N x N attention matrix. The forward pass tiles the computation with no extra FLOPs — it's the same work, just blocked differently. But the backward pass fundamentally changes the compute profile.
Forward pass (2 matmuls per block-pair):
Backward pass (5 matmuls per block-pair):
The key insight from Flash Attention 2 (Dao, 2023, Section 3.1): the backward pass requires ~2.5x the FLOPs of the forward pass because the attention matrix S must be recomputed from saved LSE in both the dQ kernel and the dKV kernel. This is the fundamental cost of O(N) memory.
Published backward/forward ratios
These ratios are well-documented across implementations:
The backward is inherently more expensive than forward in Flash Attention. This is not a bug — it's the price of O(N) memory.
Reference implementation: Metal Flash Attention (MFA)
Our implementation follows the architectural patterns from Metal Flash Attention by Philip Turner, the only prior fused attention backward for Apple Silicon:
delta = rowsum(dO * O)computed once outside kernels (MFA pattern)P = exp2(S * scale_log2 - LSE)instead ofexp(S * scale - LSE)for Metalexp2efficiency (MFA pattern)The main difference from MFA: we use MLX's JIT compilation to bake runtime constants (
gqa_factor,scale, alignment, causal) as#defineliterals, enabling the Metal compiler to eliminate dead code paths. MFA achieves this via pre-compiled metallib with#defineconstants baked at build time.Architecture
Two-Kernel Design
The backward is split into two independent kernels that run concurrently:
Per-Head-Dimension Configuration
D=128 dKV is the most register-constrained configuration. At BQ=32 WM=2, it would need ~338 regs/thread (spills at >256). BQ=16 reduces TQ from 2 to 1 per simdgroup, cutting register usage to ~202 at the cost of 2x more Q-tile iterations.
Threadgroup Memory Aliasing
The dKV kernel uses four smem buffers:
Q_smem,dO_smem(iteration phase),KV_smem(iteration phase), andred_smem(reduction phase, WM>1 only). Since Q/dO and red are temporally disjoint (iterations end before reduction begins),red_smemis aliased over theQ_smem+dO_smemregion:A
static_assertverifies the Q+dO region is large enough to alias red, and an explicitthreadgroup_barrierbefore the reduction phase ensures correctness.Sparse Block Masks
Both kernels support an optional
block_maskbuffer (uint8_t[NQ_tiles * NK_tiles]) for skipping tile pairs in sparse attention patterns (sliding window, local attention, block-sparse). The mask is gated by a function constant (has_block_mask, index 302) / JIT define (VJP_HAS_BLOCK_MASK):has_block_mask = false(default): all mask checks are dead-code eliminated — zero overheadhas_block_mask = true: tile-skip check at the top of each loop iteration, before any smem loadsblock_mask[qb * NK_tiles + kb] == 0block_mask[qb * NK_tiles + kb] == 0JIT Constant Baking
Both kernels are JIT-compiled with
#defineconstants:For metallib (non-JIT) builds, macros fall back to
params->runtime reads and Metal function constants, ensuring zero behavior change.The Metal shader compiler leverages these compile-time constants for three key optimizations:
VJP_GQA_FACTOR = 1(common MHA case), the entire GQA outer loop in the dKV kernel is removed — no loop counter, head offset math, or branch.VJP_DO_CAUSAL = falseeliminates all causal masking code;VJP_HAS_BLOCK_MASK = falseremoves the sparse mask check and pointer dereference.VJP_SCALE * P * (dP - delta)with a literal float can be fused into a single FMA instruction. Without JIT, the GPU must loadparams->scalefrom device memory at each use site.VJP_GQA_FACTORas a literal, the compiler knows the exact trip count. For GQA=1 the loop is eliminated entirely; for power-of-2 GQA factors, the head index division becomes a bit shift.Metallib builds partially close the gap via Metal function constants (PSO-level boolean specialization), but numeric values (
gqa_factor,scale,scale_log2) must always be loaded from the params buffer at runtime — the metallib compiler never sees their values. Each unique configuration is compiled once and cached in alibrary_map_, so the JIT cost (~50-200ms) is paid only on first invocation.Why Fused Backward Is Not Always Faster
The fundamental trade-off
The fused path pays two penalties:
When fused wins: causal masking
With causal attention, the attention matrix is lower-triangular. Fused kernels skip entire tile blocks in the upper triangle, eliminating ~50% of compute. Unfused still materializes the full N x N matrix then masks it:
This is why our benchmarks show fused causal at 1.17-1.37x speedup (fused faster) but fused dense at 0.36-0.70x (fused slower).
When fused wins: long sequences
At long sequence lengths, the O(L^2) attention matrix becomes a memory bandwidth bottleneck. At L=4096 with 32 heads, unfused needs 3.4 GB just for the backward intermediate. Even when fused is slower in raw compute, it avoids thrashing the memory system.
How we deal with it: auto-dispatch
Rather than always using fused (slow for dense sequences) or always unfused (OOM risk for long sequences), we route per-configuration:
Users can override with
MLX_SDPA_VJP_MODE={fused|unfused}.Benchmark Results
All benchmarks on Apple M3 Max, B=1, H=32, float16. Speedup > 1.0x means fused is faster.
Performance: Fused vs Unfused (MLX_SDPA_VJP_MODE=fused)
Causal attention — fused is competitive or faster due to block skipping:
Dense attention — fused pays the full recomputation penalty:
Pattern analysis:
Why D=128 causal is slower despite block skipping
D=64/96 causal fused is 1.17-1.37x faster than unfused, but D=128 causal fused is only 0.67-0.76x. Both get the same ~50% causal tile-skip benefit, yet D=128 cannot overcome the recomputation penalty. This is caused by five compounding factors:
1. Occupancy collapse (dominant factor). Apple M3 Max has 32 KB threadgroup memory per GPU core. D=64 dKV uses ~7 KB smem, allowing 4 concurrent threadgroups per core — when one threadgroup stalls on a memory load, others keep the ALU busy. D=128 dKV originally used ~22.5 KB smem (Q: 4,352B + dO: 4,352B + KV: 6,144B + reduction: 8,192B), limiting occupancy to 1 threadgroup per core. After smem aliasing (red over Q+dO), this drops to 14.8 KB enabling 2 TGs/core — improving D=128 causal from 0.73x to 0.80x, but still below D=64's 4 TGs/core.
2. BQ=16 doubles iteration overhead. D=128 dKV uses BQ=16 tiles (half of D=64's BQ=32) to keep register pressure under 256 regs/thread. This means 2x more Q-tile iterations in the inner loop, each paying fixed costs: 2 global memory tile loads, 5 threadgroup barriers, and LSE/delta scalar reads. Total useful compute is the same, but the overhead-to-compute ratio doubles — from ~30% at D=64 to ~60% at D=128.
3. WM=2 threadgroup reduction. D=128 requires WM=2 (2 simdgroups, 64 threads) to split register pressure across simdgroups (~202 regs/thread vs ~364 at WM=1). This adds a threadgroup reduction with 4 extra barriers per Q-block iteration for dK/dV partial sum accumulation — overhead that D=64 (WM=1, single simdgroup) never pays.
4. Fewer MMAs per iteration. Each Q-tile iteration performs:
Less compute per iteration means the fixed overhead (barriers, loads) dominates a larger fraction of wall time.
5. Unfused baseline is stronger at D=128. The unfused path uses NAX-optimized matmuls with large tiles (BM=128, BN=128). At D=128, matmul shapes like
[L, L] @ [L, 128]fill NAX tiles perfectly (N=128 = BN). At D=64, N=64 wastes half the tile width. The unfused baseline is proportionally faster at D=128, widening the gap.Combined effect: The effective fused/unfused throughput ratio is ~18% at D=64 vs ~12% at D=128 (after smem aliasing, up from ~10%). Causal tile-skipping saves ~50% of work for both: 50% of 18% = 36% effective throughput (enough to beat unfused → 1.37x), but 50% of 12% = 24% (still slower → 0.76x). Even with 2 TGs/core after aliasing, D=128's combination of BQ=16, WM=2 reduction, and stronger NAX baseline keeps it below parity.
Memory: Peak Usage During VJP
The primary motivation for fused backward is memory efficiency at scale:
At L=4096, unfused requires 3.4 GB for backward alone — this is where fused becomes essential. With model weights, optimizer states, and activations competing for memory during training, the 95% reduction determines whether a training run fits in memory or not.
Auto-Dispatch Validation
The auto-dispatch policy separates causal (speed-driven) from dense (memory-driven) routing:
Dense attention — auto correctly selects unfused at all sequence lengths below 1 GB ceiling:
The ~1.00x ratios confirm that both the "unfused" and "fused" benchmark columns execute the same unfused code path — auto mode is not forcing fused for dense attention.
Causal attention — auto selects fused for speed (verified via
MLX_SDPA_VJP_MODE=fusedbenchmarks above):References