Skip to content

Conversation

@lucylq
Copy link
Contributor

@lucylq lucylq commented Jan 14, 2026

Summary

Use custom RMSNorm for static attention. This is what's used in mha attention as well.

self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

Test plan

Exporting qwen with coreml. Qwen config sets use_qk_norm to true in the config, and then fails because of CoreML fp16 conversion error. Looks like the custom RMSNorm explicitly casts to fp32.

output = self._norm(x.float()).type_as(x)

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16604

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated Failure

As of commit bcd5510 with merge base 9510334 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 14, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@lucylq lucylq requested a review from metascroy January 14, 2026 22:07
@lucylq lucylq marked this pull request as ready for review January 14, 2026 22:07
Copilot AI review requested due to automatic review settings January 14, 2026 22:07
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the static attention implementation to use the custom RMSNorm class instead of PyTorch's built-in torch.nn.RMSNorm. The custom RMSNorm explicitly casts to fp32 during normalization, which resolves CoreML fp16 conversion errors when exporting models like Qwen that use QK normalization.

Changes:

  • Import the custom RMSNorm from executorch.examples.models.llama.norm
  • Replace torch.nn.RMSNorm with custom RMSNorm in StaticAttention.__init__() when use_qk_norm=True
Comments suppressed due to low confidence (2)

examples/models/llama/static_attention.py:862

  • The default parameter rms_norm_class=torch.nn.RMSNorm is inconsistent with the custom RMSNorm now used in __init__. When from_attention_mha is called without specifying rms_norm_class, it will create a StaticAttention instance with custom RMSNorm (via __init__), but then load_weights_from_attention_mha will replace them with torch.nn.RMSNorm instances, defeating the purpose of this PR. Change the default to RMSNorm to maintain consistency.
        rms_norm_class=torch.nn.RMSNorm,

examples/models/llama/static_attention.py:1101

  • The default parameter rms_norm_class=torch.nn.RMSNorm should be changed to RMSNorm to be consistent with the custom RMSNorm now used in __init__. This ensures that when loading weights, the norms use the same custom implementation that handles fp32 casting for CoreML compatibility.
        self, other: AttentionMHA, rms_norm_class=torch.nn.RMSNorm

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@metascroy metascroy requested a review from sxu January 14, 2026 22:20
@metascroy
Copy link
Contributor

@sxu can you make sure these changes are OK?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants