Skip to content

CUDA: fix Gemma E4B MTP FlashAttention#25148

Merged
JohannesGaessler merged 2 commits into
ggml-org:masterfrom
JohannesGaessler:cuda-fa-gemma-mtp
Jun 30, 2026
Merged

CUDA: fix Gemma E4B MTP FlashAttention#25148
JohannesGaessler merged 2 commits into
ggml-org:masterfrom
JohannesGaessler:cuda-fa-gemma-mtp

Conversation

@JohannesGaessler

Copy link
Copy Markdown
Contributor

Fixes #24400 .

For the original Gemma 4 PR I had insisted compilation of GQA ratios 1 and 2 be disabled because compiling template specializations that are never used results in an unnecessarily high compilation time/binary size. Due to a bug the compilation was not actually disabled. And becuase the compilation was not done in the dedicated .cu files for template specialization the overall compilation was not being parallelized, I fixed that in #21768 . I was not aware that the template specialization with a GQA ratio of 2 is actually in use so that PR inadvertently broke Gemma 4 E4B MTP. This PR now provides a proper fix.

Requirements

@JohannesGaessler JohannesGaessler requested a review from a team as a code owner June 29, 2026 21:19
@github-actions github-actions Bot added ggml changes relating to the ggml tensor library for machine learning CUDA Related to the CUDA backend labels Jun 29, 2026

@ServeurpersoCom ServeurpersoCom left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct attention -> better draft -> higher acceptance -> faster. Confirmed on Blackwell

[45141] 1.14.626.605 I cmn  common_reaso: activated, budget=2147483647 tokens
[45141] 1.15.265.214 I cmn  common_reaso: deactivated (natural end)
[45141] 1.15.289.527 I slot print_timing: id  3 | task 10 | prompt eval time =      28.99 ms /    11 tokens (    2.64 ms per token,   379.43 tokens per second)
[45141] 1.15.289.529 I slot print_timing: id  3 | task 10 |        eval time =     671.71 ms /   189 tokens (    3.55 ms per token,   281.37 tokens per second)
[45141] 1.15.289.529 I slot print_timing: id  3 | task 10 |       total time =     700.70 ms /   200 tokens
[45141] 1.15.289.530 I slot print_timing: id  3 | task 10 |    graphs reused =         89
[45141] 1.15.289.532 I slot print_timing: id  3 | task 10 | draft acceptance = 0.61905 (  104 accepted /   168 generated), mean len =  2.24
[45141] 1.15.289.532 I slot print_timing: id  3 | task 10 |      acc per pos = (0.714, 0.524)
[45141] 1.15.289.545 I spec common_specu: statistics        draft-mtp: #calls(b,g,a) =    2     90     90, #gen drafts =     90, #acc drafts =    64, #gen tokens =    180, #acc tokens =   110, #mean acc len = 2.22, #acc rate/pos = (0.711, 0.511), dur(b,g,a) = 0.003, 109.452, 0.020 ms
[45141] 1.15.289.556 I slot      release: id  3 | task 10 | stop processing: n_tokens = 206, truncated = 0
[45141] 1.15.289.566 I srv  update_slots: all slots are idle
[45141] 1.15.289.654 I srv         close: stream_pipe close: skip drain (done=1 cancelled=0) conv=7fd99914-2d4d-4cdd-b267-ce2df6e469b6::MoE-Vision-Gemma-4-E4B-IT-MTP
542.16.061.812 I srv    operator(): operator(): cleaning up before exit...
[51925] 0.23.292.527 I cmn  common_reaso: activated, budget=2147483647 tokens
[51925] 0.23.910.288 I cmn  common_reaso: deactivated (natural end)
[51925] 0.23.933.600 I slot print_timing: id  3 | task 127 | prompt eval time =      18.27 ms /    10 tokens (    1.83 ms per token,   547.32 tokens per second)
[51925] 0.23.933.602 I slot print_timing: id  3 | task 127 |        eval time =     647.59 ms /   264 tokens (    2.45 ms per token,   407.66 tokens per second)
[51925] 0.23.933.602 I slot print_timing: id  3 | task 127 |       total time =     665.86 ms /   274 tokens
[51925] 0.23.933.602 I slot print_timing: id  3 | task 127 |    graphs reused =        232
[51925] 0.23.933.604 I slot print_timing: id  3 | task 127 | draft acceptance = 0.67857 (  152 accepted /   224 generated), mean len =  2.36
[51925] 0.23.933.605 I slot print_timing: id  3 | task 127 |      acc per pos = (0.750, 0.607)
[51925] 0.23.933.614 I spec common_specu: statistics        draft-mtp: #calls(b,g,a) =    2    235    235, #gen drafts =    235, #acc drafts =   180, #gen tokens =    470, #acc tokens =   321, #mean acc len = 2.37, #acc rate/pos = (0.766, 0.600), dur(b,g,a) = 0.002, 292.445, 0.033 ms
[51925] 0.23.933.623 I slot      release: id  3 | task 127 | stop processing: n_tokens = 281, truncated = 0
[51925] 0.23.933.630 I srv  update_slots: all slots are idle
[51925] 0.23.933.740 I srv         close: stream_pipe close: skip drain (done=1 cancelled=0) conv=7cc03a3c-41b5-4a43-ac48-c2e005dda91f::MoE-Vision-Gemma-4-E2B-IT-MTP

@EntityDeleter

Copy link
Copy Markdown

Fixes it for me, I do not see any noticeable drop in token generation as compared to disabling flash attention. Thank you!

@JohannesGaessler

Copy link
Copy Markdown
Contributor Author

Can I please get a re-approval? I had accidentally pushed one superfluous template declaration with 2 * 2 = 4 columns which is never actually used because the minimum number of columns for the mma kernel is 8.

@JohannesGaessler JohannesGaessler merged commit e495d1e into ggml-org:master Jun 30, 2026
24 checks passed
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jul 1, 2026
* CUDA: fix Gemma E4B MTP FlashAttention

* remove unused template declaration
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jul 1, 2026
* CUDA: fix Gemma E4B MTP FlashAttention

* remove unused template declaration
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jul 1, 2026
* CUDA: fix Gemma E4B MTP FlashAttention

* remove unused template declaration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CUDA Related to the CUDA backend ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: Gemma4 E4B crashes with --flash-attn on

4 participants