diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 3f45e9df614..a4346f51001 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -26,6 +26,7 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, @@ -160,8 +161,16 @@ def __init__( default_initializer=paddle.nn.initializer.Constant(0), ) + # In pure-TP mode (tp>1, ep=1) both branches return partial sums, so we + # defer the all-reduce to after combining them — saving one collective. + # In all other modes (EP, EP+attn-TP, no parallelism) each branch handles + # its own reduction internally (reduce_results default=True), so we must + # NOT add an extra all-reduce here. + self._pure_tp = self.use_tp and not self.use_ep + self.experts = FusedMoE( fd_config, + reduce_results=not self._pure_tp, renormalize=self.norm_topk_prob, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.n_routed_experts, @@ -182,14 +191,16 @@ def __init__( intermediate_size=shared_experts_intermediate_size, layer_id=layer_id, prefix=f"{prefix}.shared_experts", + reduce_results=not self._pure_tp, ) def forward(self, x, forward_meta: ForwardMeta = None): out = self.experts(x, self.gate, forward_meta) if self.n_shared_experts > 0: - shared_experts_out = self.shared_experts(x) - out = out + shared_experts_out - + out = out + self.shared_experts(x) + if self._pure_tp: + # Both branches produced partial sums; combine first, then single all-reduce. + out = tensor_model_parallel_all_reduce(out, self.tp_group) return out