Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 110 additions & 2 deletions examples/apple/coreml/llama/export_static_llm_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import argparse
import json
from typing import Optional

import coremltools as ct
import torch
Expand Down Expand Up @@ -98,11 +99,42 @@ def remove_graph_break_(edge_manager):
edge_manager.exported_program().graph_module.graph.eliminate_dead_code()


def load_model(checkpoint_path: str, params_path: str, max_context_len: int):
"""Load the model from checkpoint with static_mha attention type."""
def load_model(
checkpoint_path: str,
params_path: str,
max_context_len: int,
adapter_checkpoint_path: Optional[str] = None,
adapter_config_path: Optional[str] = None,
):
"""Load the model from checkpoint with static_mha attention type.

Args:
checkpoint_path: Path to model checkpoint (.pth)
params_path: Path to params.json
max_context_len: Maximum context length
adapter_checkpoint_path: Optional path to LoRA adapter weights (adapter_model.safetensors)
adapter_config_path: Optional path to adapter config (adapter_config.json)
"""
with open(params_path, "r") as f:
params = json.loads(f.read())

assert (adapter_config_path is None and adapter_checkpoint_path is None) or (
adapter_config_path is not None and adapter_checkpoint_path is not None
), "Both adapter_config_path and adapter_checkpoint_path must be provided together, or neither."

# Load adapter config if provided
adapter_config = None
if adapter_config_path is not None:
with open(adapter_config_path, "r") as f:
adapter_config = json.loads(f.read())
print(f"Loaded adapter config: rank={adapter_config.get('r')}, alpha={adapter_config.get('lora_alpha')}")
print(f"Target modules: {adapter_config.get('target_modules')}")

# Merge adapter config into params
params["r"] = adapter_config.get("r")
params["lora_alpha"] = adapter_config.get("lora_alpha")
params["target_modules"] = adapter_config.get("target_modules")

# TODO: to support lookahead decoding, the static model outputs
# full logits, but if we are not using lookahead decoding, we can have a
# more efficient model by setting generate_full_logits=False and supplying the last
Expand All @@ -124,8 +156,24 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int):
if "model" in checkpoint:
checkpoint = checkpoint["model"]

# Load and merge adapter weights if provided
if adapter_checkpoint_path is not None:
print(f"Loading LoRA adapter from {adapter_checkpoint_path}...")
from safetensors.torch import load_file
from executorch.examples.models.llama.convert_weights import unsloth_to_meta

adapter_weights = load_file(adapter_checkpoint_path)
# Convert adapter weight keys to Meta format
adapter_weights = unsloth_to_meta(adapter_weights)
print(f"Loaded {len(adapter_weights)} adapter weights")

# Merge adapter weights into checkpoint
checkpoint.update(adapter_weights)

# Rename attention weight keys for static attention
# This handles both base weights and LoRA weights
for i in range(len(model.layers)):
# Base weights
if f"layers.{i}.attention.wq.weight" in checkpoint:
checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop(
f"layers.{i}.attention.wq.weight"
Expand All @@ -139,6 +187,21 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int):
f"layers.{i}.attention.wv.weight"
)

# LoRA weights (lora_a and lora_b)
for lora_suffix in ["lora_a.weight", "lora_b.weight"]:
if f"layers.{i}.attention.wq.{lora_suffix}" in checkpoint:
checkpoint[f"layers.{i}.attention.wqs.0.{lora_suffix}"] = checkpoint.pop(
f"layers.{i}.attention.wq.{lora_suffix}"
)
if f"layers.{i}.attention.wk.{lora_suffix}" in checkpoint:
checkpoint[f"layers.{i}.attention.wks.0.{lora_suffix}"] = checkpoint.pop(
f"layers.{i}.attention.wk.{lora_suffix}"
)
if f"layers.{i}.attention.wv.{lora_suffix}" in checkpoint:
checkpoint[f"layers.{i}.attention.wvs.0.{lora_suffix}"] = checkpoint.pop(
f"layers.{i}.attention.wv.{lora_suffix}"
)
Comment on lines +190 to +203
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoRA weight renaming logic handles wq, wk, and wv but is missing handling for wo (output projection). According to convert_weights.py, o_proj LoRA weights are converted to layers.{}.attention.wo.lora_a.weight and layers.{}.attention.wo.lora_b.weight. These weights need to be preserved as-is (not renamed) since wo is not converted to a ModuleList in static attention, unlike wqs/wks/wvs. However, if there are LoRA weights for wo, they should be explicitly handled to ensure they're loaded correctly.

Copilot uses AI. Check for mistakes.

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
Expand Down Expand Up @@ -263,6 +326,20 @@ def main():
help="Output filename for the .pte model",
)

# LoRA adapter options
parser.add_argument(
"--adapter_checkpoint",
type=str,
default=None,
help="Path to LoRA adapter weights (adapter_model.safetensors)",
)
parser.add_argument(
"--adapter_config",
type=str,
default=None,
help="Path to adapter config (adapter_config.json)",
)

# Model configuration
parser.add_argument(
"--max_context_len",
Expand Down Expand Up @@ -345,6 +422,8 @@ def main():
args.checkpoint,
args.params,
args.max_context_len,
args.adapter_checkpoint,
args.adapter_config,
)
print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim")

Expand All @@ -362,6 +441,33 @@ def main():
in_target_split_size=1,
in_max_splits=1,
)
try:
from executorch.examples.models.llama.lora import LoRALinear
except ImportError:
LoRALinear = None # type: ignore[assignment]
print("LoRALinear import failed, will only quantize nn.Linear layers.")

def make_linear_filter_fn(group_size=0):
"""Create a filter function for linear quantization.
Args:
group_size: Group size for quantization. 0 means per-axis (no constraint).
"""
def filter_fn(m, fqn):
# Check if it's a regular nn.Linear
is_linear = isinstance(m, nn.Linear)
# Check if it's a LoRALinear (which has a base weight parameter to quantize)
is_lora_linear = LoRALinear is not None and isinstance(m, LoRALinear)
if not (is_linear or is_lora_linear):
return False

# For per-axis (group_size=0), no shape constraint
if group_size == 0:
return True

# Check if the weight shape is compatible with group size
return m.weight.shape[1] % group_size == 0

return filter_fn

# Apply embedding quantization
if args.embedding_quantize:
Expand Down Expand Up @@ -392,6 +498,7 @@ def main():
weight_dtype=torch.int4,
granularity=PerGroup(32),
),
filter_fn=make_linear_filter_fn(group_size=32),
)
elif args.linear_quantize == "c4w":
print("\nQuantizing linear layers: 4-bit channelwise...")
Expand All @@ -401,6 +508,7 @@ def main():
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
filter_fn=make_linear_filter_fn(group_size=0),
)

# Add graph breaks between transformer blocks
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class ModelArgs:
lora_args: Optional[dict] = None

# LoRA arguments to set up a LoRA inference model.
# These arguments come directly from a torchtune adapter_config.json file.
# These arguments come directly from a torchtune/unsloth adapter_config.json file.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Modules that we can apply lora adapters to.
Expand Down
47 changes: 37 additions & 10 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ForwardOptions,
register_attention,
)
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope

Expand Down Expand Up @@ -784,22 +785,45 @@ def __init__(
# Possibly disable in future, depending on bug fixes in Core ML runtime
self.decompose_sdpa_in_mha: bool = kwargs.get("decompose_sdpa_in_mha", False)

# LoRA configuration
self.target_modules = config.target_modules
self.lora_rank = config.r
self.lora_alpha = config.lora_alpha
if self.target_modules:
assert (
self.lora_rank is not None and self.lora_alpha is not None
), "LoRA rank and alpha must be specified when target_modules is provided"


def _make_linear(in_dim: int, out_dim: int, bias: bool, lora_target: str) -> nn.Module:
"""Create a linear layer with optional LoRA support."""
if self.target_modules is not None and lora_target in self.target_modules:
return LoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=self.lora_rank,
alpha=self.lora_alpha,
dropout=0.0,
use_bias=bias,
)
return nn.Linear(in_dim, out_dim, bias=bias)

if self.split_mha:
self.wqs = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
_make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "q_proj")
for _ in range(self.n_heads)
]
)
self.wks = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
_make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "k_proj")
for _ in range(self.n_kv_heads)
]
)
self.wvs = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
_make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "v_proj")
for _ in range(self.n_kv_heads)
]
)
Expand All @@ -813,36 +837,39 @@ def __init__(
else:
self.wqs = nn.ModuleList(
[
nn.Linear(
_make_linear(
self.dim,
self.head_dim * self.n_heads,
bias=self.attention_qkv_bias,
self.attention_qkv_bias,
"q_proj",
)
]
)
self.wks = nn.ModuleList(
[
nn.Linear(
_make_linear(
self.dim,
self.head_dim * self.n_kv_heads,
bias=self.attention_qkv_bias,
self.attention_qkv_bias,
"k_proj",
)
]
)
self.wvs = nn.ModuleList(
[
nn.Linear(
_make_linear(
self.dim,
self.head_dim * self.n_kv_heads,
bias=self.attention_qkv_bias,
self.attention_qkv_bias,
"v_proj",
)
]
)

self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)])
self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)])

self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
self.wo = _make_linear(self.n_heads * self.head_dim, self.dim, False, "o_proj")
self.rope = _Rope(rope.params)
self.layer_id = layer_id

Expand Down
Loading