Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,15 +1171,22 @@ def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
return shard_size_mapping.get(loaded_shard_id)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in [
"qkv",
"gate",
], f"loaded_shard_id must be one of ['qkv', 'gate'], but got {loaded_shard_id}"

# Support loading individual shards: "q", "k", "v", "gate"
# Also support "qkv" for fused qkv weights and "gate" for gate weights
# "split_q_gate" for Qwen3.5: split q_proj into query and gate parts
valid_shard_ids = ["qkv", "gate", "q", "k", "v", "split_q_gate"]
assert (
loaded_shard_id in valid_shard_ids
), f"loaded_shard_id must be one of {valid_shard_ids}, but got {loaded_shard_id}"
if loaded_shard_id == "qkv":
self.qkv_weight_loader(param, loaded_weight, None)
else:
elif loaded_shard_id == "gate":
self.gate_weight_loader(param, loaded_weight)
elif loaded_shard_id == "split_q_gate":
self.split_q_gate_weight_loader(param, loaded_weight)
else:
# "q", "k", "v" - load individual shard
self.qkv_weight_loader(param, loaded_weight, loaded_shard_id)

def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -1289,6 +1296,47 @@ def gate_weight_loader(self, param, loaded_weight):
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)

def split_q_gate_weight_loader(self, param, loaded_weight):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0

weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False

is_torch_format = self.fd_config.model_config.model_format == "torch"

# Qwen3.5: q_proj contains query and gate in PACKED layout per head
assert loaded_weight.shape[dim] == self.num_heads * self.head_dim * 2, (
f"split_q_gate_weight_loader: expected output dim {self.num_heads * self.head_dim * 2}, "
f"got {loaded_weight.shape[dim]}. Check head_dim ({self.head_dim}) and num_heads ({self.num_heads})."
)

# Weight layout: [q0_0,...,q0_{hd-1}, g0_0,...,g0_{hd-1}, q1_0,...] where qi/gi each have head_dim elements
if is_torch_format:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug weight_need_transposeis_torch_format 逻辑冲突

当加载 torch 格式的 Qwen3.5 模型时,通常 model_format="torch"weight_need_transpose=True。此时:

  1. 权重先被 transpose 从 [out, in] 变成 [in, out](Paddle 格式)
  2. is_torch_format=True,代码走 torch 分支,假设权重仍是 [out, in] 格式
  3. 导致 reshape 失败(元素数量不匹配)

对比 qkv_weight_loadergate_weight_loader 的实现,它们在 transpose 后统一按 Paddle 格式处理,没有 is_torch_format 分支。

建议修复方式:根据 transpose 是否发生来决定后续处理逻辑,而非根据 model_format 配置:

# 方案1:简化为统一按 Paddle 格式处理(推荐,与其他 loader 保持一致)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
    loaded_weight = get_tensor(loaded_weight)
    loaded_weight = loaded_weight.transpose([1, 0])
    param.weight_need_transpose = False

# transpose 后统一是 Paddle 格式 [in_size, out_size],直接按 Paddle 分支处理
input_shape = loaded_weight.shape[:-1]
query_weight, gate_weight = paddle.chunk(
    loaded_weight.reshape([*input_shape, -1, self.head_dim * 2]), 2, axis=-1
)
...

另外,测试用例 Case 2 虽然设置了 weight_need_transpose=True,但 fd_configmodel_format 不是 "torch",未覆盖到此 bug 场景。建议补充测试用例。

# Torch format [out_size, in_size]: split along axis 0
in_size = loaded_weight.shape[-1]
w = loaded_weight.reshape([self.num_heads, self.head_dim * 2, in_size])
query_weight, gate_weight = paddle.chunk(w, 2, axis=1)
query_weight = query_weight.reshape([-1, in_size])
gate_weight = gate_weight.reshape([-1, in_size])
else:
# Paddle format [in_size, out_size]: split along last axis
input_shape = loaded_weight.shape[:-1]
query_weight, gate_weight = paddle.chunk(
loaded_weight.reshape([*input_shape, -1, self.head_dim * 2]), 2, axis=-1
)
query_weight = query_weight.reshape([*input_shape, -1])
gate_weight = gate_weight.reshape([*input_shape, -1])

