Skip to content

[Metal] Fused Flash Attention backward (VJP) kernels#3241

Open
Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
mlx-node:flash-attn
Open

[Metal] Fused Flash Attention backward (VJP) kernels#3241
Brooooooklyn wants to merge 1 commit intoml-explore:mainfrom
mlx-node:flash-attn

Conversation

@Brooooooklyn
Copy link

@Brooooooklyn Brooooooklyn commented Mar 11, 2026

Summary

This PR adds fused Flash Attention backward pass (VJP) kernels for Apple Silicon GPUs, implementing the two-kernel architecture from Flash Attention 2 (Dao, 2023) using MLX's STEEL tiling framework. The fused backward eliminates the O(L^2) attention matrix materialization, reducing peak memory by 70-95% at the cost of additional recomputation. An auto-dispatch policy routes between fused and unfused paths based on sequence length and memory pressure.

Key additions:

  • Two new Metal kernels: steel_attention_vjp_dq (dQ gradients) and steel_attention_vjp_dkv (dK/dV gradients)
  • JIT compilation with baked constants (gqa_factor, scale, scale_log2, alignment flags, causal, block mask) for dead-code elimination
  • Delta precomputation as lazy MLX graph ops (MFA pattern)
  • Threadgroup memory aliasing: red_smem over Q_smem+dO_smem (temporally disjoint), reducing D=128 dKV from 23 KB to 14.8 KB (enables 2 TGs/core)
  • Sparse block mask support in both backward kernels (zero overhead when unused via function constant gating)
  • Auto-dispatch with per-config L thresholds and memory ceiling
  • Support for D={64, 96, 128}, float16/bfloat16, causal masking, GQA

Background

The Flash Attention Backward Recomputation Trade-off

Flash Attention (Dao et al., NeurIPS 2022) achieves O(N) memory for attention by never materializing the full N x N attention matrix. The forward pass tiles the computation with no extra FLOPs — it's the same work, just blocked differently. But the backward pass fundamentally changes the compute profile.

Forward pass (2 matmuls per block-pair):

S = Q @ K^T       // attention scores
P = softmax(S)    // attention weights
O = P @ V         // output

Backward pass (5 matmuls per block-pair):

// Must recompute S and P from saved LSE (logsumexp) — this is the cost of not saving O(N^2)
S = Q @ K^T                         // recompute (1)
P = exp2(S * scale_log2 - LSE)      // recompute from saved LSE (2)
dP = dO @ V^T                       // (3)
dS = scale * P * (dP - delta)       // elementwise
dQ = dS @ K                         // (4)
dK = dS^T @ Q                       // (5)
dV = P^T @ dO                       // (5, in dKV kernel)

The key insight from Flash Attention 2 (Dao, 2023, Section 3.1): the backward pass requires ~2.5x the FLOPs of the forward pass because the attention matrix S must be recomputed from saved LSE in both the dQ kernel and the dKV kernel. This is the fundamental cost of O(N) memory.

Published backward/forward ratios

These ratios are well-documented across implementations:

Implementation Hardware Backward / Forward Source
Flash Attention 2 A100 (CUDA) ~2.0-2.5x (dense), ~1.7-2.0x (causal) Dao 2023, arXiv:2307.08691
Flash Attention 2 A100 Forward: 73% peak TFLOPS, Backward: 63% peak FA2 paper, Table 2
Metal Flash Attention M1Max/M4 (Metal) FWD+BWD: 62-71% ALU utilization at large N MFA benchmarks

The backward is inherently more expensive than forward in Flash Attention. This is not a bug — it's the price of O(N) memory.

Reference implementation: Metal Flash Attention (MFA)

Our implementation follows the architectural patterns from Metal Flash Attention by Philip Turner, the only prior fused attention backward for Apple Silicon:

  • Two-kernel split: Separate dQ and dKV kernels with no atomics (MFA pattern)
  • Delta precomputation: delta = rowsum(dO * O) computed once outside kernels (MFA pattern)
  • Log2 domain: P = exp2(S * scale_log2 - LSE) instead of exp(S * scale - LSE) for Metal exp2 efficiency (MFA pattern)
  • Register pressure management: WM=2 (64 threads) for D=128 dKV with BQ=16 tiles to avoid register spilling

The main difference from MFA: we use MLX's JIT compilation to bake runtime constants (gqa_factor, scale, alignment, causal) as #define literals, enabling the Metal compiler to eliminate dead code paths. MFA achieves this via pre-compiled metallib with #define constants baked at build time.

