diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 9eef4413a63..d33f12d9caa 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -14,6 +14,7 @@ register_attention, ) from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope @@ -847,8 +848,8 @@ def __init__( self.layer_id = layer_id if self.use_qk_norm: - self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) - self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + self.q_norm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.norm_eps) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity()