Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ def _fwd_kernel_ep_scatter_1(
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)

cur_expert_start = tl.load(expert_start_loc + cur_expert)
# Read this program's prefix-sum offset from registers instead of the
# just-written global buffer: the vectorized store above is split across
# warps, so a scalar read-back of expert_start_loc[cur_expert] can observe a
# stale value written by another warp, producing a garbage offset and an
# illegal memory access. tl.sum reduces via shared memory with proper
# synchronization, so the extracted value is always correct.
expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
# num_recv_tokens_per_expert is a read-only input (never written in this
# kernel), so a direct load carries no stale-read risk.
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)

m_indices_start_ptr = m_indices + cur_expert_start
Expand Down