diff --git a/examples/apple/coreml/llama/run_static_llm.py b/examples/apple/coreml/llama/run_static_llm.py index 2cd526aec42..107448b3453 100644 --- a/examples/apple/coreml/llama/run_static_llm.py +++ b/examples/apple/coreml/llama/run_static_llm.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Run script for static attention Llama models exported with coreml_static_llama.py. +Run script for static attention LLM models exported with export_static_llm_coreml.py. Usage: python run_static_llm.py \ @@ -21,7 +21,6 @@ import time from typing import Any, Dict, List, Tuple -import sentencepiece as spm import torch import torch.utils._pytree as pytree @@ -29,50 +28,14 @@ from executorch.examples.models.llama.runner.generation import next_token from executorch.examples.models.llama.static_attention import StaticAttentionIOManager from executorch.runtime import Runtime +from pytorch_tokenizers import get_tokenizer -class Tokenizer: - """Wrapper to support both SentencePiece and Tiktoken tokenizers.""" - - def __init__(self, model_path: str): - try: - print("Trying to load sentencepiece") - sp = spm.SentencePieceProcessor() - sp.load(model_path) - self.tokenizer = sp - self._is_sentencepiece = True - except Exception: - print("Trying to load tiktoken") - from executorch.examples.models.llama.tokenizer import tiktoken - - self.tokenizer = tiktoken.Tokenizer(model_path) - self._is_sentencepiece = False - - def encode(self, text: str, bos: bool = True, eos: bool = False) -> List[int]: - if self._is_sentencepiece: - bos_string = "" if bos else "" - eos_string = "" if eos else "" - return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") - return self.tokenizer.encode(text, bos=bos, eos=eos) - - def decode(self, tokens: List[int]) -> str: - if self._is_sentencepiece: - return self.tokenizer.decode(tokens) - return self.tokenizer.decode(tokens) - - def decode_token(self, token: int) -> str: - if self._is_sentencepiece: - return self.tokenizer.decode([token]) - try: - return self.tokenizer.decode_token(token) - except UnicodeDecodeError: - return f"<{token}>" - - @property - def stop_tokens(self) -> List[int]: - if self._is_sentencepiece: - return [self.tokenizer.eos_id()] - return self.tokenizer.stop_tokens +def get_stop_tokens(tokenizer) -> List[int]: + """Get stop tokens from tokenizer, falling back to eos_id if not available.""" + if hasattr(tokenizer, "stop_tokens"): + return tokenizer.stop_tokens + return [tokenizer.eos_id] def create_pte_wrapper( @@ -143,6 +106,12 @@ def main(): required=True, help="Path to tokenizer model", ) + parser.add_argument( + "--tokenizer_config", + type=str, + default=None, + help="Path to tokenizer config (required for HuggingFace tokenizers)", + ) parser.add_argument( "--prompt", type=str, @@ -206,7 +175,8 @@ def main(): args = parser.parse_args() # Load tokenizer - tokenizer = Tokenizer(args.tokenizer) + tokenizer = get_tokenizer(args.tokenizer, args.tokenizer_config) + stop_tokens = get_stop_tokens(tokenizer) # Load model params with open(args.params, "r") as f: @@ -291,7 +261,7 @@ def main(): ngram_size=args.ngram_size, window_size=args.window_size, n_verifications=args.n_verifications, - stop_tokens=tokenizer.stop_tokens, + stop_tokens=stop_tokens, ) else: # Use standard autoregressive decoding @@ -299,12 +269,12 @@ def main(): model_fn, first_token, n=args.max_new_tokens - 1, # -1 because first_token counts - stop_tokens=tokenizer.stop_tokens, + stop_tokens=stop_tokens, ) # Print generated tokens (skip first as it's the init_token we already printed) for token in generated_tokens[1:]: - if token in tokenizer.stop_tokens: + if token in stop_tokens: break print(tokenizer.decode_token(token), end="", flush=True)