Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 18 additions & 48 deletions examples/apple/coreml/llama/run_static_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

"""
Run script for static attention Llama models exported with coreml_static_llama.py.
Run script for static attention LLM models exported with export_static_llm_coreml.py.

Usage:
python run_static_llm.py \
Expand All @@ -21,58 +21,21 @@
import time
from typing import Any, Dict, List, Tuple

import sentencepiece as spm
import torch
import torch.utils._pytree as pytree

from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.runner.generation import next_token
from executorch.examples.models.llama.static_attention import StaticAttentionIOManager
from executorch.runtime import Runtime
from pytorch_tokenizers import get_tokenizer


class Tokenizer:
"""Wrapper to support both SentencePiece and Tiktoken tokenizers."""

def __init__(self, model_path: str):
try:
print("Trying to load sentencepiece")
sp = spm.SentencePieceProcessor()
sp.load(model_path)
self.tokenizer = sp
self._is_sentencepiece = True
except Exception:
print("Trying to load tiktoken")
from executorch.examples.models.llama.tokenizer import tiktoken

self.tokenizer = tiktoken.Tokenizer(model_path)
self._is_sentencepiece = False

def encode(self, text: str, bos: bool = True, eos: bool = False) -> List[int]:
if self._is_sentencepiece:
bos_string = "<s>" if bos else ""
eos_string = "</s>" if eos else ""
return self.tokenizer.encode(f"{bos_string}{text}{eos_string}")
return self.tokenizer.encode(text, bos=bos, eos=eos)

def decode(self, tokens: List[int]) -> str:
if self._is_sentencepiece:
return self.tokenizer.decode(tokens)
return self.tokenizer.decode(tokens)

def decode_token(self, token: int) -> str:
if self._is_sentencepiece:
return self.tokenizer.decode([token])
try:
return self.tokenizer.decode_token(token)
except UnicodeDecodeError:
return f"<{token}>"

@property
def stop_tokens(self) -> List[int]:
if self._is_sentencepiece:
return [self.tokenizer.eos_id()]
return self.tokenizer.stop_tokens
def get_stop_tokens(tokenizer) -> List[int]:
"""Get stop tokens from tokenizer, falling back to eos_id if not available."""
if hasattr(tokenizer, "stop_tokens"):
return tokenizer.stop_tokens
return [tokenizer.eos_id]


def create_pte_wrapper(
Expand Down Expand Up @@ -143,6 +106,12 @@ def main():
required=True,
help="Path to tokenizer model",
)
parser.add_argument(
"--tokenizer_config",
type=str,
default=None,
help="Path to tokenizer config (required for HuggingFace tokenizers)",
)
parser.add_argument(
"--prompt",
type=str,
Expand Down Expand Up @@ -206,7 +175,8 @@ def main():
args = parser.parse_args()

# Load tokenizer
tokenizer = Tokenizer(args.tokenizer)
tokenizer = get_tokenizer(args.tokenizer, args.tokenizer_config)
stop_tokens = get_stop_tokens(tokenizer)

# Load model params
with open(args.params, "r") as f:
Expand Down Expand Up @@ -291,20 +261,20 @@ def main():
ngram_size=args.ngram_size,
window_size=args.window_size,
n_verifications=args.n_verifications,
stop_tokens=tokenizer.stop_tokens,
stop_tokens=stop_tokens,
)
else:
# Use standard autoregressive decoding
generated_tokens = mgr.decode(
model_fn,
first_token,
n=args.max_new_tokens - 1, # -1 because first_token counts
stop_tokens=tokenizer.stop_tokens,
stop_tokens=stop_tokens,
)

# Print generated tokens (skip first as it's the init_token we already printed)
for token in generated_tokens[1:]:
if token in tokenizer.stop_tokens:
if token in stop_tokens:
break
print(tokenizer.decode_token(token), end="", flush=True)

Expand Down
Loading