Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/input_pipeline/hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def preprocessing_pipeline(
)
operations = []
if use_sft:
input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer)
operations.append(
input_pipeline_utils.SFTPromptMasking(
text_column_name=data_column_names[0],
Expand Down
113 changes: 101 additions & 12 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from threading import current_thread
from typing import Any, Iterable, TYPE_CHECKING

from jinja2 import TemplateError

if TYPE_CHECKING:
import datasets

Expand Down Expand Up @@ -172,6 +174,103 @@ def is_conversational(features, data_columns):
return False


def _extract_token_ids(tokens):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the token_ids extraction into a function so it can be reused in verify_chat_template_generation_prompt_logic.

"""Extracts token IDs from various tokenizer output formats.

This helper function standardizes the extraction of tokenized integer IDs
from common return types of Hugging Face tokenizers, including
`BatchEncoding` objects, dictionaries, or simple lists.

Args:
tokens: The object containing token IDs. Supported types include:
- A list of integers.
- A dictionary containing the `INPUT_TOKENS_KEY`.
- An object (e.g., `BatchEncoding`) with an attribute named `INPUT_TOKENS_KEY`.

Returns:
A list of integer token IDs.

Raises:
ValueError: If the input type is not supported or does not contain the expected key.
"""
# attention masks in BatchEncoding are effectively ignored
if hasattr(tokens, INPUT_TOKENS_KEY):
return getattr(tokens, INPUT_TOKENS_KEY)
elif isinstance(tokens, dict) and INPUT_TOKENS_KEY in tokens:
return tokens[INPUT_TOKENS_KEY]
elif isinstance(tokens, list):
return tokens
else:
raise ValueError(f"Can't extract token_ids from type {type(tokens)}")


def verify_chat_template_generation_prompt_logic(tokenizer_model):
"""Verifies the tokenizer's chat template for correct SFT loss masking.

This function ensures that the tokens added by `add_generation_prompt=True`
are identical to the tokens that begin an assistant's turn in a complete
conversation, which is critical for masking prompt tokens during SFT loss
calculation.

Example of a mismatch:
A `ValueError` is raised if the generation prompt and the actual
assistant prefix do not match. For example:

- `add_generation_prompt=True` on a user message produces a prompt ending in:
`...<|im_start|>generation\n`
- A full turn with an assistant message starts the reply with:
`...<|im_start|>assistant\n...`

This function would fail because the tokens for "generation" do not
match the tokens for "assistant".

Args:
tokenizer_model: The Hugging Face tokenizer instance to verify.

Raises:
ValueError: If the `add_generation_prompt` tokens do not exactly
match the beginning of an assistant message in the template.
"""
dummy_msgs = [{"role": "system", "content": "System message"}, {"role": "user", "content": "Test message"}]

try:
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
except TemplateError:
max_logging.info(
"Tokenizer failed to apply chat template with 'system' role. "
"Falling back to 'user' role only for chat template verification."
)
dummy_msgs.pop(0)
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
prompt_wo_gen_ids = _extract_token_ids(prompt_wo_gen_tokens)

prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True)
prompt_w_gen_ids = _extract_token_ids(prompt_w_gen_tokens)

if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids:
raise ValueError("Unable to extract generation prompt tokens.")
# Extract the tokenized generation prompt (the expected assistant prefix)
assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :]
full_turn_tokens = _extract_token_ids(
tokenizer_model.apply_chat_template(
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True
)
)
full_turn_ids = _extract_token_ids(full_turn_tokens)
# Extract the actual tokens that appear right after the user message in the full turn
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]

if actual_prefix_in_full_turn != assistant_prefix:
expected_str = tokenizer_model.decode(assistant_prefix)
actual_str = tokenizer_model.decode(actual_prefix_in_full_turn)
raise ValueError(
"Chat template generation prompt mismatch!\n"
f"Expected assistant prefix tokens: {assistant_prefix} ('{expected_str}')\n"
f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n"
"This means the tokenizer's chat template will break the sft masking logic."
)


def _get_completion_in_chat_template(tokenizer_model, round_msgs):
"""
Calculates the completion part of a conversation turn when formatted with a chat template.
Expand All @@ -190,18 +289,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
# include generation_prompt as part of the prompt tokens
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)

# attention masks in BatchEncoding are effectively ignored
if hasattr(prompt_completion_tokens, INPUT_TOKENS_KEY):
prompt_completion_ids = getattr(prompt_completion_tokens, INPUT_TOKENS_KEY)
prompt_ids = getattr(prompt_tokens, INPUT_TOKENS_KEY)
elif isinstance(prompt_completion_tokens, dict) and INPUT_TOKENS_KEY in prompt_completion_tokens:
prompt_completion_ids = prompt_completion_tokens[INPUT_TOKENS_KEY]
prompt_ids = prompt_tokens[INPUT_TOKENS_KEY]
elif isinstance(prompt_completion_tokens, list):
prompt_completion_ids = prompt_completion_tokens
prompt_ids = prompt_tokens
else:
raise ValueError(f"Can't handle the chat template output of type {type(prompt_completion_tokens)}")
prompt_completion_ids = _extract_token_ids(prompt_completion_tokens)
prompt_ids = _extract_token_ids(prompt_tokens)

completion_tokens = prompt_completion_ids[len(prompt_ids) :]
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
Expand Down
50 changes: 49 additions & 1 deletion tests/unit/sft_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@
import pytest
import numpy as np
import jax
import re
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from datasets import Dataset
import transformers
from parameterized import parameterized_class

from unittest.mock import patch
from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT
from maxtext.input_pipeline import hf_data_processing
from maxtext.input_pipeline import input_pipeline_interface
from maxtext.input_pipeline.hf_data_processing import _get_pad_id
from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic

PROMPT_DATA = [
[
Expand Down Expand Up @@ -480,5 +482,51 @@ def test_system_message_not_at_beginning(self):
self.get_data_iterator(dataset, ["messages"])


@pytest.mark.external_training
class SFTChatTemplateLogicTest(unittest.TestCase):
LLAMA_TOKENIZER_PATH = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer")

@classmethod
def setUpClass(cls):
super().setUpClass()
if not os.path.exists(cls.LLAMA_TOKENIZER_PATH):
exit_code = subprocess.call(
[
"gsutil",
"cp",
"-r",
"gs://maxtext-dataset/hf/llama2-chat-tokenizer",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it okay to use a private gs: location ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @SurbhiJainUSC ,
Do you have any suggestions? I am referencing the practice here

os.path.join(MAXTEXT_ASSETS_ROOT, ""),
]
)
if exit_code != 0:
raise ValueError("Failed to download llama tokenizer")

def setUp(self):
super().setUp()
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH)

def test_tokenizer_w_generation_prompt(self):
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)

def test_tokenizer_wo_generation_prompt(self):
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)

def test_failure_path_with_modified_template(self):
"""Verifies the function correctly raises a ValueError on a bad template."""
# Replace the role within the existing add_generation_prompt block with a deliberately faulty one.
fault_chat_template = re.sub(
r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})",
r"\1wrong_role\2",
self.qwen3_tokenizer.chat_template,
flags=re.DOTALL,
)
with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template):
# Verify that our function catches the mismatch and raises the expected error
with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"):
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)


if __name__ == "__main__":
unittest.main()
Loading