Architecture

Two-Kernel Design

The backward is split into two independent kernels that run concurrently:

                    ┌─────────────────────────────────┐
                    │       Delta Precomputation      │
                    │   delta = rowsum(dO * O) [f32]  │
                    └──────────┬──────────────────────┘
                               │
                    ┌──────────┴──────────────────────--┐
                    │                                   │
              ┌─────┴─────-┐                      ┌─────┴──────┐
              │  dQ Kernel │                      │ dKV Kernel │
              │            │                      │            │
              │ Grid:      │                      │ Grid:      │
              │ [NQ,H,B]   │                      │ [NK,kvH,B] │
              │            │                      │            │
              │ WM=4       │                      │ WM=1 (D64) │
              │ 128 threads│                      │ WM=2 (D96+)│
              │ 4 simdgrps │                      │ 32-64 thrd │
              │            │                      │            │
              │ Loop: KV   │                      │ Loop: GQA  │
              │ blocks     │                      │ then Q blks│
              └─────┬──────┘                      └─────┬──────┘
                    │                                   │
                    │  dQ                          dK, dV
                    └───────────────┬───────────────────┘
                                    │
                              Output grads

Per-Head-Dimension Configuration

D Kernel BQ BK WM Threads Regs/thread Notes
64 dQ 32 32 4 128 ~180 Full occupancy
64 dKV 32 32 1 32 ~220 Single simdgroup, no reduction
96 dQ 32 32 (M3+) / 16 4 128 ~200 BK=32 on M3+ for bandwidth
96 dKV 32 16 2 64 ~280 Threadgroup reduction for dK/dV
128 dQ 32 32 (M3+) / 16 4 128 ~220 BK=32 on M3+
128 dKV 16 16 2 64 ~202 BQ=16 halves register tiles, avoids spilling

D=128 dKV is the most register-constrained configuration. At BQ=32 WM=2, it would need ~338 regs/thread (spills at >256). BQ=16 reduces TQ from 2 to 1 per simdgroup, cutting register usage to ~202 at the cost of 2x more Q-tile iterations.

Threadgroup Memory Aliasing

The dKV kernel uses four smem buffers: Q_smem, dO_smem (iteration phase), KV_smem (iteration phase), and red_smem (reduction phase, WM>1 only). Since Q/dO and red are temporally disjoint (iterations end before reduction begins), red_smem is aliased over the Q_smem+dO_smem region:

Before aliasing (D=128 BQ=16 WM=2):
  Q_smem(4,352) + dO_smem(4,352) + KV_smem(6,144) + red_smem(8,192) = 23,040 bytes
  → 1 threadgroup/core (32KB limit)

After aliasing:
  max(Q+dO(8,704), red(8,192)) + KV(6,144) = 14,848 bytes
  → 2 threadgroups/core (latency hiding doubles)

A static_assert verifies the Q+dO region is large enough to alias red, and an explicit threadgroup_barrier before the reduction phase ensures correctness.

Sparse Block Masks

Both kernels support an optional block_mask buffer (uint8_t[NQ_tiles * NK_tiles]) for skipping tile pairs in sparse attention patterns (sliding window, local attention, block-sparse). The mask is gated by a function constant (has_block_mask, index 302) / JIT define (VJP_HAS_BLOCK_MASK):

  • When has_block_mask = false (default): all mask checks are dead-code eliminated — zero overhead
  • When has_block_mask = true: tile-skip check at the top of each loop iteration, before any smem loads
  • dQ kernel: skips K-blocks where block_mask[qb * NK_tiles + kb] == 0
  • dKV kernel: skips Q-blocks where block_mask[qb * NK_tiles + kb] == 0

JIT Constant Baking

Both kernels are JIT-compiled with #define constants:

#define VJP_GQA_FACTOR 4          // eliminates GQA loop for GQA=1
#define VJP_SCALE 0.0883883461f   // constant-folds into FMA
#define VJP_SCALE_LOG2 0.127552539f
#define VJP_BAKED_FC 1            // signals JIT mode
#define VJP_ALIGN_Q true          // dead-code eliminates bounds checks
#define VJP_ALIGN_K true
#define VJP_DO_CAUSAL false       // dead-code eliminates causal branches
#define VJP_HAS_BLOCK_MASK false  // dead-code eliminates sparse mask checks

For metallib (non-JIT) builds, macros fall back to params-> runtime reads and Metal function constants, ensuring zero behavior change.

