Title: [BUG] _fwd_kernel_ep_scatter_1: cross-warp stale read of expert_start_loc causes illegal memory access
Label: bug
Before you submit an issue, please search for existing issues to avoid duplicates.
Issue description:
_fwd_kernel_ep_scatter_1 in deepep_scatter_gather.py has a cross-warp memory visibility bug. After writing the prefix sum to expert_start_loc via a vector tl.store, the kernel immediately reads back a single element via scalar tl.load (lines 23–25):
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert) # ← may read stale value
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
With num_experts=256 and num_warps=8, the vector store is spread across 8 warps (32 elements each). The scalar load reads a slot written by a different warp. Under CUDA's weak memory model, without an explicit bar.sync + membar.cta between the store and load, other warps may read stale/uninitialized data from the torch.empty-allocated expert_start_loc buffer.
When cur_expert_start reads a garbage value (e.g. 0x7FFFFFFF), the subsequent unmasked write tl.store(m_indices + cur_expert_start + ...) produces an ~8 GB offset into unmapped GPU memory, triggering cudaErrorIllegalAddress → SIGABRT.
Steps to reproduce:
1. Instrumented kernel — directly catch stale reads at runtime
The following script adds an instrumented kernel that reads cur_expert_start via both global memory (the buggy path) and register-local extraction (the correct path), then outputs both values for comparison. Any mismatch proves the global memory load read stale data:
import torch
import triton
import triton.language as tl
@triton.jit
def _instrumented_kernel(
num_recv_tokens_per_expert,
expert_start_loc,
global_read_buf,
register_read_buf,
num_experts: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
"""Read cur_expert_start via BOTH global memory and register, output both."""
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts, other=0,
)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
# Method 1: global memory load (the buggy pattern)
from_global = tl.load(expert_start_loc + cur_expert)
# Method 2: register-local extraction (the correct pattern)
expert_mask = offset_cumsum == cur_expert
from_register = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
tl.store(global_read_buf + cur_expert, from_global)
tl.store(register_read_buf + cur_expert, from_register)
device = torch.device("cuda")
num_experts = 256
counts_gpu = torch.tensor([128] * num_experts, dtype=torch.int32, device=device)
total_mismatches = 0
mismatch_examples = []
for round_i in range(1000):
expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=device)
global_buf = torch.empty(num_experts, dtype=torch.int32, device=device)
register_buf = torch.empty(num_experts, dtype=torch.int32, device=device)
_instrumented_kernel[(num_experts,)](
counts_gpu,
expert_start_loc,
global_buf,
register_buf,
num_experts=num_experts,
num_warps=8,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
torch.cuda.synchronize()
g = global_buf.cpu()
r = register_buf.cpu()
diff_mask = g != r
n_diff = int(diff_mask.sum().item())
if n_diff > 0:
total_mismatches += n_diff
if len(mismatch_examples) < 5:
idx = diff_mask.nonzero(as_tuple=True)[0][0].item()
mismatch_examples.append(
f"iter={round_i} expert={idx} "
f"global_read={g[idx].item()} register_read={r[idx].item()}"
)
for ex in mismatch_examples:
print(f"Stale read: {ex}")
print(f"\nResult: {total_mismatches} stale reads in 1000 iterations "
f"({'BUG CONFIRMED' if total_mismatches > 0 else 'no race observed on this hardware'})")
Output on our setup (H20, Triton 3.4.0):
Stale read: iter=0 expert=59 global_read=0 register_read=7552
Stale read: iter=2 expert=52 global_read=0 register_read=6656
Result: 2 stale reads in 1000 iterations (BUG CONFIRMED)
The global memory path reads 0 (uninitialized torch.empty residual) while the register path reads the correct prefix sum. In production, when the stale value is large enough, the resulting out-of-bounds write crashes the process.
2. PTX inspection — confirm no bar.sync between store and load
The following script compiles the kernel and inspects Triton's PTX output:
import os
import glob
import shutil
import torch
import triton
import triton.language as tl
@triton.jit
def _old_kernel(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(
num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0,
)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(m_indices_start_ptr + start_m + off_expert, cur_expert)
def inspect_ptx():
device = torch.device("cuda")
num_experts = 256
BLOCK_E = 128
cache_dir = os.environ.get(
"TRITON_CACHE_DIR",
os.path.join(os.path.expanduser("~"), ".triton", "cache"),
)
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
counts = torch.full((num_experts,), BLOCK_E, dtype=torch.int32, device=device)
start_loc = torch.empty(num_experts, dtype=torch.int32, device=device)
m_indices = torch.empty(num_experts * BLOCK_E, dtype=torch.int32, device=device)
_old_kernel[(num_experts,)](
counts, start_loc, m_indices,
num_experts=num_experts, num_warps=8,
BLOCK_E=BLOCK_E, BLOCK_EXPERT_NUM=256,
)
torch.cuda.synchronize()
ptx_files = sorted(
glob.glob(os.path.join(cache_dir, "**", "*.ptx"), recursive=True),
key=os.path.getmtime, reverse=True,
)
if not ptx_files:
print("No PTX files found")
return
with open(ptx_files[0]) as f:
code = f.read()
bar_count = code.lower().count("bar.sync")
st_global = code.lower().count("st.global")
ld_global = code.lower().count("ld.global")
print(f"bar.sync={bar_count} st.global={st_global} ld.global={ld_global}")
lines = code.split("\n")
for i, line in enumerate(lines):
lo = line.lower()
if any(kw in lo for kw in ["bar.sync", "st.global", "ld.global"]):
print(f" {i:4d} {line.rstrip()}")
if __name__ == "__main__":
inspect_ptx()
Output:
bar.sync=1 st.global=2 ld.global=3
43 @%p1 ld.global.b32 { %r8 }, [ %rd3 + 0 ];
75 bar.sync 0;
120 @%p1 st.global.b32 [ %rd4 + 0 ], { %r11 };
128 ld.global.b32 { %r12 }, [ %rd5 + 0 ];
135 ld.global.b32 { %r13 }, [ %rd6 + 0 ];
160 @%p17 st.global.b32 [ %rd14 + 0 ], { %r57 };
The only bar.sync is at line 75 (for tl.cumsum's shared memory reduction). No barrier between the st.global at line 120 (writing expert_start_loc) and the ld.global at line 128 (reading it back).
Expected behavior:
cur_expert_start should always reflect the correct exclusive prefix sum computed in the same kernel invocation, regardless of warp scheduling.
Error logging:
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
File "deepep_scatter_gather.py", line 25, in _fwd_kernel_ep_scatter_1
cur_expert_start = tl.load(expert_start_loc + cur_expert)
Full crash stack:
model forward
→ layer forward
→ MoE forward
→ fused_moe forward
→ ep_scatter
→ _fwd_kernel_ep_scatter_1 (Triton kernel)
→ CUDA illegal memory access → SIGABRT
Environment:
Proposed fix:
Replace the global memory round-trip with register-local extraction using tl.sum + tl.where:
# Before (buggy):
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
# After (fixed):
expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert)))
This keeps the tl.store to expert_start_loc (needed by _fwd_kernel_ep_scatter_2) but reads values directly from registers. tl.sum internally uses shared memory reduction with proper bar.sync, so cross-warp visibility is guaranteed.
We can submit a PR with the fix and unit tests if that would be helpful.
Additional context:
We discovered this bug while stress-testing RTP-LLM, which uses LightLLM's _fwd_kernel_ep_scatter_1 implementation. The fix PR is here: alibaba/rtp-llm#1098
Language:
English
Title:
[BUG] _fwd_kernel_ep_scatter_1: cross-warp stale read of expert_start_loc causes illegal memory accessLabel:
bugBefore you submit an issue, please search for existing issues to avoid duplicates.
Issue description:
_fwd_kernel_ep_scatter_1indeepep_scatter_gather.pyhas a cross-warp memory visibility bug. After writing the prefix sum toexpert_start_locvia a vectortl.store, the kernel immediately reads back a single element via scalartl.load(lines 23–25):With
num_experts=256andnum_warps=8, the vector store is spread across 8 warps (32 elements each). The scalar load reads a slot written by a different warp. Under CUDA's weak memory model, without an explicitbar.sync+membar.ctabetween the store and load, other warps may read stale/uninitialized data from thetorch.empty-allocatedexpert_start_locbuffer.When
cur_expert_startreads a garbage value (e.g.0x7FFFFFFF), the subsequent unmasked writetl.store(m_indices + cur_expert_start + ...)produces an ~8 GB offset into unmapped GPU memory, triggeringcudaErrorIllegalAddress→SIGABRT.Steps to reproduce:
1. Instrumented kernel — directly catch stale reads at runtime
The following script adds an instrumented kernel that reads
cur_expert_startvia both global memory (the buggy path) and register-local extraction (the correct path), then outputs both values for comparison. Any mismatch proves the global memory load read stale data:Output on our setup (H20, Triton 3.4.0):
The global memory path reads
0(uninitializedtorch.emptyresidual) while the register path reads the correct prefix sum. In production, when the stale value is large enough, the resulting out-of-bounds write crashes the process.2. PTX inspection — confirm no
bar.syncbetween store and loadThe following script compiles the kernel and inspects Triton's PTX output:
Output:
The only
bar.syncis at line 75 (fortl.cumsum's shared memory reduction). No barrier between thest.globalat line 120 (writingexpert_start_loc) and theld.globalat line 128 (reading it back).Expected behavior:
cur_expert_startshould always reflect the correct exclusive prefix sum computed in the same kernel invocation, regardless of warp scheduling.Error logging:
Full crash stack:
Environment:
Using container
GPU info:
NVIDIA-SMI 580.105.08 Driver Version: 580.105.08 CUDA Version: 12.9Python: CPython 3.10.9
PyTorch: 2.8.0+cu129
openai-triton:
triton 3.4.0LightLLM: current
main(deepep_scatter_gather.pylines 23–26)Proposed fix:
Replace the global memory round-trip with register-local extraction using
tl.sum+tl.where:This keeps the
tl.storetoexpert_start_loc(needed by_fwd_kernel_ep_scatter_2) but reads values directly from registers.tl.suminternally uses shared memory reduction with properbar.sync, so cross-warp visibility is guaranteed.We can submit a PR with the fix and unit tests if that would be helpful.
Additional context:
We discovered this bug while stress-testing RTP-LLM, which uses LightLLM's
_fwd_kernel_ep_scatter_1implementation. The fix PR is here: alibaba/rtp-llm#1098Language:
English