diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py index 101d31693..90d76a9c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py @@ -22,7 +22,16 @@ def _fwd_kernel_ep_scatter_1( cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) - cur_expert_start = tl.load(expert_start_loc + cur_expert) + # Read this program's prefix-sum offset from registers instead of the + # just-written global buffer: the vectorized store above is split across + # warps, so a scalar read-back of expert_start_loc[cur_expert] can observe a + # stale value written by another warp, producing a garbage offset and an + # illegal memory access. tl.sum reduces via shared memory with proper + # synchronization, so the extracted value is always correct. + expert_mask = offset_cumsum == cur_expert + cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum))) + # num_recv_tokens_per_expert is a read-only input (never written in this + # kernel), so a direct load carries no stale-read risk. cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) m_indices_start_ptr = m_indices + cur_expert_start