Skip to content

win: re-enable and fix cuDNN performance#3242

Merged
zcbenz merged 3 commits intoml-explore:mainfrom
dhiltgen:win_cuDNN_fix
Mar 13, 2026
Merged

win: re-enable and fix cuDNN performance#3242
zcbenz merged 3 commits intoml-explore:mainfrom
dhiltgen:win_cuDNN_fix

Conversation

@dhiltgen
Copy link
Copy Markdown
Contributor

Proposed changes

Populating a fresh cuDNN CUDA graph for each layer, and adding that new graph to the overall MLX CUDA graph is costly with WDDM. To resolve this we cache the graph (first call does the expensive populate_cuda_graph, subsequent calls only patch pointers via update_cuda_graph), and we cache the subgraph key to avoid the overhead of recomputing (kernel attribute queries that hit WDDM round-trip overhead.)

SDPACacheKey has bool fields adjacent to int64_t arrays which causes padding bytes for alignment. The BytesKey constructor memsets everything to zero, but the aggregate init cache_key.pod = {...} creates a stack temporary with uninitialized padding, and the compiler's trivial copy-assignment copies the entire struct — including the garbage padding — over the zeroed bytes. Since BytesKey uses memcmp for equality, every SDPA call produces a unique key and is a cache miss.

Results (RTX 5090, Windows 11 WDDM, mlx_lm benchmark -p 2048 -g 128):

Model Metric main After Fix Change
Llama-3.2-3B-4bit Prefill (tok/s) 1,228 2,436 +98%
Llama-3.2-3B-4bit Gen (tok/s) 371 385 +4%
Qwen3-8B-4bit Prefill (tok/s) 494 917 +86%
Qwen3-8B-4bit Gen (tok/s) 220 229 +4%
Llama-3.2-3B-bf16 Prefill (tok/s) 19,157 19,347 +1%
Llama-3.2-3B-bf16 Gen (tok/s) 200 199 ~0%
Qwen3-8B-bf16 Prefill (tok/s) 9,254 9,253 ~0%
Qwen3-8B-bf16 Gen (tok/s) 91 91 ~0%

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Populating a fresh cuDNN CUDA graph for each layer, and adding that new
graph to the overall MLX CUDA graph is costly with WDDM.  To resolve
this we cache the graph (first call does the expensive
populate_cuda_graph, subsequent calls only patch pointers via
update_cuda_graph), and we cache the subgraph key to avoid the overhead
of recomputing (kernel attribute queries that hit WDDM round-trip
overhead.)
SDPACacheKey has bool fields adjacent to int64_t arrays which causes
padding bytes for alignment.  The BytesKey constructor memsets
everything to zero, but the aggregate init cache_key.pod = {...} creates
a stack temporary with uninitialized padding, and the compiler's trivial
copy-assignment copies the entire struct — including the garbage padding
— over the zeroed bytes. Since BytesKey uses memcmp for equality, every
SDPA call produces a unique key and is a cache miss.
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the first time we were bitten by the padding bytes, thanks a lot for the awesome fix!

Comment thread mlx/backend/cuda/scaled_dot_product_attention.cpp Outdated
@zcbenz zcbenz merged commit 7adfc83 into ml-explore:main Mar 13, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants