From 4297eedbc7ecf37eaf412d954aaad9ca21dc6d7d Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Thu, 12 Mar 2026 08:44:39 +0000 Subject: [PATCH] Update Qwen3 vLLM layer names to match tpu-inference mappings. --- .../integration/tunix/weight_mapping/qwen3.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/maxtext/integration/tunix/weight_mapping/qwen3.py b/src/maxtext/integration/tunix/weight_mapping/qwen3.py index fac316db29..e7a2a16072 100644 --- a/src/maxtext/integration/tunix/weight_mapping/qwen3.py +++ b/src/maxtext/integration/tunix/weight_mapping/qwen3.py @@ -67,12 +67,12 @@ def to_hf_mapping(): return { # Token embeddings - shard vocab dimension "base.token_embedder.embedding": ( - "model.embed.embedding", + "model.embed_tokens.weight", ("model", None), ), # Final layer norm - no sharding needed "base.decoder.decoder_norm.scale": ( - "model.norm.scale", + "model.norm.weight", (None,), ), # LM head (logits projection) - shard vocab dimension @@ -83,49 +83,49 @@ def to_hf_mapping(): # Layer-specific mappings (scanned -> unscanned) # MLP components - shard hidden dimensions "base.decoder.layers.mlp.wi_0.kernel": ( - "model.layers.*.mlp.gate_proj.kernel", + "model.layers.*.mlp.gate_proj.weight", (None, "layer", "model"), ), "base.decoder.layers.mlp.wi_1.kernel": ( - "model.layers.*.mlp.up_proj.kernel", + "model.layers.*.mlp.up_proj.weight", (None, "layer", "model"), ), "base.decoder.layers.mlp.wo.kernel": ( - "model.layers.*.mlp.down_proj.kernel", + "model.layers.*.mlp.down_proj.weight", ("model", "layer", None), ), # Layer norms - no sharding needed "base.decoder.layers.pre_self_attention_layer_norm.scale": ( - "model.layers.*.input_layernorm.scale", + "model.layers.*.input_layernorm.weight", (None, "layer"), ), "base.decoder.layers.post_self_attention_layer_norm.scale": ( - "model.layers.*.post_attention_layernorm.scale", + "model.layers.*.post_attention_layernorm.weight", (None, "layer"), ), # Attention components - shard head dimensions "base.decoder.layers.self_attention.query.kernel": ( - "model.layers.*.self_attn.q_proj.kernel", + "model.layers.*.self_attn.q_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.key.kernel": ( - "model.layers.*.self_attn.k_proj.kernel", + "model.layers.*.self_attn.k_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.value.kernel": ( - "model.layers.*.self_attn.v_proj.kernel", + "model.layers.*.self_attn.v_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.out.kernel": ( - "model.layers.*.self_attn.o_proj.kernel", + "model.layers.*.self_attn.o_proj.weight", ("model", "layer", None, None), ), "base.decoder.layers.self_attention.query_norm.scale": ( - "model.layers.*.self_attn.q_norm.scale", + "model.layers.*.self_attn.q_norm.weight", (None, "layer"), ), "base.decoder.layers.self_attention.key_norm.scale": ( - "model.layers.*.self_attn.k_norm.scale", + "model.layers.*.self_attn.k_norm.weight", (None, "layer"), ), }