diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index adcb04d98c..218cfa0716 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -303,6 +303,7 @@ def preprocessing_pipeline( ) operations = [] if use_sft: + input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer) operations.append( input_pipeline_utils.SFTPromptMasking( text_column_name=data_column_names[0], diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index 6db25ec0f6..9f5074f126 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -19,6 +19,8 @@ from threading import current_thread from typing import Any, Iterable, TYPE_CHECKING +from jinja2 import TemplateError + if TYPE_CHECKING: import datasets @@ -172,6 +174,103 @@ def is_conversational(features, data_columns): return False +def _extract_token_ids(tokens): + """Extracts token IDs from various tokenizer output formats. + + This helper function standardizes the extraction of tokenized integer IDs + from common return types of Hugging Face tokenizers, including + `BatchEncoding` objects, dictionaries, or simple lists. + + Args: + tokens: The object containing token IDs. Supported types include: + - A list of integers. + - A dictionary containing the `INPUT_TOKENS_KEY`. + - An object (e.g., `BatchEncoding`) with an attribute named `INPUT_TOKENS_KEY`. + + Returns: + A list of integer token IDs. + + Raises: + ValueError: If the input type is not supported or does not contain the expected key. + """ + # attention masks in BatchEncoding are effectively ignored + if hasattr(tokens, INPUT_TOKENS_KEY): + return getattr(tokens, INPUT_TOKENS_KEY) + elif isinstance(tokens, dict) and INPUT_TOKENS_KEY in tokens: + return tokens[INPUT_TOKENS_KEY] + elif isinstance(tokens, list): + return tokens + else: + raise ValueError(f"Can't extract token_ids from type {type(tokens)}") + + +def verify_chat_template_generation_prompt_logic(tokenizer_model): + """Verifies the tokenizer's chat template for correct SFT loss masking. + + This function ensures that the tokens added by `add_generation_prompt=True` + are identical to the tokens that begin an assistant's turn in a complete + conversation, which is critical for masking prompt tokens during SFT loss + calculation. + + Example of a mismatch: + A `ValueError` is raised if the generation prompt and the actual + assistant prefix do not match. For example: + + - `add_generation_prompt=True` on a user message produces a prompt ending in: + `...<|im_start|>generation\n` + - A full turn with an assistant message starts the reply with: + `...<|im_start|>assistant\n...` + + This function would fail because the tokens for "generation" do not + match the tokens for "assistant". + + Args: + tokenizer_model: The Hugging Face tokenizer instance to verify. + + Raises: + ValueError: If the `add_generation_prompt` tokens do not exactly + match the beginning of an assistant message in the template. + """ + dummy_msgs = [{"role": "system", "content": "System message"}, {"role": "user", "content": "Test message"}] + + try: + prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True) + except TemplateError: + max_logging.info( + "Tokenizer failed to apply chat template with 'system' role. " + "Falling back to 'user' role only for chat template verification." + ) + dummy_msgs.pop(0) + prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True) + prompt_wo_gen_ids = _extract_token_ids(prompt_wo_gen_tokens) + + prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True) + prompt_w_gen_ids = _extract_token_ids(prompt_w_gen_tokens) + + if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids: + raise ValueError("Unable to extract generation prompt tokens.") + # Extract the tokenized generation prompt (the expected assistant prefix) + assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :] + full_turn_tokens = _extract_token_ids( + tokenizer_model.apply_chat_template( + dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True + ) + ) + full_turn_ids = _extract_token_ids(full_turn_tokens) + # Extract the actual tokens that appear right after the user message in the full turn + actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)] + + if actual_prefix_in_full_turn != assistant_prefix: + expected_str = tokenizer_model.decode(assistant_prefix) + actual_str = tokenizer_model.decode(actual_prefix_in_full_turn) + raise ValueError( + "Chat template generation prompt mismatch!\n" + f"Expected assistant prefix tokens: {assistant_prefix} ('{expected_str}')\n" + f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n" + "This means the tokenizer's chat template will break the sft masking logic." + ) + + def _get_completion_in_chat_template(tokenizer_model, round_msgs): """ Calculates the completion part of a conversation turn when formatted with a chat template. @@ -190,18 +289,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs): # include generation_prompt as part of the prompt tokens prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True) - # attention masks in BatchEncoding are effectively ignored - if hasattr(prompt_completion_tokens, INPUT_TOKENS_KEY): - prompt_completion_ids = getattr(prompt_completion_tokens, INPUT_TOKENS_KEY) - prompt_ids = getattr(prompt_tokens, INPUT_TOKENS_KEY) - elif isinstance(prompt_completion_tokens, dict) and INPUT_TOKENS_KEY in prompt_completion_tokens: - prompt_completion_ids = prompt_completion_tokens[INPUT_TOKENS_KEY] - prompt_ids = prompt_tokens[INPUT_TOKENS_KEY] - elif isinstance(prompt_completion_tokens, list): - prompt_completion_ids = prompt_completion_tokens - prompt_ids = prompt_tokens - else: - raise ValueError(f"Can't handle the chat template output of type {type(prompt_completion_tokens)}") + prompt_completion_ids = _extract_token_ids(prompt_completion_tokens) + prompt_ids = _extract_token_ids(prompt_tokens) completion_tokens = prompt_completion_ids[len(prompt_ids) :] completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False) diff --git a/tests/unit/sft_data_processing_test.py b/tests/unit/sft_data_processing_test.py index bd8092c12f..37e54d510c 100644 --- a/tests/unit/sft_data_processing_test.py +++ b/tests/unit/sft_data_processing_test.py @@ -19,17 +19,19 @@ import pytest import numpy as np import jax +import re from jax.sharding import Mesh from jax.experimental import mesh_utils from datasets import Dataset import transformers from parameterized import parameterized_class - +from unittest.mock import patch from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT from maxtext.input_pipeline import hf_data_processing from maxtext.input_pipeline import input_pipeline_interface from maxtext.input_pipeline.hf_data_processing import _get_pad_id +from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic PROMPT_DATA = [ [ @@ -480,5 +482,51 @@ def test_system_message_not_at_beginning(self): self.get_data_iterator(dataset, ["messages"]) +@pytest.mark.external_training +class SFTChatTemplateLogicTest(unittest.TestCase): + LLAMA_TOKENIZER_PATH = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer") + + @classmethod + def setUpClass(cls): + super().setUpClass() + if not os.path.exists(cls.LLAMA_TOKENIZER_PATH): + exit_code = subprocess.call( + [ + "gsutil", + "cp", + "-r", + "gs://maxtext-dataset/hf/llama2-chat-tokenizer", + os.path.join(MAXTEXT_ASSETS_ROOT, ""), + ] + ) + if exit_code != 0: + raise ValueError("Failed to download llama tokenizer") + + def setUp(self): + super().setUp() + self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") + self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH) + + def test_tokenizer_w_generation_prompt(self): + verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer) + + def test_tokenizer_wo_generation_prompt(self): + verify_chat_template_generation_prompt_logic(self.llama2_tokenizer) + + def test_failure_path_with_modified_template(self): + """Verifies the function correctly raises a ValueError on a bad template.""" + # Replace the role within the existing add_generation_prompt block with a deliberately faulty one. + fault_chat_template = re.sub( + r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})", + r"\1wrong_role\2", + self.qwen3_tokenizer.chat_template, + flags=re.DOTALL, + ) + with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template): + # Verify that our function catches the mismatch and raises the expected error + with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"): + verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer) + + if __name__ == "__main__": unittest.main()