win: re-enable and fix cuDNN performance#3242
Merged
zcbenz merged 3 commits intoml-explore:mainfrom Mar 13, 2026
Merged
Conversation
Populating a fresh cuDNN CUDA graph for each layer, and adding that new graph to the overall MLX CUDA graph is costly with WDDM. To resolve this we cache the graph (first call does the expensive populate_cuda_graph, subsequent calls only patch pointers via update_cuda_graph), and we cache the subgraph key to avoid the overhead of recomputing (kernel attribute queries that hit WDDM round-trip overhead.)
SDPACacheKey has bool fields adjacent to int64_t arrays which causes
padding bytes for alignment. The BytesKey constructor memsets
everything to zero, but the aggregate init cache_key.pod = {...} creates
a stack temporary with uninitialized padding, and the compiler's trivial
copy-assignment copies the entire struct — including the garbage padding
— over the zeroed bytes. Since BytesKey uses memcmp for equality, every
SDPA call produces a unique key and is a cache miss.
zcbenz
approved these changes
Mar 11, 2026
Collaborator
zcbenz
left a comment
There was a problem hiding this comment.
This is not the first time we were bitten by the padding bytes, thanks a lot for the awesome fix!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposed changes
Populating a fresh cuDNN CUDA graph for each layer, and adding that new graph to the overall MLX CUDA graph is costly with WDDM. To resolve this we cache the graph (first call does the expensive populate_cuda_graph, subsequent calls only patch pointers via update_cuda_graph), and we cache the subgraph key to avoid the overhead of recomputing (kernel attribute queries that hit WDDM round-trip overhead.)
SDPACacheKey has bool fields adjacent to int64_t arrays which causes padding bytes for alignment. The BytesKey constructor memsets everything to zero, but the aggregate init cache_key.pod = {...} creates a stack temporary with uninitialized padding, and the compiler's trivial copy-assignment copies the entire struct — including the garbage padding — over the zeroed bytes. Since BytesKey uses memcmp for equality, every SDPA call produces a unique key and is a cache miss.
Results (RTX 5090, Windows 11 WDDM,
mlx_lm benchmark -p 2048 -g 128):Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes