diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index a3fd8201414..2ceae7ddd29 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -21,6 +21,7 @@ import argparse import json +from typing import Optional import coremltools as ct import torch @@ -98,11 +99,42 @@ def remove_graph_break_(edge_manager): edge_manager.exported_program().graph_module.graph.eliminate_dead_code() -def load_model(checkpoint_path: str, params_path: str, max_context_len: int): - """Load the model from checkpoint with static_mha attention type.""" +def load_model( + checkpoint_path: str, + params_path: str, + max_context_len: int, + adapter_checkpoint_path: Optional[str] = None, + adapter_config_path: Optional[str] = None, +): + """Load the model from checkpoint with static_mha attention type. + + Args: + checkpoint_path: Path to model checkpoint (.pth) + params_path: Path to params.json + max_context_len: Maximum context length + adapter_checkpoint_path: Optional path to LoRA adapter weights (adapter_model.safetensors) + adapter_config_path: Optional path to adapter config (adapter_config.json) + """ with open(params_path, "r") as f: params = json.loads(f.read()) + assert (adapter_config_path is None and adapter_checkpoint_path is None) or ( + adapter_config_path is not None and adapter_checkpoint_path is not None + ), "Both adapter_config_path and adapter_checkpoint_path must be provided together, or neither." + + # Load adapter config if provided + adapter_config = None + if adapter_config_path is not None: + with open(adapter_config_path, "r") as f: + adapter_config = json.loads(f.read()) + print(f"Loaded adapter config: rank={adapter_config.get('r')}, alpha={adapter_config.get('lora_alpha')}") + print(f"Target modules: {adapter_config.get('target_modules')}") + + # Merge adapter config into params + params["r"] = adapter_config.get("r") + params["lora_alpha"] = adapter_config.get("lora_alpha") + params["target_modules"] = adapter_config.get("target_modules") + # TODO: to support lookahead decoding, the static model outputs # full logits, but if we are not using lookahead decoding, we can have a # more efficient model by setting generate_full_logits=False and supplying the last @@ -124,8 +156,24 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int): if "model" in checkpoint: checkpoint = checkpoint["model"] + # Load and merge adapter weights if provided + if adapter_checkpoint_path is not None: + print(f"Loading LoRA adapter from {adapter_checkpoint_path}...") + from safetensors.torch import load_file + from executorch.examples.models.llama.convert_weights import unsloth_to_meta + + adapter_weights = load_file(adapter_checkpoint_path) + # Convert adapter weight keys to Meta format + adapter_weights = unsloth_to_meta(adapter_weights) + print(f"Loaded {len(adapter_weights)} adapter weights") + + # Merge adapter weights into checkpoint + checkpoint.update(adapter_weights) + # Rename attention weight keys for static attention + # This handles both base weights and LoRA weights for i in range(len(model.layers)): + # Base weights if f"layers.{i}.attention.wq.weight" in checkpoint: checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( f"layers.{i}.attention.wq.weight" @@ -139,6 +187,21 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int): f"layers.{i}.attention.wv.weight" ) + # LoRA weights (lora_a and lora_b) + for lora_suffix in ["lora_a.weight", "lora_b.weight"]: + if f"layers.{i}.attention.wq.{lora_suffix}" in checkpoint: + checkpoint[f"layers.{i}.attention.wqs.0.{lora_suffix}"] = checkpoint.pop( + f"layers.{i}.attention.wq.{lora_suffix}" + ) + if f"layers.{i}.attention.wk.{lora_suffix}" in checkpoint: + checkpoint[f"layers.{i}.attention.wks.0.{lora_suffix}"] = checkpoint.pop( + f"layers.{i}.attention.wk.{lora_suffix}" + ) + if f"layers.{i}.attention.wv.{lora_suffix}" in checkpoint: + checkpoint[f"layers.{i}.attention.wvs.0.{lora_suffix}"] = checkpoint.pop( + f"layers.{i}.attention.wv.{lora_suffix}" + ) + missing, unexpected = model.load_state_dict( checkpoint, strict=False, @@ -263,6 +326,20 @@ def main(): help="Output filename for the .pte model", ) + # LoRA adapter options + parser.add_argument( + "--adapter_checkpoint", + type=str, + default=None, + help="Path to LoRA adapter weights (adapter_model.safetensors)", + ) + parser.add_argument( + "--adapter_config", + type=str, + default=None, + help="Path to adapter config (adapter_config.json)", + ) + # Model configuration parser.add_argument( "--max_context_len", @@ -345,6 +422,8 @@ def main(): args.checkpoint, args.params, args.max_context_len, + args.adapter_checkpoint, + args.adapter_config, ) print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") @@ -362,6 +441,33 @@ def main(): in_target_split_size=1, in_max_splits=1, ) + try: + from executorch.examples.models.llama.lora import LoRALinear + except ImportError: + LoRALinear = None # type: ignore[assignment] + print("LoRALinear import failed, will only quantize nn.Linear layers.") + + def make_linear_filter_fn(group_size=0): + """Create a filter function for linear quantization. + Args: + group_size: Group size for quantization. 0 means per-axis (no constraint). + """ + def filter_fn(m, fqn): + # Check if it's a regular nn.Linear + is_linear = isinstance(m, nn.Linear) + # Check if it's a LoRALinear (which has a base weight parameter to quantize) + is_lora_linear = LoRALinear is not None and isinstance(m, LoRALinear) + if not (is_linear or is_lora_linear): + return False + + # For per-axis (group_size=0), no shape constraint + if group_size == 0: + return True + + # Check if the weight shape is compatible with group size + return m.weight.shape[1] % group_size == 0 + + return filter_fn # Apply embedding quantization if args.embedding_quantize: @@ -392,6 +498,7 @@ def main(): weight_dtype=torch.int4, granularity=PerGroup(32), ), + filter_fn=make_linear_filter_fn(group_size=32), ) elif args.linear_quantize == "c4w": print("\nQuantizing linear layers: 4-bit channelwise...") @@ -401,6 +508,7 @@ def main(): weight_dtype=torch.int4, granularity=PerAxis(0), ), + filter_fn=make_linear_filter_fn(group_size=0), ) # Add graph breaks between transformer blocks diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index a0e9eb70498..33b15cd181d 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -117,7 +117,7 @@ class ModelArgs: lora_args: Optional[dict] = None # LoRA arguments to set up a LoRA inference model. - # These arguments come directly from a torchtune adapter_config.json file. + # These arguments come directly from a torchtune/unsloth adapter_config.json file. r: Optional[int] = None # Rank. lora_alpha: Optional[int] = None # Alpha. # Modules that we can apply lora adapters to. diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 9eef4413a63..143e778b1f5 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -13,6 +13,7 @@ ForwardOptions, register_attention, ) +from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope @@ -784,22 +785,45 @@ def __init__( # Possibly disable in future, depending on bug fixes in Core ML runtime self.decompose_sdpa_in_mha: bool = kwargs.get("decompose_sdpa_in_mha", False) + # LoRA configuration + self.target_modules = config.target_modules + self.lora_rank = config.r + self.lora_alpha = config.lora_alpha + if self.target_modules: + assert ( + self.lora_rank is not None and self.lora_alpha is not None + ), "LoRA rank and alpha must be specified when target_modules is provided" + + + def _make_linear(in_dim: int, out_dim: int, bias: bool, lora_target: str) -> nn.Module: + """Create a linear layer with optional LoRA support.""" + if self.target_modules is not None and lora_target in self.target_modules: + return LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=self.lora_rank, + alpha=self.lora_alpha, + dropout=0.0, + use_bias=bias, + ) + return nn.Linear(in_dim, out_dim, bias=bias) + if self.split_mha: self.wqs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + _make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "q_proj") for _ in range(self.n_heads) ] ) self.wks = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + _make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "k_proj") for _ in range(self.n_kv_heads) ] ) self.wvs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + _make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "v_proj") for _ in range(self.n_kv_heads) ] ) @@ -813,28 +837,31 @@ def __init__( else: self.wqs = nn.ModuleList( [ - nn.Linear( + _make_linear( self.dim, self.head_dim * self.n_heads, - bias=self.attention_qkv_bias, + self.attention_qkv_bias, + "q_proj", ) ] ) self.wks = nn.ModuleList( [ - nn.Linear( + _make_linear( self.dim, self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, + self.attention_qkv_bias, + "k_proj", ) ] ) self.wvs = nn.ModuleList( [ - nn.Linear( + _make_linear( self.dim, self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, + self.attention_qkv_bias, + "v_proj", ) ] ) @@ -842,7 +869,7 @@ def __init__( self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + self.wo = _make_linear(self.n_heads * self.head_dim, self.dim, False, "o_proj") self.rope = _Rope(rope.params) self.layer_id = layer_id