Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,14 @@ class Distillation(BaseModel):
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
)

# --- Offline Distillation Fields ---
offline_distillation: bool = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks redundant
if you specify offline_data_dir parameter, that can be a direct sign of switching to the offline processing

False, description="If True, enables offline distillation using pre-generated teacher data."
)
offline_data_dir: Optional[str] = Field(
None, description="GCS or local path to the pre-generated ArrayRecord teacher data."
)

# --- Loss Params ---
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
model structures with Tunix's training interfaces.
"""

import pickle
import tensorflow as tf
from array_record.python import array_record_module

from typing import Any, Iterator, Optional, List, Callable

import flax
Expand Down Expand Up @@ -63,13 +67,60 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput):
targets_position: jax.Array = None
#: Segment IDs for packed target tokens.
targets_segmentation: jax.Array = None
#: Top-K logits from the teacher model.
top_k_logits: jax.Array = None
top_k_indices: jax.Array = None


# -----------------------------------------------------------------------------
# Data Loading Adapter
# -----------------------------------------------------------------------------


class OfflineArrayRecordIterator:
"""Reads the pre-generated global top-k logits file."""

def __init__(self, data_dir: str, epochs: int = 100):
self.filepath = data_dir

if not tf.io.gfile.exists(self.filepath):
raise FileNotFoundError(f"Offline distillation file not found: {self.filepath}")

self.reader = array_record_module.ArrayRecordReader(self.filepath)
self.num_records = self.reader.num_records()
self.epochs = epochs
self.current_epoch = 0
self.record_index = 0

def __iter__(self):
return self

def __next__(self):
if self.record_index >= self.num_records:
self.current_epoch += 1
if self.current_epoch >= self.epochs:
raise StopIteration

self.record_index = 0
self.reader = array_record_module.ArrayRecordReader(self.filepath)

record = self.reader.read()
self.record_index += 1
data = pickle.loads(record)

# Map the arrays to match MaxText's expected dictionary
batch = {
"inputs": data["tokens"],
"top_k_logits": data["top_k_logits"],
"top_k_indices": data["top_k_indices"],
}
for key in ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]:
if key in data:
batch[key] = data[key]

return batch


class MaxTextToTunixIterator:
"""Adapts the raw dictionary output of MaxText's data loader to Tunix objects.

Expand Down Expand Up @@ -123,6 +174,8 @@ def __next__(self) -> MaxTextTrainingInput:
targets=batch["targets"],
targets_position=targets_position,
targets_segmentation=targets_segmentation,
top_k_logits=batch.get("top_k_logits"),
top_k_indices=batch.get("top_k_indices"),
)


Expand Down
82 changes: 55 additions & 27 deletions src/maxtext/trainers/post_train/distillation/train_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
"""

from typing import Sequence, Callable
from absl import app
from flax import nnx
Expand Down Expand Up @@ -292,6 +291,8 @@ def _prepare_inputs(
targets=input_data.targets,
targets_position=input_data.targets_position,
targets_segmentation=input_data.targets_segmentation,
top_k_logits=input_data.top_k_logits,
top_k_indices=input_data.top_k_indices,
)

def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None:
Expand Down Expand Up @@ -390,7 +391,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
# -----------------------------------------------------------------------------


def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters) -> None:
def train_distill(
student_config: pyconfig.HyperParameters,
teacher_config: pyconfig.HyperParameters,
is_offline: bool = False,
offline_data_dir: str | None = None,
) -> None:
"""Main distillation training loop.
Orchestrates the loading of both student and teacher models, configures the
Expand Down Expand Up @@ -426,9 +432,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
_log_config_details(student_config, "Student")
student_model = get_maxtext_model(student_config, mesh)

max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...")
_log_config_details(teacher_config, "Teacher")
teacher_model = get_maxtext_model(teacher_config, mesh)
# Skip teacher model loading if offline
if is_offline:
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
teacher_model = None
else:
max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...")
_log_config_details(teacher_config, "Teacher")
teacher_model = get_maxtext_model(teacher_config, mesh)
teacher_model.eval()

# 3. Define Distillation Strategy
def labels_fn(targets, targets_segmentation=None, **kwargs):
Expand Down Expand Up @@ -489,13 +501,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
)

# 5. Data Iterators (Init BEFORE Trainer)
# We use MaxText's native create_data_iterator which creates both train and eval iterators
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)
if is_offline:
max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...")
raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir)
raw_eval_iter = None
else:
max_logging.log("Initializing Data Iterators via MaxText pipeline...")
raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh)

teacher_model.eval()
student_model.train()

model_bundle = ModelBundle(teacher_model, student_model)

# 6. Initialize Trainer
Expand All @@ -513,18 +527,35 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config)

# 8. Configure Input Mapping
trainer = trainer.with_gen_model_input_fn(
lambda batch: {
"input_tokens": batch.input_tokens,
"positions": batch.positions,
"attention_mask": batch.input_mask,
"decoder_segment_ids": batch.decoder_segment_ids,
"targets": batch.targets, # Passed to strategy (labels_fn)
"targets_position": batch.targets_position, # Passed to strategy (labels_fn)
"targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn)
"cache": None,
}
)
def custom_gen_model_input_fn(batch):
inputs_dict = {
"input_tokens": batch.input_tokens,
"positions": batch.positions,
"attention_mask": batch.input_mask,
"decoder_segment_ids": batch.decoder_segment_ids,
"targets": batch.targets,
"targets_position": batch.targets_position,
"targets_segmentation": batch.targets_segmentation,
"cache": None,
}

# If we are in online mode then we exit
if getattr(batch, "top_k_logits", None) is None:
return inputs_dict

# Scatter the offline arrays into a dense tensor of -10000s
dense_shape = batch.input_tokens.shape + (student_config.vocab_size,)
dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32)
dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why inplace=False?


# Inject it as teacher_output so the trainer skips the teacher forward pass
inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput(
logits=dense_logits, out_projection_activations=None
)

return inputs_dict

trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn)

# 9. Create Iterator Wrappers (Use Utils)
train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter)
Expand Down Expand Up @@ -576,9 +607,6 @@ def main(argv: Sequence[str]) -> None:
Parses configuration, isolates Student and Teacher overrides, and triggers the
training loop.
Args:
argv: List of command-line arguments. Expects [script_name, config_file, ...].
"""
# 1. Parse Global Config to extract Overrides
global_config = pyconfig.initialize(argv)
Expand All @@ -593,7 +621,7 @@ def main(argv: Sequence[str]) -> None:
teacher_overrides = global_config.teacher_overrides

# Ensure load_parameters_path is set in overrides
if not teacher_overrides.get("load_parameters_path"):
if not global_config.offline_distillation and not teacher_overrides.get("load_parameters_path"):
raise ValueError(
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
"in your config or arguments."
Expand All @@ -605,7 +633,7 @@ def main(argv: Sequence[str]) -> None:
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)

# 4. Run Training
train_distill(student_config, teacher_config)
train_distill(student_config, teacher_config, global_config.offline_distillation, global_config.offline_data_dir)


if __name__ == "__main__":
Expand Down
Loading