The Metal shader compiler leverages these compile-time constants for three key optimizations:

  1. Dead code elimination: When VJP_GQA_FACTOR = 1 (common MHA case), the entire GQA outer loop in the dKV kernel is removed — no loop counter, head offset math, or branch. VJP_DO_CAUSAL = false eliminates all causal masking code; VJP_HAS_BLOCK_MASK = false removes the sparse mask check and pointer dereference.
  2. Constant folding: VJP_SCALE * P * (dP - delta) with a literal float can be fused into a single FMA instruction. Without JIT, the GPU must load params->scale from device memory at each use site.
  3. Loop unrolling: With VJP_GQA_FACTOR as a literal, the compiler knows the exact trip count. For GQA=1 the loop is eliminated entirely; for power-of-2 GQA factors, the head index division becomes a bit shift.

Metallib builds partially close the gap via Metal function constants (PSO-level boolean specialization), but numeric values (gqa_factor, scale, scale_log2) must always be loaded from the params buffer at runtime — the metallib compiler never sees their values. Each unique configuration is compiled once and cached in a library_map_, so the JIT cost (~50-200ms) is paid only on first invocation.

Why Fused Backward Is Not Always Faster

The fundamental trade-off

Path Compute Memory
Unfused S computed once, reused from O(L^2) buffer. Uses NAX large-tile matmuls (~10.7 TFLOPS) Materializes full L x L attention matrix per head
Fused S recomputed in both dQ and dKV kernels (~2.5x forward FLOPs). Small BQ=16-32 STEEL tiles (~1.9 TFLOPS) Only O(L) memory — no attention matrix

The fused path pays two penalties:

  1. Recomputation: S = Q @ K^T computed twice instead of once
  2. Smaller tiles: STEEL 32x32 tiles achieve lower MMA utilization than NAX 64x64+ tiles

When fused wins: causal masking

With causal attention, the attention matrix is lower-triangular. Fused kernels skip entire tile blocks in the upper triangle, eliminating ~50% of compute. Unfused still materializes the full N x N matrix then masks it:

Dense attention matrix:        Causal attention matrix:
┌─────────────┐                ┌─────────────┐
│ █ █ █ █ █ █ │                │ █ · · · · · │
│ █ █ █ █ █ █ │                │ █ █ · · · · │
│ █ █ █ █ █ █ │   ──────►      │ █ █ █ · · · │  ~50% tiles skipped
│ █ █ █ █ █ █ │                │ █ █ █ █ · · │
│ █ █ █ █ █ █ │                │ █ █ █ █ █ · │
│ █ █ █ █ █ █ │                │ █ █ █ █ █ █ │
└─────────────┘                └─────────────┘
  All tiles computed              Only lower triangle

This is why our benchmarks show fused causal at 1.17-1.37x speedup (fused faster) but fused dense at 0.36-0.70x (fused slower).

When fused wins: long sequences

At long sequence lengths, the O(L^2) attention matrix becomes a memory bandwidth bottleneck. At L=4096 with 32 heads, unfused needs 3.4 GB just for the backward intermediate. Even when fused is slower in raw compute, it avoids thrashing the memory system.

How we deal with it: auto-dispatch

Rather than always using fused (slow for dense sequences) or always unfused (OOM risk for long sequences), we route per-configuration:

auto_dispatch(D, L, B, H, causal, gqa):
  if causal:
    # Fused is competitive or faster due to ~50% tile skipping.
    # Speed-driven L thresholds:
    D=64/96 MHA:   always fused (1.17-1.37x faster at all L)
    D=64/96 GQA:   fused if L >= 1024 (GQA serial loop erodes advantage)
    D=128 MHA:     fused if L >= 1024 (0.67-0.76x but 70-82% memory savings)
    D=128 GQA:     fused if L >= 2048
  else (dense):
    # Fused is always slower (0.36-0.70x) — never chosen for speed.
    # Only a memory ceiling prevents OOM:
    if attn_bytes >= 1 GB → fused (prevent OOM)
    else → unfused

Users can override with MLX_SDPA_VJP_MODE={fused|unfused}.

Benchmark Results

All benchmarks on Apple M3 Max, B=1, H=32, float16. Speedup > 1.0x means fused is faster.

Performance: Fused vs Unfused (MLX_SDPA_VJP_MODE=fused)

Causal attention — fused is competitive or faster due to block skipping:

