diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 4bdd7f99c8..85eb045bfe 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -141,8 +141,8 @@ def model_forward_fn( decoder_positions=positions, decoder_segment_ids=decoder_segment_ids, enable_dropout=config.enable_dropout, - decoder_target_tokens=kwargs.get("targets", None), - decoder_target_mask=kwargs.get("targets_segmentation", None), + decoder_target_tokens=kwargs.get("decoder_target_tokens", None), + decoder_target_mask=kwargs.get("decoder_target_mask", None), ) out_projection_activations = None if config.distill_beta > 0.0: @@ -224,6 +224,8 @@ def loss_wrapper(student, teacher, batch): positions=batch["positions"], attention_mask=batch.get("attention_mask"), decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) @@ -235,9 +237,12 @@ def loss_wrapper(student, teacher, batch): positions=batch["positions"], attention_mask=batch.get("attention_mask"), decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) - labels = self.strategy.labels_fn(batch["targets"]) + # we should apply a mask for labels to disable segment-separator tokens + labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) return self.strategy.compute_loss(student_output, teacher_output, labels) # Because student is the 0th argument, argnums=0 guarantees diff --git a/tests/unit/train_distill_test.py b/tests/unit/train_distill_test.py index 72c2ba2e86..6e84a914af 100644 --- a/tests/unit/train_distill_test.py +++ b/tests/unit/train_distill_test.py @@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch.get("targets", None), + decoder_target_mask=mock_batch.get("targets_segmentation", None), cache=None, ) @@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch.get("targets", None), cache=None, + decoder_target_mask=None, ) trainer.strategy.student_forward_fn.assert_called_once_with( @@ -237,11 +241,13 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch.get("targets", None), cache=None, + decoder_target_mask=None, ) # Verify loss computation and optimizer update - trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"]) + trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None) trainer.strategy.compute_loss.assert_called_once() optimizer.update.assert_called_once_with(student_model, mock_grads) @@ -249,6 +255,66 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a self.assertEqual(loss, mock_loss) self.assertEqual(aux, mock_aux) + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map): + """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" + # 1. Initialize Trainer + # pylint: disable=no-value-for-parameter + trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) + trainer.strategy = mock.Mock() + + # 2. Setup Batch WITH targets_segmentation + mock_targets_segmentation = jnp.array([[1, 1, 0]]) + mock_batch = { + "input_tokens": mock.Mock(), + "positions": mock.Mock(), + "attention_mask": mock.Mock(), + "decoder_segment_ids": mock.Mock(), + "targets": mock.Mock(), + "targets_segmentation": mock_targets_segmentation, + } + trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) + + # 3. Setup Models & Inputs + teacher_model, student_model = mock.Mock(), mock.Mock() + model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) + optimizer, inputs = mock.Mock(), mock.Mock() + + # 4. Configure mocked nnx.value_and_grad + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), mock.Mock()), mock.Mock())) + mock_value_and_grad.return_value = mock_grad_fn + + # 5. Execute outer function & trigger inner loss_wrapper + trainer._train_step(model_bundle, optimizer, inputs) + loss_wrapper = mock_value_and_grad.call_args[0][0] + loss_wrapper(student_model, teacher_model, mock_batch) + + # 6. Assertions + trainer.strategy.labels_fn.assert_called_once_with( + mock_batch["targets"], targets_segmentation=mock_targets_segmentation + ) + trainer.strategy.student_forward_fn.assert_called_once_with( + model=student_model, + input_tokens=mock_batch["input_tokens"], + positions=mock_batch["positions"], + attention_mask=mock_batch["attention_mask"], + decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch["targets"], + decoder_target_mask=mock_targets_segmentation, + cache=None, + ) + trainer.strategy.teacher_forward_fn.assert_called_once_with( + model=teacher_model, + input_tokens=mock_batch["input_tokens"], + positions=mock_batch["positions"], + attention_mask=mock_batch["attention_mask"], + decoder_segment_ids=mock_batch["decoder_segment_ids"], + decoder_target_tokens=mock_batch["targets"], + decoder_target_mask=mock_targets_segmentation, + cache=None, + ) + def test_optimizer_factory(self): """Verifies the optimizer factory injects hyperparams and handles configs.""" # Mock config