From 0d80d3d2178376ab9af5afe2389cbdf3343569e3 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Wed, 11 Mar 2026 23:36:32 +0000 Subject: [PATCH 1/3] fix sft with after a recent distillation train code refactor --- .../post_train/distillation/train_distill.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 4bdd7f99c8..f3ae92d51e 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -131,7 +131,8 @@ def create_forward_fn(config: pyconfig.HyperParameters) -> Callable[..., distill """ def model_forward_fn( - model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs + model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, + **kwargs ) -> distillation_utils.DistillationForwardOutput: """Forward pass wrapper adapted for raw MaxText models.""" del attention_mask # Unused @@ -141,8 +142,8 @@ def model_forward_fn( decoder_positions=positions, decoder_segment_ids=decoder_segment_ids, enable_dropout=config.enable_dropout, - decoder_target_tokens=kwargs.get("targets", None), - decoder_target_mask=kwargs.get("targets_segmentation", None), + decoder_target_tokens=kwargs.get("decoder_target_tokens", None), + decoder_target_mask=kwargs.get("decoder_target_mask", None), ) out_projection_activations = None if config.distill_beta > 0.0: @@ -214,7 +215,7 @@ def _train_step(self, model, optimizer, inputs): batch = self.gen_model_input_fn(inputs) - def loss_wrapper(student, teacher, batch): + def loss_wrapper(student, teacher, batch): if "teacher_output" in batch: teacher_output = batch["teacher_output"] else: @@ -224,6 +225,8 @@ def loss_wrapper(student, teacher, batch): positions=batch["positions"], attention_mask=batch.get("attention_mask"), decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) @@ -235,9 +238,12 @@ def loss_wrapper(student, teacher, batch): positions=batch["positions"], attention_mask=batch.get("attention_mask"), decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) - labels = self.strategy.labels_fn(batch["targets"]) + # we should apply a mask for labels to disable segment-separator tokens + labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) return self.strategy.compute_loss(student_output, teacher_output, labels) # Because student is the 0th argument, argnums=0 guarantees @@ -434,7 +440,7 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco # 3. Define Distillation Strategy def labels_fn(targets, targets_segmentation=None, **kwargs): - """Converts integer targets to masked one-hot vectors for hard label loss.""" + """Converts integer targets to masked one-hot vectors for hard label loss.""" del kwargs # Unused one_hot = jax.nn.one_hot(targets, student_config.vocab_size) mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None] From b6bea94d5aca38fb1d89a0b4d680c3e4fb57d772 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Thu, 12 Mar 2026 17:33:26 +0000 Subject: [PATCH 2/3] added a unit test + format --- .../post_train/distillation/train_distill.py | 7 +-- tests/unit/train_distill_test.py | 60 +++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index f3ae92d51e..85eb045bfe 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -131,8 +131,7 @@ def create_forward_fn(config: pyconfig.HyperParameters) -> Callable[..., distill """ def model_forward_fn( - model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, - **kwargs + model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs ) -> distillation_utils.DistillationForwardOutput: """Forward pass wrapper adapted for raw MaxText models.""" del attention_mask # Unused @@ -215,7 +214,7 @@ def _train_step(self, model, optimizer, inputs): batch = self.gen_model_input_fn(inputs) - def loss_wrapper(student, teacher, batch): + def loss_wrapper(student, teacher, batch): if "teacher_output" in batch: teacher_output = batch["teacher_output"] else: @@ -440,7 +439,7 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco # 3. Define Distillation Strategy def labels_fn(targets, targets_segmentation=None, **kwargs): - """Converts integer targets to masked one-hot vectors for hard label loss.""" + """Converts integer targets to masked one-hot vectors for hard label loss.""" del kwargs # Unused one_hot = jax.nn.one_hot(targets, student_config.vocab_size) mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None] diff --git a/tests/unit/train_distill_test.py b/tests/unit/train_distill_test.py index 72c2ba2e86..8c8c9bb437 100644 --- a/tests/unit/train_distill_test.py +++ b/tests/unit/train_distill_test.py @@ -249,6 +249,66 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a self.assertEqual(loss, mock_loss) self.assertEqual(aux, mock_aux) + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map): + """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" + # 1. Initialize Trainer + # pylint: disable=no-value-for-parameter + trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) + trainer.strategy = mock.Mock() + + # 2. Setup Batch WITH targets_segmentation + mock_targets_segmentation = jnp.array([[1, 1, 0]]) + mock_batch = { + "input_tokens": mock.Mock(), + "positions": mock.Mock(), + "attention_mask": mock.Mock(), + "decoder_segment_ids": mock.Mock(), + "targets": mock.Mock(), + "targets_segmentation": mock_targets_segmentation, + } + trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) + + # 3. Setup Models & Inputs + teacher_model, student_model = mock.Mock(), mock.Mock() + model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) + optimizer, inputs = mock.Mock(), mock.Mock() + + # 4. Configure mocked nnx.value_and_grad + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), mock.Mock()), mock.Mock())) + mock_value_and_grad.return_value = mock_grad_fn + + # 5. Execute outer function & trigger inner loss_wrapper + trainer._train_step(model_bundle, optimizer, inputs) + loss_wrapper = mock_value_and_grad.call_args[0][0] + loss_wrapper(student_model, teacher_model, mock_batch) + + # 6. Assertions + trainer.strategy.labels_fn.assert_called_once_with( + mock_batch["targets"], targets_segmentation=mock_targets_segmentation + ) + trainer.strategy.student_forward_fn.assert_called_once_with( + model=student_model, + input_tokens=mock_batch["input_tokens"], + positions=mock_batch["positions"], + attention_mask=mock_batch["attention_mask"], + decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch["targets"], + decoder_target_mask=mock_targets_segmentation, + cache=None, + ) + trainer.strategy.teacher_forward_fn.assert_called_once_with( + model=teacher_model, + input_tokens=mock_batch["input_tokens"], + positions=mock_batch["positions"], + attention_mask=mock_batch["attention_mask"], + decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch["targets"], + decoder_target_mask=mock_targets_segmentation, + cache=None, + ) + def test_optimizer_factory(self): """Verifies the optimizer factory injects hyperparams and handles configs.""" # Mock config From 5d3683587aa4df7f9a1a6f65963e676a3c4e4c75 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Thu, 12 Mar 2026 18:17:22 +0000 Subject: [PATCH 3/3] fixed related test --- tests/unit/train_distill_test.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/unit/train_distill_test.py b/tests/unit/train_distill_test.py index 8c8c9bb437..6e84a914af 100644 --- a/tests/unit/train_distill_test.py +++ b/tests/unit/train_distill_test.py @@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch.get("targets", None), + decoder_target_mask=mock_batch.get("targets_segmentation", None), cache=None, ) @@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch.get("targets", None), cache=None, + decoder_target_mask=None, ) trainer.strategy.student_forward_fn.assert_called_once_with( @@ -237,11 +241,13 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch.get("targets", None), cache=None, + decoder_target_mask=None, ) # Verify loss computation and optimizer update - trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"]) + trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None) trainer.strategy.compute_loss.assert_called_once() optimizer.update.assert_called_once_with(student_model, mock_grads)