diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c9e9475883..6d087f9256 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1077,6 +1077,11 @@ class Distillation(BaseModel): description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).", ) + # --- Offline Distillation Field --- + offline_data_dir: Optional[str] = Field( + None, description="GCS or local path to the pre-generated ArrayRecord teacher data." + ) + # --- Loss Params --- distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.") distill_temperature: float = Field(1.0, description="Temperature for distillation softening.") diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 71a63c1ce2..92ffbb4e4e 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -18,6 +18,10 @@ model structures with Tunix's training interfaces. """ +import pickle +import tensorflow as tf +from array_record.python import array_record_module + from typing import Any, Iterator, Optional, List, Callable import flax @@ -63,6 +67,9 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput): targets_position: jax.Array = None #: Segment IDs for packed target tokens. targets_segmentation: jax.Array = None + #: Top-K logits from the teacher model. + top_k_logits: jax.Array = None + top_k_indices: jax.Array = None # ----------------------------------------------------------------------------- @@ -70,6 +77,50 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput): # ----------------------------------------------------------------------------- +class OfflineArrayRecordIterator: + """Reads the pre-generated global top-k logits file.""" + + def __init__(self, data_dir: str, epochs: int = 100): + self.filepath = data_dir + + if not tf.io.gfile.exists(self.filepath): + raise FileNotFoundError(f"Offline distillation file not found: {self.filepath}") + + self.reader = array_record_module.ArrayRecordReader(self.filepath) + self.num_records = self.reader.num_records() + self.epochs = epochs + self.current_epoch = 0 + self.record_index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.record_index >= self.num_records: + self.current_epoch += 1 + if self.current_epoch >= self.epochs: + raise StopIteration + + self.record_index = 0 + self.reader = array_record_module.ArrayRecordReader(self.filepath) + + record = self.reader.read() + self.record_index += 1 + data = pickle.loads(record) + + # Map the arrays to match MaxText's expected dictionary + batch = { + "inputs": data["tokens"], + "top_k_logits": data["top_k_logits"], + "top_k_indices": data["top_k_indices"], + } + for key in ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]: + if key in data: + batch[key] = data[key] + + return batch + + class MaxTextToTunixIterator: """Adapts the raw dictionary output of MaxText's data loader to Tunix objects. @@ -123,6 +174,8 @@ def __next__(self) -> MaxTextTrainingInput: targets=batch["targets"], targets_position=targets_position, targets_segmentation=targets_segmentation, + top_k_logits=batch.get("top_k_logits"), + top_k_indices=batch.get("top_k_indices"), ) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 85eb045bfe..e0f5e65ae3 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -32,7 +32,6 @@ 3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose a standard interface (call signature) that the Tunix `DistillationTrainer` expects. """ - from typing import Sequence, Callable from absl import app from flax import nnx @@ -299,6 +298,8 @@ def _prepare_inputs( targets=input_data.targets, targets_position=input_data.targets_position, targets_segmentation=input_data.targets_segmentation, + top_k_logits=input_data.top_k_logits, + top_k_indices=input_data.top_k_indices, ) def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None: @@ -397,7 +398,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) # ----------------------------------------------------------------------------- -def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters) -> None: +def train_distill( + student_config: pyconfig.HyperParameters, + teacher_config: pyconfig.HyperParameters, + is_offline: bool = False, + offline_data_dir: str | None = None, +) -> None: """Main distillation training loop. Orchestrates the loading of both student and teacher models, configures the @@ -433,9 +439,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco _log_config_details(student_config, "Student") student_model = get_maxtext_model(student_config, mesh) - max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") - _log_config_details(teacher_config, "Teacher") - teacher_model = get_maxtext_model(teacher_config, mesh) + # Skip teacher model loading if offline + if is_offline: + max_logging.log("Offline Distillation: Skipping Teacher Model loading.") + teacher_model = None + else: + max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") + _log_config_details(teacher_config, "Teacher") + teacher_model = get_maxtext_model(teacher_config, mesh) + teacher_model.eval() # 3. Define Distillation Strategy def labels_fn(targets, targets_segmentation=None, **kwargs): @@ -498,13 +510,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): ) # 5. Data Iterators (Init BEFORE Trainer) - # We use MaxText's native create_data_iterator which creates both train and eval iterators - max_logging.log("Initializing Data Iterators via MaxText pipeline...") - raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) + if is_offline: + max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...") + raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir) + raw_eval_iter = None + else: + max_logging.log("Initializing Data Iterators via MaxText pipeline...") + raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) - teacher_model.eval() student_model.train() - model_bundle = ModelBundle(teacher_model, student_model) # 6. Initialize Trainer @@ -522,18 +536,35 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config) # 8. Configure Input Mapping - trainer = trainer.with_gen_model_input_fn( - lambda batch: { - "input_tokens": batch.input_tokens, - "positions": batch.positions, - "attention_mask": batch.input_mask, - "decoder_segment_ids": batch.decoder_segment_ids, - "targets": batch.targets, # Passed to strategy (labels_fn) - "targets_position": batch.targets_position, # Passed to strategy (labels_fn) - "targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn) - "cache": None, - } - ) + def custom_gen_model_input_fn(batch): + inputs_dict = { + "input_tokens": batch.input_tokens, + "positions": batch.positions, + "attention_mask": batch.input_mask, + "decoder_segment_ids": batch.decoder_segment_ids, + "targets": batch.targets, + "targets_position": batch.targets_position, + "targets_segmentation": batch.targets_segmentation, + "cache": None, + } + + # If we are in online mode then we exit + if getattr(batch, "top_k_logits", None) is None: + return inputs_dict + + # Scatter the offline arrays into a dense tensor of -10000s + dense_shape = batch.input_tokens.shape + (student_config.vocab_size,) + dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32) + dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False) + + # Inject it as teacher_output so the trainer skips the teacher forward pass + inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput( + logits=dense_logits, out_projection_activations=None + ) + + return inputs_dict + + trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn) # 9. Create Iterator Wrappers (Use Utils) train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter) @@ -585,9 +616,6 @@ def main(argv: Sequence[str]) -> None: Parses configuration, isolates Student and Teacher overrides, and triggers the training loop. - - Args: - argv: List of command-line arguments. Expects [script_name, config_file, ...]. """ # 1. Parse Global Config to extract Overrides global_config = pyconfig.initialize(argv) @@ -597,12 +625,14 @@ def main(argv: Sequence[str]) -> None: student_overrides = global_config.student_overrides student_config = pyconfig.initialize(argv, **student_overrides) + is_offline = bool(global_config.offline_data_dir) + # 3. Initialize TEACHER Config # We isolate the Teacher from Student CLI arguments (like pruning params). teacher_overrides = global_config.teacher_overrides # Ensure load_parameters_path is set in overrides - if not teacher_overrides.get("load_parameters_path"): + if not is_offline and not teacher_overrides.get("load_parameters_path"): raise ValueError( "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' " "in your config or arguments." @@ -614,7 +644,7 @@ def main(argv: Sequence[str]) -> None: teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) # 4. Run Training - train_distill(student_config, teacher_config) + train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir) if __name__ == "__main__": diff --git a/tests/unit/train_distill_test.py b/tests/unit/train_distill_test.py index 6e84a914af..dc27d83daa 100644 --- a/tests/unit/train_distill_test.py +++ b/tests/unit/train_distill_test.py @@ -604,6 +604,162 @@ def test_post_process_train_step(self): values_list = mock_buffer.additional_metrics["distill/kl_div"][0] self.assertEqual(values_list[0], 0.5) + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh") + @mock.patch("maxtext.configs.pyconfig.initialize") + def test_main_offline_mode_skips_teacher_loading( + self, + mock_pyconfig_init, + mock_create_mesh, + mock_build_tokenizer, + mock_get_model, + mock_create_iterator, + mock_trainer_cls, + mock_offline_iter_cls, + ): + """Verifies offline mode (offline_data_dir is set) skips teacher model loading.""" + # 1. Configs + mock_global = mock.Mock() + mock_global.student_overrides = {} + mock_global.teacher_overrides = {} # No checkpoint needed + mock_global.offline_data_dir = "gs://bucket/data" # Triggers offline mode + + mock_student_cfg = mock.Mock() + mock_student_cfg.vocab_size = 32000 + mock_student_cfg.mesh_axes = ("data",) + mock_student_cfg.dataset_type = "grain" + + # Add dummy numbers for optimizer math + mock_student_cfg.learning_rate = 1e-4 + mock_student_cfg.warmup_steps_fraction = 0.1 + mock_student_cfg.learning_rate_final_fraction = 0.1 + mock_student_cfg.steps = 100 + mock_student_cfg.checkpoint_period = 10 + mock_student_cfg.gradient_clipping_threshold = 0.0 + mock_student_cfg.eval_interval = -1 + + # Add dummy numbers for strategy math/logic + mock_student_cfg.distill_temperature = 1.0 + mock_student_cfg.distill_alpha = 0.5 + mock_student_cfg.distill_beta = 0.0 + mock_student_cfg.distill_layer_indices = None + mock_student_cfg.use_sft = False + mock_student_cfg.enable_dropout = False + + # Add dummy variables for Checkpointer and Logger + mock_student_cfg.max_num_checkpoints_to_keep = 1 + mock_student_cfg.async_checkpointing = False + mock_student_cfg.profiler = "none" + mock_student_cfg.tensorboard_dir = "" + mock_student_cfg.checkpoint_dir = "" + mock_student_cfg.log_period = 10 + mock_student_cfg.save_checkpoint_on_completion = False + mock_student_cfg.logical_axis_rules = [] + + mock_teacher_cfg = mock.Mock() + mock_teacher_cfg.vocab_size = 32000 + mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg] + + # 2. Model Loading + mock_student_model = mock.Mock() + mock_get_model.return_value = mock_student_model + + # 3. Tokenizer & Data Iterator + mock_build_tokenizer.return_value = mock.Mock(pad_id=0) + mock_create_iterator.return_value = (None, None) + + train_distill.main(["train_distill.py", "config.yml"]) + + # 4. Assertions + # checking to ensure get_maxtext_model is only called once for student and not for teacher + mock_get_model.assert_called_once_with(mock_student_cfg, mock.ANY) + + trainer_init_kwargs = mock_trainer_cls.call_args.kwargs + model_bundle = trainer_init_kwargs["model"] + # check that student model is set but teacher model is None since offline mode should skip loading teacher + self.assertIs(model_bundle.student_model, mock_student_model) + self.assertIsNone(model_bundle.teacher_model) + + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh") + @mock.patch("maxtext.configs.pyconfig.initialize") + def test_main_online_mode_loads_teacher( + self, + mock_pyconfig_init, + mock_create_mesh, + mock_build_tokenizer, + mock_get_model, + mock_create_iterator, + mock_trainer_cls, + ): + """Verifies online mode (offline_data_dir is None) loads both student and teacher models.""" + mock_global = mock.Mock() + mock_global.student_overrides = {} + mock_global.teacher_overrides = {"load_parameters_path": "gs://ckpt"} + mock_global.offline_data_dir = None # Triggers online mode + + mock_student_cfg = mock.Mock() + mock_student_cfg.vocab_size = 32000 + mock_student_cfg.mesh_axes = ("data",) + mock_student_cfg.dataset_type = "grain" + + # Add dummy numbers for optimizer math + mock_student_cfg.learning_rate = 1e-4 + mock_student_cfg.warmup_steps_fraction = 0.1 + mock_student_cfg.learning_rate_final_fraction = 0.1 + mock_student_cfg.steps = 100 + mock_student_cfg.checkpoint_period = 10 + mock_student_cfg.gradient_clipping_threshold = 0.0 + mock_student_cfg.eval_interval = -1 + + # Add dummy numbers for strategy math/logic + mock_student_cfg.distill_temperature = 1.0 + mock_student_cfg.distill_alpha = 0.5 + mock_student_cfg.distill_beta = 0.0 + mock_student_cfg.distill_layer_indices = None + mock_student_cfg.use_sft = False + mock_student_cfg.enable_dropout = False + + # Add dummy variables for Checkpointer and Logger + mock_student_cfg.max_num_checkpoints_to_keep = 1 + mock_student_cfg.async_checkpointing = False + mock_student_cfg.profiler = "none" + mock_student_cfg.tensorboard_dir = "" + mock_student_cfg.checkpoint_dir = "" + mock_student_cfg.log_period = 10 + mock_student_cfg.save_checkpoint_on_completion = False + mock_student_cfg.logical_axis_rules = [] + + mock_teacher_cfg = mock.Mock() + mock_teacher_cfg.vocab_size = 32000 + mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg] + + mock_student_model = mock.Mock() + mock_teacher_model = mock.Mock() + mock_get_model.side_effect = [mock_student_model, mock_teacher_model] + + mock_build_tokenizer.return_value = mock.Mock(pad_id=0) + mock_create_iterator.return_value = (mock.Mock(), mock.Mock()) + + train_distill.main(["train_distill.py", "config.yml"]) + + # checking to ensure get_maxtext_model is called for both student and teacher since online mode should load both + self.assertEqual(mock_get_model.call_count, 2) + mock_get_model.assert_any_call(mock_student_cfg, mock.ANY) + mock_get_model.assert_any_call(mock_teacher_cfg, mock.ANY) + + trainer_init_kwargs = mock_trainer_cls.call_args.kwargs + model_bundle = trainer_init_kwargs["model"] + # check that both student and teacher models are set since online mode should load both + self.assertIs(model_bundle.student_model, mock_student_model) + self.assertIs(model_bundle.teacher_model, mock_teacher_model) def test_gradient_accumulation_requires_k_passes_for_update(self): """Verifies that weights only update after k distinct forward passes."""