Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/maxtext/trainers/post_train/distillation/train_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def model_forward_fn(
decoder_positions=positions,
decoder_segment_ids=decoder_segment_ids,
enable_dropout=config.enable_dropout,
decoder_target_tokens=kwargs.get("targets", None),
decoder_target_mask=kwargs.get("targets_segmentation", None),
decoder_target_tokens=kwargs.get("decoder_target_tokens", None),
decoder_target_mask=kwargs.get("decoder_target_mask", None),
)
out_projection_activations = None
if config.distill_beta > 0.0:
Expand Down Expand Up @@ -224,6 +224,8 @@ def loss_wrapper(student, teacher, batch):
positions=batch["positions"],
attention_mask=batch.get("attention_mask"),
decoder_segment_ids=batch.get("decoder_segment_ids"),
decoder_target_tokens=batch.get("targets", None),
decoder_target_mask=batch.get("targets_segmentation", None),
cache=None,
)

Expand All @@ -235,9 +237,12 @@ def loss_wrapper(student, teacher, batch):
positions=batch["positions"],
attention_mask=batch.get("attention_mask"),
decoder_segment_ids=batch.get("decoder_segment_ids"),
decoder_target_tokens=batch.get("targets", None),
decoder_target_mask=batch.get("targets_segmentation", None),
cache=None,
)
labels = self.strategy.labels_fn(batch["targets"])
# we should apply a mask for labels to disable segment-separator tokens
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
return self.strategy.compute_loss(student_output, teacher_output, labels)

# Because student is the 0th argument, argnums=0 guarantees
Expand Down
68 changes: 67 additions & 1 deletion tests/unit/train_distill_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a
positions=mock_batch["positions"],
attention_mask=mock_batch["attention_mask"],
decoder_segment_ids=mock_batch["decoder_segment_ids"],
decoder_target_tokens=mock_batch.get("targets", None),
decoder_target_mask=mock_batch.get("targets_segmentation", None),
cache=None,
)

Expand Down Expand Up @@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
positions=mock_batch["positions"],
attention_mask=mock_batch["attention_mask"],
decoder_segment_ids=mock_batch["decoder_segment_ids"],
decoder_target_tokens=mock_batch.get("targets", None),
cache=None,
decoder_target_mask=None,
)

trainer.strategy.student_forward_fn.assert_called_once_with(
Expand All @@ -237,18 +241,80 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
positions=mock_batch["positions"],
attention_mask=mock_batch["attention_mask"],
decoder_segment_ids=mock_batch["decoder_segment_ids"],
decoder_target_tokens=mock_batch.get("targets", None),
cache=None,
decoder_target_mask=None,
)

# Verify loss computation and optimizer update
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"])
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
trainer.strategy.compute_loss.assert_called_once()
optimizer.update.assert_called_once_with(student_model, mock_grads)

# Verify the final returns match what grad_fn produced
self.assertEqual(loss, mock_loss)
self.assertEqual(aux, mock_aux)

@mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map")
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad")
def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map):
"""Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask."""
# 1. Initialize Trainer
# pylint: disable=no-value-for-parameter
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
trainer.strategy = mock.Mock()

# 2. Setup Batch WITH targets_segmentation
mock_targets_segmentation = jnp.array([[1, 1, 0]])
mock_batch = {
"input_tokens": mock.Mock(),
"positions": mock.Mock(),
"attention_mask": mock.Mock(),
"decoder_segment_ids": mock.Mock(),
"targets": mock.Mock(),
"targets_segmentation": mock_targets_segmentation,
}
trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch)

# 3. Setup Models & Inputs
teacher_model, student_model = mock.Mock(), mock.Mock()
model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model)
optimizer, inputs = mock.Mock(), mock.Mock()

# 4. Configure mocked nnx.value_and_grad
mock_grad_fn = mock.Mock(return_value=((mock.Mock(), mock.Mock()), mock.Mock()))
mock_value_and_grad.return_value = mock_grad_fn

# 5. Execute outer function & trigger inner loss_wrapper
trainer._train_step(model_bundle, optimizer, inputs)
loss_wrapper = mock_value_and_grad.call_args[0][0]
loss_wrapper(student_model, teacher_model, mock_batch)

# 6. Assertions
trainer.strategy.labels_fn.assert_called_once_with(
mock_batch["targets"], targets_segmentation=mock_targets_segmentation
)
trainer.strategy.student_forward_fn.assert_called_once_with(
model=student_model,
input_tokens=mock_batch["input_tokens"],
positions=mock_batch["positions"],
attention_mask=mock_batch["attention_mask"],
decoder_segment_ids=mock_batch["decoder_segment_ids"],
decoder_target_tokens=mock_batch["targets"],
decoder_target_mask=mock_targets_segmentation,
cache=None,
)
trainer.strategy.teacher_forward_fn.assert_called_once_with(
model=teacher_model,
input_tokens=mock_batch["input_tokens"],
positions=mock_batch["positions"],
attention_mask=mock_batch["attention_mask"],
decoder_segment_ids=mock_batch["decoder_segment_ids"],
decoder_target_tokens=mock_batch["targets"],
decoder_target_mask=mock_targets_segmentation,
cache=None,
)

def test_optimizer_factory(self):
"""Verifies the optimizer factory injects hyperparams and handles configs."""
# Mock config
Expand Down
Loading