# Load query and gate weights
self.qkv_weight_loader(param, query_weight, "q")
self.gate_weight_loader(param, gate_weight)

def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
Expand Down
126 changes: 126 additions & 0 deletions tests/layers/test_qkvg_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ def test_weight_loader_valid_shard_id(self):
if "loaded_shard_id must be one of" in str(e):
self.fail("weight_loader should accept 'gate' as a valid shard_id")

try:
layer.weight_loader(param_mock, weight_mock, "split_q_gate")
except AssertionError as e:
if "loaded_shard_id must be one of" in str(e):
self.fail("weight_loader should accept 'split_q_gate' as a valid shard_id")

try:
layer.weight_loader(param_mock, weight_mock, "k")
except AssertionError as e:
if "loaded_shard_id must be one of" in str(e):
self.fail("weight_loader should accept 'k' as a valid shard_id")

def test_weight_loader_invalid_shard_id(self):
"""Test weight_loader with invalid shard IDs."""
fd_config = self.create_fd_config()
Expand All @@ -307,6 +319,120 @@ def test_weight_loader_invalid_shard_id(self):

self.assertIn("loaded_shard_id must be one of", str(context.exception))

def test_weight_loader_success(self):
"""Test split_q_gate weight_loader correctly splits and places query/gate weights."""
# --- Config constants (tp=1) ---
# num_heads=16, kv_num_heads=4, head_dim=64, input_size=1024
# output_size = (2*16 + 2*4)*64 = 2560
# param layout (dim=-1): [q: 0..1024) | [k: 1024..1280) | [v: 1280..1536) | [gate: 1536..2560)
fd_config = self.create_fd_config()
layer = QKVGateParallelLinear(fd_config=fd_config, prefix="test.qkvg_proj")

num_heads = layer.num_heads # 16
head_dim = layer.head_dim # 64
input_size = layer.input_size # 1024
kv_heads_per_rank = layer.kv_num_heads_per_rank # 4 (tp=1)
num_heads_per_rank = layer.num_heads_per_rank # 16 (tp=1)

q_size = num_heads_per_rank * head_dim # 1024
gate_size = num_heads_per_rank * head_dim # 1024
gate_offset = (num_heads_per_rank + 2 * kv_heads_per_rank) * head_dim # 1536

# --- Case 1: paddle (non-torch) format, loaded_weight shape [input_size, num_heads*head_dim*2] ---
param = layer.weight
# Build a packed weight: per head, first head_dim = query, next head_dim = gate
# Shape: [input_size, num_heads * head_dim * 2]
q_values = paddle.ones([input_size, num_heads * head_dim], dtype=param.dtype) # all 1.0
g_values = paddle.full([input_size, num_heads * head_dim], 2.0, dtype=param.dtype) # all 2.0
# interleave per-head: reshape to [input_size, num_heads, head_dim*2] then cat
packed = paddle.concat(
[q_values.reshape([input_size, num_heads, head_dim]), g_values.reshape([input_size, num_heads, head_dim])],
axis=-1,
).reshape(
[input_size, num_heads * head_dim * 2]
) # [1024, 2048]

layer.weight_loader(param, packed, "split_q_gate")

# q region in param should be all 1.0
q_region = layer.weight[:, :q_size].cast("float32")
self.assertTrue(
paddle.allclose(q_region, paddle.ones_like(q_region)).item(),
"Query weights not correctly written to param q-region",
)
# gate region in param should be all 2.0
gate_region = layer.weight[:, gate_offset : gate_offset + gate_size].cast("float32")
self.assertTrue(
paddle.allclose(gate_region, paddle.full_like(gate_region, 2.0)).item(),
"Gate weights not correctly written to param gate-region",
)

# --- Case 2: torch format (weight_need_transpose=True), loaded_weight shape [num_heads*head_dim*2, input_size] ---
layer2 = QKVGateParallelLinear(fd_config=fd_config, prefix="test.qkvg_proj")
param2 = layer2.weight
# Simulate torch format: transpose of packed
packed_torch = packed.transpose([1, 0]) # [2048, 1024]
setattr(param2, "weight_need_transpose", True)

layer2.weight_loader(param2, packed_torch, "split_q_gate")

q_region2 = layer2.weight[:, :q_size].cast("float32")
self.assertTrue(
paddle.allclose(q_region2, paddle.ones_like(q_region2)).item(),
"Query weights not correctly written after transpose (torch format)",
)
gate_region2 = layer2.weight[:, gate_offset : gate_offset + gate_size].cast("float32")
self.assertTrue(
paddle.allclose(gate_region2, paddle.full_like(gate_region2, 2.0)).item(),
"Gate weights not correctly written after transpose (torch format)",
)
# weight_need_transpose should be reset to False after loading
self.assertFalse(
getattr(param2, "weight_need_transpose", False),
"weight_need_transpose should be reset to False after split_q_gate loading",
)

# --- Case 3: TP=2, rank=0 — only first half of heads loaded ---
fd_config_tp2 = self.create_fd_config(tp_size=2, tp_rank=0)
layer_tp = QKVGateParallelLinear(fd_config=fd_config_tp2, prefix="test.qkvg_proj")
# tp=2: num_heads_per_rank=8, kv_num_heads_per_rank=2
# output_size per rank = (2*8 + 2*2)*64 = 1280
tp_num_heads_per_rank = layer_tp.num_heads_per_rank # 8
tp_kv_per_rank = layer_tp.kv_num_heads_per_rank # 2
tp_q_size = tp_num_heads_per_rank * head_dim # 512
tp_gate_offset = (tp_num_heads_per_rank + 2 * tp_kv_per_rank) * head_dim # 768

param_tp = layer_tp.weight
q_tp = paddle.ones([input_size, num_heads * head_dim], dtype=param_tp.dtype)
g_tp = paddle.full([input_size, num_heads * head_dim], 3.0, dtype=param_tp.dtype)
packed_tp = paddle.concat(
[q_tp.reshape([input_size, num_heads, head_dim]), g_tp.reshape([input_size, num_heads, head_dim])],
axis=-1,
).reshape([input_size, num_heads * head_dim * 2])

layer_tp.weight_loader(param_tp, packed_tp, "split_q_gate")

# rank=0 takes first 8 heads → q values should be 1.0
q_region_tp = layer_tp.weight[:, :tp_q_size].cast("float32")
self.assertTrue(
paddle.allclose(q_region_tp, paddle.ones_like(q_region_tp)).item(),
"TP rank-0 query weights not correctly written",
)
gate_region_tp = layer_tp.weight[:, tp_gate_offset : tp_gate_offset + tp_num_heads_per_rank * head_dim].cast(
"float32"
)
self.assertTrue(
paddle.allclose(gate_region_tp, paddle.full_like(gate_region_tp, 3.0)).item(),
"TP rank-0 gate weights not correctly written",
)

# --- Case 4: wrong shape triggers assert ---
layer4 = QKVGateParallelLinear(fd_config=fd_config, prefix="test.qkvg_proj")
bad_weight = paddle.zeros([input_size, num_heads * head_dim], dtype=layer4.weight.dtype) # missing gate half
with self.assertRaises(AssertionError) as ctx:
layer4.weight_loader(layer4.weight, bad_weight, "split_q_gate")
self.assertIn("split_q_gate_weight_loader", str(ctx.exception))

def test_load_state_dict_success(self):
"""Test loading state_dict with valid qkv and gate weights."""
fd_config = self.create_fd_config()
Expand Down
Loading