From 5f4733125ee71749780ec252d2231487ea4e3deb Mon Sep 17 00:00:00 2001 From: Tai An Date: Tue, 16 Jun 2026 06:08:35 -0700 Subject: [PATCH 1/3] fix(moe): avoid cross-warp stale read in ep_scatter prefix sum _fwd_kernel_ep_scatter_1 stores the full exclusive prefix sum to expert_start_loc with a vectorized tl.store, then immediately reads back expert_start_loc[cur_expert] with a scalar tl.load. The vectorized store is split across the program warps, so under CUDA weak memory ordering the scalar read can observe a stale/uninitialized value written by another warp, producing a garbage offset and a cudaErrorIllegalAddress in the following m_indices write. Extract cur_expert_start (and cur_expert_token_num) directly from the in-register cumsum / tokens_per_expert vectors via tl.where + tl.sum, which reduces through shared memory with proper synchronization. The global store to expert_start_loc is kept since _fwd_kernel_ep_scatter_2 consumes it. Fixes #1361 --- .../fused_moe/deepep_scatter_gather.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py index 101d316937..6144554b51 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py @@ -22,8 +22,15 @@ def _fwd_kernel_ep_scatter_1( cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) - cur_expert_start = tl.load(expert_start_loc + cur_expert) - cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + # Read this program's prefix-sum offset from registers instead of the + # just-written global buffer: the vectorized store above is split across + # warps, so a scalar read-back of expert_start_loc[cur_expert] can observe a + # stale value written by another warp, producing a garbage offset and an + # illegal memory access. tl.sum reduces via shared memory with proper + # synchronization, so the extracted value is always correct. + expert_mask = offset_cumsum == cur_expert + cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum))) + cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert))) m_indices_start_ptr = m_indices + cur_expert_start off_expert = tl.arange(0, BLOCK_E) @@ -229,4 +236,4 @@ def ep_gather( num_warps=num_warps, BLOCK_D=BLOCK_D, ) - return + return \ No newline at end of file From c660552084b2215a0520d8c0bfb9249903593e49 Mon Sep 17 00:00:00 2001 From: Tai An Date: Tue, 16 Jun 2026 06:13:10 -0700 Subject: [PATCH 2/3] refactor(moe): keep cur_expert_token_num as direct load num_recv_tokens_per_expert is a read-only input and is never written in this kernel, so reading it directly carries no cross-warp stale-read risk. Only the expert_start_loc read-back (written just above) needs the register-based extraction. Reverts the unnecessary reduction for cur_expert_token_num per review feedback. --- .../triton_kernel/fused_moe/deepep_scatter_gather.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py index 6144554b51..8dd0c8532d 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py @@ -30,7 +30,9 @@ def _fwd_kernel_ep_scatter_1( # synchronization, so the extracted value is always correct. expert_mask = offset_cumsum == cur_expert cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum))) - cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert))) + # num_recv_tokens_per_expert is a read-only input (never written in this + # kernel), so a direct load carries no stale-read risk. + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) m_indices_start_ptr = m_indices + cur_expert_start off_expert = tl.arange(0, BLOCK_E) From fe473a338b8c5789d9b53907b32317e36911de73 Mon Sep 17 00:00:00 2001 From: Tai An Date: Wed, 17 Jun 2026 06:02:52 -0700 Subject: [PATCH 3/3] style: add trailing newline to satisfy black pre-commit --- .../basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py index 8dd0c8532d..90d76a9c6f 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py @@ -238,4 +238,4 @@ def ep_gather( num_warps=num_warps, BLOCK_D=BLOCK_D, ) - return \ No newline at end of file + return