D L Unfused (s) Fused (s) Speedup
64 512 0.012 0.011 1.17x
64 1024 0.044 0.033 1.37x
64 2048 0.165 0.131 1.25x
64 4096 0.746 0.553 1.35x
96 512 0.017 0.014 1.20x
96 1024 0.068 0.057 1.18x
96 2048 0.244 0.207 1.18x
128 512 0.018 0.024 0.76x
128 1024 0.068 0.091 0.75x
128 2048 0.261 0.392 0.67x

Dense attention — fused pays the full recomputation penalty:

D L Unfused (s) Fused (s) Speedup
64 512 0.011 0.016 0.67x
64 1024 0.040 0.057 0.70x
64 2048 0.139 0.223 0.62x
64 4096 0.620 0.943 0.66x
96 512 0.015 0.023 0.67x
96 1024 0.060 0.090 0.66x
96 2048 0.216 0.365 0.59x
128 512 0.016 0.038 0.41x
128 1024 0.062 0.148 0.42x
128 2048 0.236 0.652 0.36x

Pattern analysis:

  • D=64/96 causal: Fused is 1.17-1.37x faster. The ~50% causal tile skip more than compensates for the 2.5x recomputation.
  • D=64/96 dense: Fused is 0.59-0.70x (1.4-1.7x slower). Full recomputation penalty, but manageable.
  • D=128 dense: Fused is 0.36-0.42x (2.4-2.8x slower). Larger D means more computation per tile; the recomputation penalty scales with D.
  • D=128 causal: Fused is 0.67-0.76x — causal helps but cannot fully overcome the D=128 overhead. Smem aliasing improved this from 0.73x (see analysis below).

Why D=128 causal is slower despite block skipping

D=64/96 causal fused is 1.17-1.37x faster than unfused, but D=128 causal fused is only 0.67-0.76x. Both get the same ~50% causal tile-skip benefit, yet D=128 cannot overcome the recomputation penalty. This is caused by five compounding factors:

1. Occupancy collapse (dominant factor). Apple M3 Max has 32 KB threadgroup memory per GPU core. D=64 dKV uses ~7 KB smem, allowing 4 concurrent threadgroups per core — when one threadgroup stalls on a memory load, others keep the ALU busy. D=128 dKV originally used ~22.5 KB smem (Q: 4,352B + dO: 4,352B + KV: 6,144B + reduction: 8,192B), limiting occupancy to 1 threadgroup per core. After smem aliasing (red over Q+dO), this drops to 14.8 KB enabling 2 TGs/core — improving D=128 causal from 0.73x to 0.80x, but still below D=64's 4 TGs/core.

