-
Notifications
You must be signed in to change notification settings - Fork 483
Implementation for soft offline distillation using saved top-k teacher logits #3382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ajkv-google
wants to merge
5
commits into
main
Choose a base branch
from
ajkv/offline-distillation-soft
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+116
−27
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
dc4d964
Added train script for offline distillation training
ajkv-google f8bb608
updated code formatting and style
ajkv-google 02931fb
updated iterator to ensure weight updates when training student model
ajkv-google 6bb64d0
moved cmd args into the distillation config to make command easier to…
ajkv-google 1fc0699
removed the need for hardcoding arrayrecord file and read directly fr…
ajkv-google File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,6 @@ | |
| 3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose | ||
| a standard interface (call signature) that the Tunix `DistillationTrainer` expects. | ||
| """ | ||
|
|
||
| from typing import Sequence, Callable | ||
| from absl import app | ||
| from flax import nnx | ||
|
|
@@ -292,6 +291,8 @@ def _prepare_inputs( | |
| targets=input_data.targets, | ||
| targets_position=input_data.targets_position, | ||
| targets_segmentation=input_data.targets_segmentation, | ||
| top_k_logits=input_data.top_k_logits, | ||
| top_k_indices=input_data.top_k_indices, | ||
| ) | ||
|
|
||
| def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None: | ||
|
|
@@ -390,7 +391,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) | |
| # ----------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters) -> None: | ||
| def train_distill( | ||
| student_config: pyconfig.HyperParameters, | ||
| teacher_config: pyconfig.HyperParameters, | ||
| is_offline: bool = False, | ||
| offline_data_dir: str | None = None, | ||
| ) -> None: | ||
| """Main distillation training loop. | ||
| Orchestrates the loading of both student and teacher models, configures the | ||
|
|
@@ -426,9 +432,15 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco | |
| _log_config_details(student_config, "Student") | ||
| student_model = get_maxtext_model(student_config, mesh) | ||
|
|
||
| max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") | ||
| _log_config_details(teacher_config, "Teacher") | ||
| teacher_model = get_maxtext_model(teacher_config, mesh) | ||
| # Skip teacher model loading if offline | ||
| if is_offline: | ||
| max_logging.log("Offline Distillation: Skipping Teacher Model loading.") | ||
| teacher_model = None | ||
| else: | ||
| max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") | ||
| _log_config_details(teacher_config, "Teacher") | ||
| teacher_model = get_maxtext_model(teacher_config, mesh) | ||
| teacher_model.eval() | ||
|
|
||
| # 3. Define Distillation Strategy | ||
| def labels_fn(targets, targets_segmentation=None, **kwargs): | ||
|
|
@@ -489,13 +501,15 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): | |
| ) | ||
|
|
||
| # 5. Data Iterators (Init BEFORE Trainer) | ||
| # We use MaxText's native create_data_iterator which creates both train and eval iterators | ||
| max_logging.log("Initializing Data Iterators via MaxText pipeline...") | ||
| raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) | ||
| if is_offline: | ||
| max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...") | ||
| raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir) | ||
| raw_eval_iter = None | ||
| else: | ||
| max_logging.log("Initializing Data Iterators via MaxText pipeline...") | ||
| raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) | ||
|
|
||
| teacher_model.eval() | ||
| student_model.train() | ||
|
|
||
| model_bundle = ModelBundle(teacher_model, student_model) | ||
|
|
||
| # 6. Initialize Trainer | ||
|
|
@@ -513,18 +527,35 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): | |
| raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config) | ||
|
|
||
| # 8. Configure Input Mapping | ||
| trainer = trainer.with_gen_model_input_fn( | ||
| lambda batch: { | ||
| "input_tokens": batch.input_tokens, | ||
| "positions": batch.positions, | ||
| "attention_mask": batch.input_mask, | ||
| "decoder_segment_ids": batch.decoder_segment_ids, | ||
| "targets": batch.targets, # Passed to strategy (labels_fn) | ||
| "targets_position": batch.targets_position, # Passed to strategy (labels_fn) | ||
| "targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn) | ||
| "cache": None, | ||
| } | ||
| ) | ||
| def custom_gen_model_input_fn(batch): | ||
| inputs_dict = { | ||
| "input_tokens": batch.input_tokens, | ||
| "positions": batch.positions, | ||
| "attention_mask": batch.input_mask, | ||
| "decoder_segment_ids": batch.decoder_segment_ids, | ||
| "targets": batch.targets, | ||
| "targets_position": batch.targets_position, | ||
| "targets_segmentation": batch.targets_segmentation, | ||
| "cache": None, | ||
| } | ||
|
|
||
| # If we are in online mode then we exit | ||
| if getattr(batch, "top_k_logits", None) is None: | ||
| return inputs_dict | ||
|
|
||
| # Scatter the offline arrays into a dense tensor of -10000s | ||
| dense_shape = batch.input_tokens.shape + (student_config.vocab_size,) | ||
| dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32) | ||
| dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why inplace=False? |
||
|
|
||
| # Inject it as teacher_output so the trainer skips the teacher forward pass | ||
| inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput( | ||
| logits=dense_logits, out_projection_activations=None | ||
| ) | ||
|
|
||
| return inputs_dict | ||
|
|
||
| trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn) | ||
|
|
||
| # 9. Create Iterator Wrappers (Use Utils) | ||
| train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter) | ||
|
|
@@ -576,9 +607,6 @@ def main(argv: Sequence[str]) -> None: | |
| Parses configuration, isolates Student and Teacher overrides, and triggers the | ||
| training loop. | ||
| Args: | ||
| argv: List of command-line arguments. Expects [script_name, config_file, ...]. | ||
| """ | ||
| # 1. Parse Global Config to extract Overrides | ||
| global_config = pyconfig.initialize(argv) | ||
|
|
@@ -593,7 +621,7 @@ def main(argv: Sequence[str]) -> None: | |
| teacher_overrides = global_config.teacher_overrides | ||
|
|
||
| # Ensure load_parameters_path is set in overrides | ||
| if not teacher_overrides.get("load_parameters_path"): | ||
| if not global_config.offline_distillation and not teacher_overrides.get("load_parameters_path"): | ||
| raise ValueError( | ||
| "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' " | ||
| "in your config or arguments." | ||
|
|
@@ -605,7 +633,7 @@ def main(argv: Sequence[str]) -> None: | |
| teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) | ||
|
|
||
| # 4. Run Training | ||
| train_distill(student_config, teacher_config) | ||
| train_distill(student_config, teacher_config, global_config.offline_distillation, global_config.offline_data_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks redundant
if you specify offline_data_dir parameter, that can be a direct sign of switching to the offline processing