From 0e85947a0593e8e12db3089fe473c939bd606d7e Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Tue, 19 May 2026 01:57:11 +0000 Subject: [PATCH] Qwen3-30b learn to init --- QWEN3_LTI_SETUP.md | 223 ++++++++++++++++++ .../post_train/distillation_qwen3_30b_lti.yml | 104 ++++++++ src/maxtext/layers/learn_to_init_layer.py | 72 +++++- src/maxtext/models/qwen3.py | 2 + .../distillation/tools/derive_lti_copy_map.py | 128 ++++++++++ .../post_train/distillation/train_distill.py | 7 +- 6 files changed, 526 insertions(+), 10 deletions(-) create mode 100644 QWEN3_LTI_SETUP.md create mode 100644 src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml create mode 100644 src/maxtext/trainers/post_train/distillation/tools/derive_lti_copy_map.py diff --git a/QWEN3_LTI_SETUP.md b/QWEN3_LTI_SETUP.md new file mode 100644 index 0000000000..9398c4024d --- /dev/null +++ b/QWEN3_LTI_SETUP.md @@ -0,0 +1,223 @@ +# Qwen3-30B-A3B Learn-to-Init Distillation — Setup + +End-to-end steps to run Learn-to-Init (LTI) soft distillation from a +Qwen3-30B-A3B-base teacher to a custom student (half query/KV heads, doubled +head_dim) in MaxText. + +- **Teacher**: `qwen3-30b-a3b-base` (converted MaxText checkpoint) +- **Student**: custom variant — `learn_to_init_mode: True` +- **Recipe**: `src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml` +- **Hardware tested**: TPU v7x-8 (8 chips, 96 GB HBM/device) + +--- + +## 1. Environment + +Editable MaxText install plus TPU extras: + +```bash +python3 -m venv .venv +.venv/bin/python -m pip install --upgrade pip +.venv/bin/python -m pip install -e '.[tpu]' +.venv/bin/install_tpu_pre_train_extra_deps +``` + +Verify TPU visibility: + +```bash +.venv/bin/python -c "import jax; print(len(jax.devices()), jax.devices()[0])" +``` + +### Add tunix (required for the distillation trainer) + +The `[tpu]` extra does not include tunix. Distillation lives under +`maxtext.trainers.post_train.distillation`, which imports `tunix`. Install +the pinned tunix sha matching MaxText's canonical XPK image +(`run_distill_xpk.sh:prep_image`), then re-pin libtpu/jax to keep +image/libtpu/jax compatibility: + +```bash +.venv/bin/python -m pip install --no-cache-dir --force-reinstall \ + "git+https://github.com/google/tunix@" + +.venv/bin/python -m pip install --no-cache-dir --force-reinstall --no-deps \ + jax==0.10.0 jaxlib==0.10.0 libtpu==0.0.39 +``` + +> The `tpu-post-train` extra pulls in vLLM and tpu-inference (large) and +> downgrades `flax`/`optax`. The two commands above install only what the +> distillation trainer needs. + +--- + +## 2. Convert the teacher checkpoint + +Convert the HF Qwen3-30B-A3B-base weights to MaxText format using the +unified `to_maxtext` script: + +```bash +python -m maxtext.checkpoint_conversion.to_maxtext \ + src/maxtext/configs/base.yml \ + model_name=qwen3-30b-a3b-base \ + load_parameters_path= \ + base_output_directory= \ + hardware=cpu skip_jax_distributed_system=True scan_layers=True +``` + +The custom student is materialized at training time from this teacher +checkpoint via the recipe's `student_overrides` and the copy_map (Section 5); +no separate student conversion is needed. + +--- + +## 3. The distillation config + +`src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml` enables LTI +distillation. Key fields: + +### Custom student shape + +```yaml +student_overrides: + model_name: "qwen3-30b-a3b-base" + override_model_config: True + base_num_query_heads: 16 # teacher: 32 + head_dim: 256 # teacher: 128 + base_num_kv_heads: 2 # teacher: 4 +``` + +`rope_max_timescale` is inherited from the base model config (1e7), applied +at the new head_dim. The A,B bridges learn to adapt to whatever RoPE +frequencies are present. + +### LTI mode + +```yaml +learn_to_init_mode: True +attn_module_name: "self_attention" +lti_use_general_linear_map: False # bilinear bridge; cheaper HBM +``` + +### YAML top-level requirement + +Batch-shape fields (`per_device_batch_size`, `gradient_accumulation_steps`, +`max_target_length`) must be set at the YAML top level — the trainer +rebuilds the teacher config from the YAML only, and a shape mismatch trips +a validator at startup. + +--- + +## 4. Smoke test + +End-to-end pipeline check (LTI swap + forward + loss + ckpt) at small +batch/seq: + +```bash +.venv/bin/python -m maxtext.trainers.post_train.distillation.train_distill \ + src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml \ + run_name=smoke-lti-$(date +%Y%m%d-%H%M%S) \ + base_output_directory= \ + max_target_length=2048 \ + steps=20 checkpoint_period=10 +``` + +With an empty `distill_weights_copy_map`, expect: + +``` +total_loss ~ 10 soft_loss = kl_div_T1 ~ 10 hard_loss ~ 12 +student_perplexity ~ 2.5e+05 teacher_perplexity ~ single-digit +moe_lb_loss ~ 0.02 +``` + +- `student_perplexity` near vocab size = near-random — expected because + non-attention weights are randomly initialized. +- `total_loss ≈ soft_loss` because `distill_alpha=1.0` (pure KD). + +Compile is ~10–15 min cold, ~1 min if `~/workspace/jax_cache` is warm. + +--- + +## 5. Derive `distill_weights_copy_map` + +`distill_weights_copy_map` tells `lti_utils.prepare_student_weights` which +teacher tensors to copy into the student at init. Without it, only LTI's +internal bridges are randomly initialized — every non-attention weight is +random too, and loss starts far above the floor. + +A helper script +`src/maxtext/trainers/post_train/distillation/tools/derive_lti_copy_map.py` +uses `nnx.eval_shape` (no weights materialized) to walk both abstract +graphs and emit a copy_map for every path whose shape exactly matches: + +```bash +.venv/bin/python -m maxtext.trainers.post_train.distillation.tools.derive_lti_copy_map \ + src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml \ + > /tmp/copy_map.yml +``` + +Inspected skips are expected: attention q/k/v/out projections (wrapped in +`LearnToInitDense`) and q_norm/k_norm (shape depends on head_dim which +differs). Paste the `distill_weights_copy_map: ...` block into the YAML. + +Critically, the copy map must also copy teacher attention kernels into +the student's frozen `C` buffer: + +```yaml +distill_weights_copy_map: + "decoder/layers/self_attention/query/kernel": "decoder/layers/self_attention/query/C" + "decoder/layers/self_attention/key/kernel": "decoder/layers/self_attention/key/C" + "decoder/layers/self_attention/value/kernel": "decoder/layers/self_attention/value/C" + "decoder/layers/self_attention/out/kernel": "decoder/layers/self_attention/out/C" +``` + +Without this, `C` stays at `jnp.empty()` (≈zero) and the LTI bridges +compute `A · 0 · B = 0`, so attention output is zero. + +--- + +## 6. Run + +After the copy map is in the YAML: + +```bash +.venv/bin/python -m maxtext.trainers.post_train.distillation.train_distill \ + src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml \ + run_name=qwen3-30b-lti-$(date +%Y%m%d-%H%M%S) \ + base_output_directory= +``` + +### Expected timings (TPU v7x-8) + +- Teacher checkpoint load: ~4 min +- Student init + LTI weight injection: ~5 s (after teacher is loaded) +- XLA compile: ~1 min warm cache; ~10–15 min cold +- Step time (per_device=1, grad_accum=1, seq=4096): **~1.8 s/step** +- 64 000 steps: roughly **~32 h wall-clock** +- Checkpoint save to GCS: ~5 min per save (async — overlaps with training) + +### Memory expectations (TPU v7x-8, 96 GB HBM/device) + +Per-device rough budget (FSDP shards across 8 devices): + +- Teacher params (bf16, frozen): ~7.5 GB +- Student params (bf16): ~7.5 GB +- Adam optimizer state (fp32 m + fp32 nu, student only): ~30 GB +- **Static state per device: ~45 GB** +- Activations (seq 4096, batch 1, fp32 logits): ~20 GB peak +- **Total per device: ~65 GB / 96 GB cap → ~30 GB headroom** + +--- + +## 7. Outputs + +For each run, the trainer writes under `//`: + +- `distillation.yml` — verbatim copy of the source YAML +- `command.sh` — pasteable command with CLI overrides +- `checkpoints//` — Orbax model_params + iter +- `tensorboard/` — TensorBoard event files + +Resume a crashed run: re-launch with the same `run_name` and +`base_output_directory`; the trainer auto-restores from the latest +checkpoint. + diff --git a/src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml b/src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml new file mode 100644 index 0000000000..8ae44a5d75 --- /dev/null +++ b/src/maxtext/configs/post_train/distillation_qwen3_30b_lti.yml @@ -0,0 +1,104 @@ +base_config: "base.yml" + +# --- Student Specifics --- +student_overrides: + model_name: "qwen3-30b-a3b-base" + override_model_config: True + base_num_query_heads: 16 + head_dim: 256 + base_num_kv_heads: 2 + +# --- Teacher Specifics --- +# learn_to_init_mode: False so the teacher keeps its original attention +# (LTI bridges only wrap the student). +teacher_overrides: + model_name: "qwen3-30b-a3b-base" + load_parameters_path: "" # path to the converted Qwen3-30B-A3B-base MaxText checkpoint + learn_to_init_mode: False + +# --- Distillation Loss --- +distill_alpha: 1.0 +distill_temperature: 1.0 +distill_beta: 0.0 + +# --- Learn-to-Init --- +learn_to_init_mode: True +attn_module_name: "self_attention" +lti_use_general_linear_map: False + +# Copy teacher non-attention weights directly; copy teacher attention kernels +# into the student's frozen C buffer (LTI bridges learn A,B around it). +distill_weights_copy_map: + "token_embedder/embedding": "token_embedder/embedding" + "decoder/decoder_norm/scale": "decoder/decoder_norm/scale" + "decoder/logits_dense/kernel": "decoder/logits_dense/kernel" + "decoder/layers/pre_self_attention_layer_norm/scale": "decoder/layers/pre_self_attention_layer_norm/scale" + "decoder/layers/post_self_attention_layer_norm/scale": "decoder/layers/post_self_attention_layer_norm/scale" + "decoder/layers/moe_block/gate/kernel": "decoder/layers/moe_block/gate/kernel" + "decoder/layers/moe_block/wi_0": "decoder/layers/moe_block/wi_0" + "decoder/layers/moe_block/wi_1": "decoder/layers/moe_block/wi_1" + "decoder/layers/moe_block/wo": "decoder/layers/moe_block/wo" + "decoder/layers/self_attention/query/kernel": "decoder/layers/self_attention/query/C" + "decoder/layers/self_attention/key/kernel": "decoder/layers/self_attention/key/C" + "decoder/layers/self_attention/value/kernel": "decoder/layers/self_attention/value/C" + "decoder/layers/self_attention/out/kernel": "decoder/layers/self_attention/out/C" + +# Freeze everything except LTI bridges A,B (~10M of ~30B params train). +student_params_to_update: + - "self_attention/(query|key|value|out)/(A|B)$" + +# --- MoE --- +# Qwen3-30B-A3B (128 experts, top-8) needs MoE LB loss > 0 to avoid router +# collapse. 0.001 matches HF router_aux_loss_coef. +load_balance_loss_weight: 0.001 + +# --- Dataset & Tokenizer --- +dataset_type: "grain" +grain_file_type: "arrayrecord" +grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" +grain_worker_count: 16 +grain_ram_budget_mb: 4096 +grain_per_worker_buffer_size: 8 +grain_prefetch_buffer_size: 64 +num_epoch: 10 + +tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer" +tokenizer_type: "huggingface" + +# Batch-shape fields must be at YAML top level (trainer rebuilds teacher +# config from YAML only, ignoring CLI overrides). +max_target_length: 4096 + +# --- Training Loop --- +# Front-loaded schedule: warmup + cosine decay over the first 3200 steps, +# then constant min_lr. +steps: 64000 +learning_rate_schedule_steps: 3200 +checkpoint_period: 2000 +log_period: 10 +save_checkpoint_on_completion: True + +# --- Batch Size Strategy --- +# Global Batch Size = per_device_batch_size * num_devices * gradient_accumulation_steps +# per_device=1 keeps Adam state + teacher params in 96 GB HBM. +per_device_batch_size: 1 +gradient_accumulation_steps: 1 + +# --- Learning Rate Schedule --- +# 200-step warmup is sufficient for LTI-only training (~10M trainable params, +# smooth loss surface). +learning_rate: 3.0e-5 +learning_rate_final_fraction: 0.0333 # = 1e-6 / 3e-5 +warmup_steps_fraction: 0.0625 # = 200 / 3200 + +# --- Optimizer --- +adam_b1: 0.9 +adam_b2: 0.95 +adam_eps: 1.e-5 +adam_weight_decay: 0.01 +adamw_mask: [] # uniform WD across all params + +# --- Numerics --- +# fp32 logits: KL is precision-sensitive (vocab=151k) and fits at per_device=1. +z_loss_multiplier: 1.0e-5 +float32_logits: True diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py index 2530c17336..840b8adc12 100644 --- a/src/maxtext/layers/learn_to_init_layer.py +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -33,12 +33,62 @@ LTI_ORIGINAL_ATTENTION_PARAMS_NAME = "kernel" LTI_LAYER_PATH_PREFIXES = ("layers_", "dense_layers_", "moe_layers_") +# Fallback teacher config for the LTI augment hook. The `_flat_config` dict +# injection in train_distill is lost when HyperParameters is deep-copied during +# nnx lazy_init, so the trainer also stashes the teacher config here. +_TEACHER_CONFIG: Config | None = None -def apply_lti_modification(module: nnx.Module, module_name: str | None = None): + +def set_teacher_config_for_lti(teacher_config: Config | None) -> None: + """Stashes the teacher config for `apply_lti_modification` to read.""" + global _TEACHER_CONFIG + _TEACHER_CONFIG = teacher_config + + +# Small noise on the structured warm-start so gradients can flow into the +# off-structure entries. +LTI_BRIDGE_NOISE_SCALE = 0.01 + + +def _warmstart_head_bridge(shape, dtype, rng_key, noise_scale=LTI_BRIDGE_NOISE_SCALE): + """Initializes a head-axis bridge of shape `(teacher_heads, student_heads)`. + + Group-mean when teacher_heads is divisible by student_heads; otherwise + identity-prefix. A small Gaussian noise is added. """ - Applies Learn-To-Init structural modifications to an instantiated NNX module. - Checks the config to determine if LTI is enabled. + x, u = shape + if x >= u and x % u == 0: + g = x // u + base = jnp.zeros((x, u), dtype=dtype) + rows = jnp.arange(x) + cols = rows // g + base = base.at[rows, cols].set(jnp.asarray(1.0 / g, dtype=dtype)) + else: + base = jnp.eye(x, u, dtype=dtype) + noise = nnx.initializers.lecun_normal()(rng_key, shape, dtype) * noise_scale + return base + noise + + +def _warmstart_dim_bridge(shape, dtype, rng_key, noise_scale=LTI_BRIDGE_NOISE_SCALE): + """Initializes a head_dim-axis bridge of shape `(in_dim, out_dim)`. + + Identity-prefix on the overlap region, `lecun_normal` random elsewhere. + This keeps teacher signal flowing through matching components while + preserving gradient through the expansion/reduction region. """ + in_dim, out_dim = shape + random_full = nnx.initializers.lecun_normal()(rng_key, shape, dtype) + overlap = min(in_dim, out_dim) + identity = jnp.eye(in_dim, out_dim, dtype=dtype) + if out_dim > in_dim: + out = random_full.at[:, :overlap].set(identity[:, :overlap] + random_full[:, :overlap] * noise_scale) + else: + out = random_full.at[:overlap, :].set(identity[:overlap, :] + random_full[:overlap, :] * noise_scale) + return out + + +def apply_lti_modification(module: nnx.Module, module_name: str | None = None): + """Applies LTI structural modifications to an instantiated NNX module if enabled in the config.""" config = getattr(module, "config", None) if not config or not getattr(config, "learn_to_init_mode", False): @@ -84,7 +134,12 @@ def _customize_attention_modules(config: Config, attn_module_name: str, module: target_names = LTI_MODIFIED_ATTENTION_PARAM_NAMES use_general_linear_map = config.lti_use_general_linear_map - teacher_config = config.teacher_config + teacher_config = getattr(config, "teacher_config", None) or _TEACHER_CONFIG + if teacher_config is None: + raise ValueError( + "LTI: teacher_config not set. Call set_teacher_config_for_lti(...) " + "from the trainer before building the student model." + ) for name in target_names: child = getattr(attention_module, name, None) @@ -193,17 +248,18 @@ def __init__( x, y, b_t = self.C.value.shape assert b_s == b_t, f"Embedding dimension mismatch for output projection: {b_s} != {b_t}" if self.use_general_linear_map: + # General-map mode has no structured warm-start; prefer A/B mode for that. self.W = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (x, y, u, v), self.weight_dtype), sharding=(None, None, None, None), ) else: self.A = nnx.Param( - nnx.initializers.lecun_normal()(rngs.params(), (x, u), self.weight_dtype), + _warmstart_head_bridge((x, u), self.weight_dtype, rngs.params()), sharding=(None, None), ) self.B = nnx.Param( - nnx.initializers.lecun_normal()(rngs.params(), (v, y), self.weight_dtype), + _warmstart_dim_bridge((v, y), self.weight_dtype, rngs.params()), sharding=(None, None), ) else: @@ -219,11 +275,11 @@ def __init__( ) else: self.A = nnx.Param( - nnx.initializers.lecun_normal()(rngs.params(), (x, u), self.weight_dtype), + _warmstart_head_bridge((x, u), self.weight_dtype, rngs.params()), sharding=(None, None), ) self.B = nnx.Param( - nnx.initializers.lecun_normal()(rngs.params(), (y, v), self.weight_dtype), + _warmstart_dim_bridge((y, v), self.weight_dtype, rngs.params()), sharding=(None, None), ) diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..63d5c2a116 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -35,6 +35,7 @@ from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations +from maxtext.layers.learn_to_init_layer import apply_lti_modification from maxtext.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding from maxtext.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated from maxtext.layers.quantizations import AqtQuantization as Quant @@ -2251,6 +2252,7 @@ def qwen3omni_audioprojector_as_linen(config: Config, mesh: Mesh): Qwen3MoeDecoderLayerToLinen = nnx_wrappers.to_linen_class( Qwen3MoeDecoderLayer, base_metadata_fn=max_initializers.variable_to_logically_partitioned, + nnx_module_augment_fn=apply_lti_modification, ) Qwen3NextDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/src/maxtext/trainers/post_train/distillation/tools/derive_lti_copy_map.py b/src/maxtext/trainers/post_train/distillation/tools/derive_lti_copy_map.py new file mode 100644 index 0000000000..3a1666fde6 --- /dev/null +++ b/src/maxtext/trainers/post_train/distillation/tools/derive_lti_copy_map.py @@ -0,0 +1,128 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Derive distill_weights_copy_map for an LTI distillation run. + +Loads the same student/teacher abstract configs that train_distill.py builds, +walks both graphs with nnx.graph.iter_graph, and emits a copy_map YAML snippet +listing every teacher path whose shape exactly matches the student path. Uses +nnx.eval_shape so no weights are materialized. + +Usage: + python -m maxtext.trainers.post_train.distillation.tools.derive_lti_copy_map +""" + +import sys +from collections import defaultdict + +from flax import nnx +from flax.linen import partitioning as nn_partitioning + +from maxtext.configs import pyconfig +from maxtext.utils import model_creation_utils + + +def _abstract_state_paths(config): + """Return {path -> shape} for every parameter in an abstract model.""" + _, abs_model = model_creation_utils.create_nnx_abstract_model(config) + paths = {} + for path, node in nnx.graph.iter_graph(abs_model): + if not isinstance(node, nnx.Variable): + continue + try: + val = node.value + except Exception: # pylint: disable=broad-exception-caught + continue + shape = getattr(val, "shape", None) + if shape is None: + continue + paths["/".join(map(str, path))] = tuple(shape) + return paths + + +def main(argv): + if len(argv) != 2: + sys.stderr.write(f"usage: {argv[0]} \n") + sys.exit(2) + + config_path = argv[1] + + global_config = pyconfig.initialize([argv[0], config_path]) + student_overrides = dict(global_config.student_overrides or {}) + teacher_overrides = dict(global_config.teacher_overrides or {}) + + # Teacher built from a sanitized argv (no CLI flags) -- mirrors train_distill.py:898. + teacher_argv = [argv[0], config_path] + with nn_partitioning.axis_rules(global_config.logical_axis_rules): + student_config = pyconfig.initialize([argv[0], config_path], **student_overrides) + teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) + + print("Building abstract student model...", file=sys.stderr) + student_paths = _abstract_state_paths(student_config) + print(f" student params: {len(student_paths)}", file=sys.stderr) + + print("Building abstract teacher model...", file=sys.stderr) + teacher_paths = _abstract_state_paths(teacher_config) + print(f" teacher params: {len(teacher_paths)}", file=sys.stderr) + + matches = [] # (path, shape) + missing_in_student = [] + shape_mismatch = [] # (path, teacher_shape, student_shape) + for path, t_shape in teacher_paths.items(): + if path not in student_paths: + missing_in_student.append(path) + continue + s_shape = student_paths[path] + if s_shape == t_shape: + matches.append((path, t_shape)) + else: + shape_mismatch.append((path, t_shape, s_shape)) + + print(f"\nMatching paths (will be copied): {len(matches)}", file=sys.stderr) + print(f"Shape-mismatch (skipped): {len(shape_mismatch)}", file=sys.stderr) + print(f"Only in teacher (skipped): {len(missing_in_student)}", file=sys.stderr) + + if shape_mismatch: + print("\n--- Skipped due to shape mismatch (expected for LTI: attn projections, q/k_norm) ---", file=sys.stderr) + for p, t, s in shape_mismatch[:20]: + print(f" {p} teacher={t} student={s}", file=sys.stderr) + if len(shape_mismatch) > 20: + print(f" ... and {len(shape_mismatch) - 20} more", file=sys.stderr) + + if missing_in_student: + print("\n--- Only in teacher (skipped: no student counterpart) ---", file=sys.stderr) + for p in missing_in_student[:20]: + print(f" {p}", file=sys.stderr) + if len(missing_in_student) > 20: + print(f" ... and {len(missing_in_student) - 20} more", file=sys.stderr) + + # Group matches by parent module so the YAML stays human-readable while still + # mapping 1:1 (we escape regex metachars and use exact-match patterns). + print("\n# --- Paste into the YAML under `distill_weights_copy_map:` ---") + print("distill_weights_copy_map:") + groups = defaultdict(list) + for path, _ in sorted(matches): + parent = path.rsplit("/", 1)[0] if "/" in path else "" + groups[parent].append(path) + for parent, paths in groups.items(): + if parent: + print(f" # {parent}") + for p in paths: + # exact-match regex; copy teacher->student same path + escaped = p.replace(".", r"\.").replace("[", r"\[").replace("]", r"\]") + print(f' "{escaped}": "{p}"') + + +if __name__ == "__main__": + main(sys.argv) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 27b82b1f6b..9bcc49535b 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -51,7 +51,7 @@ from maxtext.configs import pyconfig from maxtext.input_pipeline import tokenizer from maxtext.input_pipeline import input_pipeline_interface -from maxtext.layers.learn_to_init_layer import apply_lti_model_update +from maxtext.layers.learn_to_init_layer import apply_lti_model_update, set_teacher_config_for_lti from maxtext.optimizers import optimizers from maxtext.trainers.post_train.distillation import distillation_utils, lti_utils from maxtext.utils import max_logging @@ -695,8 +695,11 @@ def train_distill( teacher_model = get_maxtext_model(teacher_config, mesh) teacher_model.eval() - # LTI phase needs the student initialization step to know about the teacher configuration + # LTI needs the teacher config at student lazy_init time. The dict + # injection is lost when HyperParameters is deep-copied, so also stash it + # on the LTI module-level fallback. student_config.get_keys()["teacher_config"] = teacher_config + set_teacher_config_for_lti(teacher_config) max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") _log_config_details(student_config, "Student")