Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 10 additions & 36 deletions src/models/qwen3_5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2398,11 +2398,10 @@ pub fn sanitize_weights(mut weights: WeightMap, config: &Qwen35Config) -> Weight

// Handle per-expert gate_proj/up_proj/down_proj naming variant.
// Checkpoints that store experts under `experts.{e}.gate_proj.weight`
// instead of `experts.{e}.w1.weight` use this layout. The mapping
// mirrors the fused gate_up_proj path: gate->w1, up->w3, down->w2.
for (src_proj, dst_proj) in [("gate_proj", "w1"), ("up_proj", "w3"), ("down_proj", "w2")] {
// Skip if this target slot is already populated by the w1/w2/w3 pass above.
if weights.contains_key(format!("{}.{}.weight", base, dst_proj).as_str()) {
// instead of per-expert w1 weights use this layout. Keep the
// stacked names aligned with SwitchGLU::from_weights.
for proj in ["gate_proj", "up_proj", "down_proj"] {
if weights.contains_key(format!("{}.{}.weight", base, proj).as_str()) {
continue;
}

Expand All @@ -2413,18 +2412,18 @@ pub fn sanitize_weights(mut weights: WeightMap, config: &Qwen35Config) -> Weight
let mut e = 0;
while let Some(w) = weights.remove(&format!(
"model.layers.{}.mlp.experts.{}.{}.weight",
l, e, src_proj
l, e, proj
)) {
expert_weights.push(w);
if let Some(s) = weights.remove(&format!(
"model.layers.{}.mlp.experts.{}.{}.scales",
l, e, src_proj
l, e, proj
)) {
expert_scales.push(s);
}
if let Some(b) = weights.remove(&format!(
"model.layers.{}.mlp.experts.{}.{}.biases",
l, e, src_proj
l, e, proj
)) {
expert_biases.push(b);
}
Expand All @@ -2433,16 +2432,16 @@ pub fn sanitize_weights(mut weights: WeightMap, config: &Qwen35Config) -> Weight

if !expert_weights.is_empty() {
let stacked = stack_arrays(&expert_weights, 0);
weights.insert(format!("{}.{}.weight", base, dst_proj), stacked);
weights.insert(format!("{}.{}.weight", base, proj), stacked);

if !expert_scales.is_empty() {
let stacked = stack_arrays(&expert_scales, 0);
weights.insert(format!("{}.{}.scales", base, dst_proj), stacked);
weights.insert(format!("{}.{}.scales", base, proj), stacked);
}

if !expert_biases.is_empty() {
let stacked = stack_arrays(&expert_biases, 0);
weights.insert(format!("{}.{}.biases", base, dst_proj), stacked);
weights.insert(format!("{}.{}.biases", base, proj), stacked);
}
}
}
Expand Down Expand Up @@ -2481,31 +2480,6 @@ pub fn sanitize_weights(mut weights: WeightMap, config: &Qwen35Config) -> Weight
}
}

// 8. Rename switch_mlp.{gate_proj,up_proj,down_proj} → switch_mlp.{w1,w3,w2}
// Pre-quantized MoE models use gate_proj/up_proj/down_proj naming,
// but SparseMoeBlock expects w1/w2/w3 naming.
let rename_map = [
("switch_mlp.gate_proj.", "switch_mlp.w1."),
("switch_mlp.up_proj.", "switch_mlp.w3."),
("switch_mlp.down_proj.", "switch_mlp.w2."),
];
let keys_to_rename: Vec<String> = weights
.keys()
.filter(|k| rename_map.iter().any(|(from, _)| k.contains(from)))
.cloned()
.collect();
for key in keys_to_rename {
for (from, to) in &rename_map {
if key.contains(from) {
let new_key = key.replace(from, to);
if let Some(v) = weights.remove(&key) {
weights.insert(new_key, v);
}
break;
}
}
}

weights
}

Expand Down
34 changes: 33 additions & 1 deletion src/models/qwen3_5_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
//! that need a real Qwen 3.5 model and are gated behind hardware availability.

use super::qwen3_5::{
Qwen35Config, rebuild_with_zero_tail, sanitize_weights, zero_per_row_kv_tail,
Qwen35Config, rebuild_with_zero_tail, sanitize_weights,
zero_per_row_kv_tail,
};
use mlxcel_core::dtype;
use mlxcel_core::layers::KVCache;
Expand Down Expand Up @@ -309,3 +310,34 @@ fn sanitize_weights_drops_lm_head_when_tied_embeddings() {
);
assert!(sanitized.contains_key("model.embed_tokens.weight"));
}

#[test]
#[ignore = "requires serial MLX execution"]
fn sanitize_weights_stacks_per_expert_switch_proj_names_for_loader() {
let root = "model.layers.0.mlp.experts";
let mut weights = WeightMap::new();
for expert in 0..2 {
for proj in ["gate_proj", "up_proj", "down_proj"] {
weights.insert(
format!("{root}.{expert}.{proj}.weight"),
mlxcel_core::from_slice_f32(&[expert as f32; 8], &[2, 4]),
);
}
}

let mut config = make_tiny_config();
config.num_experts = 2;
config.num_experts_per_tok = 1;
config.decoder_sparse_step = 1;
config.moe_intermediate_size = 2;
config.shared_expert_intermediate_size = 2;

let sanitized = sanitize_weights(weights, &config);

assert!(sanitized.contains_key("model.layers.0.mlp.switch_mlp.gate_proj.weight"));
assert!(sanitized.contains_key("model.layers.0.mlp.switch_mlp.up_proj.weight"));
assert!(sanitized.contains_key("model.layers.0.mlp.switch_mlp.down_proj.weight"));
assert!(!sanitized.contains_key("model.layers.0.mlp.switch_mlp.w1.weight"));
assert!(!sanitized.contains_key("model.layers.0.mlp.switch_mlp.w2.weight"));
assert!(!sanitized.contains_key("model.layers.0.mlp.switch_mlp.w3.weight"));
}