2. BQ=16 doubles iteration overhead. D=128 dKV uses BQ=16 tiles (half of D=64's BQ=32) to keep register pressure under 256 regs/thread. This means 2x more Q-tile iterations in the inner loop, each paying fixed costs: 2 global memory tile loads, 5 threadgroup barriers, and LSE/delta scalar reads. Total useful compute is the same, but the overhead-to-compute ratio doubles — from ~30% at D=64 to ~60% at D=128.

3. WM=2 threadgroup reduction. D=128 requires WM=2 (2 simdgroups, 64 threads) to split register pressure across simdgroups (~202 regs/thread vs ~364 at WM=1). This adds a threadgroup reduction with 4 extra barriers per Q-block iteration for dK/dV partial sum accumulation — overhead that D=64 (WM=1, single simdgroup) never pays.

4. Fewer MMAs per iteration. Each Q-tile iteration performs:

  • D=64: TQ=4, TK=4, TD=8 → 128 MMAs per iteration (good compute density)
  • D=128: TQ=1, TK=2, TD=16 → 32 MMAs per iteration (4x less)

Less compute per iteration means the fixed overhead (barriers, loads) dominates a larger fraction of wall time.

5. Unfused baseline is stronger at D=128. The unfused path uses NAX-optimized matmuls with large tiles (BM=128, BN=128). At D=128, matmul shapes like [L, L] @ [L, 128] fill NAX tiles perfectly (N=128 = BN). At D=64, N=64 wastes half the tile width. The unfused baseline is proportionally faster at D=128, widening the gap.

Combined effect: The effective fused/unfused throughput ratio is ~18% at D=64 vs ~12% at D=128 (after smem aliasing, up from ~10%). Causal tile-skipping saves ~50% of work for both: 50% of 18% = 36% effective throughput (enough to beat unfused → 1.37x), but 50% of 12% = 24% (still slower → 0.76x). Even with 2 TGs/core after aliasing, D=128's combination of BQ=16, WM=2 reduction, and stronger NAX baseline keeps it below parity.

Memory: Peak Usage During VJP

The primary motivation for fused backward is memory efficiency at scale:

Config Unfused Peak Fused Peak Savings Attn Matrix Size
D=64 L=512 69 MB 21 MB 70% 17 MB
D=64 L=1024 239 MB 42 MB 82% 67 MB
D=64 L=2048 causal 894 MB 84 MB 91% 268 MB
D=64 L=4096 3,373 MB 169 MB 95% 1,074 MB
D=96 L=1024 258 MB 63 MB 76% 67 MB
D=96 L=2048 919 MB 126 MB 86% 268 MB
D=128 L=1024 277 MB 84 MB 70% 67 MB
D=128 L=2048 causal 969 MB 168 MB 83% 268 MB
D=128 L=2048 GQA(32/8) 940 MB 118 MB 87% 268 MB

At L=4096, unfused requires 3.4 GB for backward alone — this is where fused becomes essential. With model weights, optimizer states, and activations competing for memory during training, the 95% reduction determines whether a training run fits in memory or not.

Auto-Dispatch Validation

The auto-dispatch policy separates causal (speed-driven) from dense (memory-driven) routing:

Dense attention — auto correctly selects unfused at all sequence lengths below 1 GB ceiling:

D L B Auto ratio Expected Notes
64 512 1 1.00x Unfused 17 MB << 1 GB
64 1024 1 1.00x Unfused 67 MB << 1 GB
64 2048 1 1.01x Unfused 268 MB < 1 GB
128 512 1 0.99x Unfused 17 MB << 1 GB
128 1024 1 1.01x Unfused 67 MB << 1 GB
128 2048 1 0.96x Unfused 268 MB < 1 GB
128 512 4 1.02x Unfused 67 MB << 1 GB
128 512 8 1.01x Unfused 134 MB < 1 GB

The ~1.00x ratios confirm that both the "unfused" and "fused" benchmark columns execute the same unfused code path — auto mode is not forcing fused for dense attention.

Causal attention — auto selects fused for speed (verified via MLX_SDPA_VJP_MODE=fused benchmarks above):

  • D=64/96 MHA: always fused → 1.17-1.37x faster
  • D=128 MHA at L≥1024: fused → 0.67-0.76x speed with 70-82% memory savings (acceptable trade-off for training)
  • D=128 MHA at L<1024: unfused (fused is 0.76x with only 70% memory savings — not worth it for short sequences)

References

@Brooooooklyn Brooooooklyn force-pushed the flash-attn branch 4 times, most recently from 3e2965b to f2f1380 Compare March 11, 2026 14:42
Add fused Flash Attention backward pass (VJP) kernels for Apple Silicon
GPUs, implementing the two-kernel architecture from Flash Attention 2
(Dao, 2023). The fused backward eliminates O(L^2) attention matrix
materialization, reducing peak memory by 70-95% with auto-dispatch
routing between fused and unfused paths.

Key additions:
- Two Metal kernels: steel_attention_vjp_dq and steel_attention_vjp_dkv
- JIT compilation with baked constants for dead-code elimination
- Delta precomputation as lazy MLX graph ops
- Threadgroup memory aliasing (23KB -> 14.8KB for D=128 dKV)
- Sparse block mask support via function constant gating
- Auto-dispatch: causal L thresholds + 1GB memory ceiling
- Support for D={64,96,128}, float16/bfloat16, causal, GQA
- 2-pass vector kernel LSE output for VJP logsumexp

Performance (M3 Max, B=1 H=32 causal float16, fused vs unfused):
  D=64:  1.22-1.40x faster
  D=96:  1.29-1.35x faster
  D=128: 0.77-0.81x (memory trade-off, 70-82% savings)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Brooooooklyn
Copy link
Author

I know that submitting a PR of this size to a framework like mlx is very challenging, and I don’t want to create additional review burden for the reviewers here.

The background is that I'm developing https://github.com/mlx-node/mlx-node, and I'm experimenting with local RL on some small parameter LLMs like qwen 3.5 0.8b on macOS via Node.js

The reason for using Node.js development is that I think this way can make it simpler for JavaScript developers to use this framework to write Reward functions, thereby completing different RL tasks.

My development is still in very early stages. When attempting RL, I found that 128GB of memory is still too small for RL (GRPO), so I identified this potential optimization. If optimizing mlx training performance on macOS is not currently a major goal for the mlx team, you can temporarily ignore this PR. I will sync here if I have any new progress in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant