-
Notifications
You must be signed in to change notification settings - Fork 800
Use custom rmsnorm in static attention #16604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16604
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated FailureAs of commit bcd5510 with merge base 9510334 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR updates the static attention implementation to use the custom RMSNorm class instead of PyTorch's built-in torch.nn.RMSNorm. The custom RMSNorm explicitly casts to fp32 during normalization, which resolves CoreML fp16 conversion errors when exporting models like Qwen that use QK normalization.
Changes:
- Import the custom
RMSNormfromexecutorch.examples.models.llama.norm - Replace
torch.nn.RMSNormwith customRMSNorminStaticAttention.__init__()whenuse_qk_norm=True
Comments suppressed due to low confidence (2)
examples/models/llama/static_attention.py:862
- The default parameter
rms_norm_class=torch.nn.RMSNormis inconsistent with the customRMSNormnow used in__init__. Whenfrom_attention_mhais called without specifyingrms_norm_class, it will create a StaticAttention instance with custom RMSNorm (via__init__), but thenload_weights_from_attention_mhawill replace them withtorch.nn.RMSNorminstances, defeating the purpose of this PR. Change the default toRMSNormto maintain consistency.
rms_norm_class=torch.nn.RMSNorm,
examples/models/llama/static_attention.py:1101
- The default parameter
rms_norm_class=torch.nn.RMSNormshould be changed toRMSNormto be consistent with the custom RMSNorm now used in__init__. This ensures that when loading weights, the norms use the same custom implementation that handles fp32 casting for CoreML compatibility.
self, other: AttentionMHA, rms_norm_class=torch.nn.RMSNorm
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@sxu can you make sure these changes are OK? |
Summary
Use custom RMSNorm for static attention. This is what's used in mha attention as well.
executorch/examples/models/llama/attention.py
Lines 365 to 366 in 9cbe754
Test plan
Exporting qwen with coreml. Qwen config sets
use_qk_normto true in the config, and then fails because of CoreML fp16 conversion error. Looks like the custom RMSNorm explicitly casts to fp32.executorch/examples/models/llama/norm.py
Line 54 in 9cbe754