Implementation for soft offline distillation using saved top-k teacher logits#3382
Implementation for soft offline distillation using saved top-k teacher logits#3382ajkv-google wants to merge 5 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| def __init__(self, data_dir: str, epochs: int = 100): | ||
| # Check if the user passed a directory or a direct file path | ||
| if tf.io.gfile.isdir(data_dir): | ||
| self.filepath = os.path.join(data_dir, "teacher_top_k_global.array_record") |
There was a problem hiding this comment.
is it ok to hardcode this file as teacher_top_k_global.array_record?
There was a problem hiding this comment.
In the save_top_k_teacher logits file (from this PR), we are writing a single arrayrecord file from one host rather than having multiple hosts write their chunks of data. So, I just named the file as "teacher_top_k_global.arrayrecord". But, I believe not everyone running offline distillation will use the same file to save top-k teacher logits, so I will add this as a field to the config so that users can specify the filename of the saved top-k teacher logits to have it be dynamic.
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) | ||
| parser = argparse.ArgumentParser() |
There was a problem hiding this comment.
I think these should go inside types.py to add them as part of the config.
There was a problem hiding this comment.
That makes sense, it would make the command less complex and make things more organized if it is in the config. I moved these to types.py and verified the training ran successfully after the change.
Description
This PR introduces an end-to-end offline distillation training pipeline. Previously, the distillation loop executed in an "online" mode, which required both the frozen Teacher model and the learning Student model to be loaded and executed simultaneously during training. This change allows the trainer to load pre-computed, top-K Teacher logits from .array_record files, which allows us to bybass the forward pass for the teacher model during the training loop.
Tests
Tested this code change by running the following command:
python3 src/maxtext/trainers/post_train/distillation/train_distill.py src/maxtext/configs/post_train/distillation.yml steps=100 tokenizer_path="/mnt/ajkv/disks/codebase/maxtext/src/maxtext/assets/tokenizers/tokenizer_llama3.tiktoken" --offline_distillation --offline_data_dir="/mnt/ajkv/disks/teacher_logits_output/teacher_top_k_global.array_record"Truncated output showing the successful run: https://paste.googleplex.com/6342987127848960#l=8.
Verified that the training happened sucessfully and finished the distillation run.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.