Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddleformers.utils.log import logger

from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
Expand Down Expand Up @@ -160,8 +161,16 @@ def __init__(
default_initializer=paddle.nn.initializer.Constant(0),
)

# In pure-TP mode (tp>1, ep=1) both branches return partial sums, so we
# defer the all-reduce to after combining them — saving one collective.
# In all other modes (EP, EP+attn-TP, no parallelism) each branch handles
# its own reduction internally (reduce_results default=True), so we must
# NOT add an extra all-reduce here.
self._pure_tp = self.use_tp and not self.use_ep
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量应该叫 self.merge_ffn_tp 是不是更好点

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量应该叫 self.merge_ffn_tp 是不是更好点

是的,更容易理解一些


self.experts = FusedMoE(
fd_config,
reduce_results=not self._pure_tp,
renormalize=self.norm_topk_prob,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts,
Expand All @@ -182,14 +191,16 @@ def __init__(
intermediate_size=shared_experts_intermediate_size,
layer_id=layer_id,
prefix=f"{prefix}.shared_experts",
reduce_results=not self._pure_tp,
)

def forward(self, x, forward_meta: ForwardMeta = None):
out = self.experts(x, self.gate, forward_meta)
if self.n_shared_experts > 0:
shared_experts_out = self.shared_experts(x)
out = out + shared_experts_out

out = out + self.shared_experts(x)
if self._pure_tp:
# Both branches produced partial sums; combine first, then single all-reduce.
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out


Expand Down
Loading