Skip to content

[BUG] #1361

@HoniiTro19

Description

@HoniiTro19

Title: [BUG] _fwd_kernel_ep_scatter_1: cross-warp stale read of expert_start_loc causes illegal memory access

Label: bug


Before you submit an issue, please search for existing issues to avoid duplicates.

Issue description:

_fwd_kernel_ep_scatter_1 in deepep_scatter_gather.py has a cross-warp memory visibility bug. After writing the prefix sum to expert_start_loc via a vector tl.store, the kernel immediately reads back a single element via scalar tl.load (lines 23–25):

cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert)       # ← may read stale value
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)

With num_experts=256 and num_warps=8, the vector store is spread across 8 warps (32 elements each). The scalar load reads a slot written by a different warp. Under CUDA's weak memory model, without an explicit bar.sync + membar.cta between the store and load, other warps may read stale/uninitialized data from the torch.empty-allocated expert_start_loc buffer.

When cur_expert_start reads a garbage value (e.g. 0x7FFFFFFF), the subsequent unmasked write tl.store(m_indices + cur_expert_start + ...) produces an ~8 GB offset into unmapped GPU memory, triggering cudaErrorIllegalAddressSIGABRT.

Steps to reproduce:

1. Instrumented kernel — directly catch stale reads at runtime

The following script adds an instrumented kernel that reads cur_expert_start via both global memory (the buggy path) and register-local extraction (the correct path), then outputs both values for comparison. Any mismatch proves the global memory load read stale data:

import torch
import triton
import triton.language as tl


@triton.jit
def _instrumented_kernel(
    num_recv_tokens_per_expert,
    expert_start_loc,
    global_read_buf,
    register_read_buf,
    num_experts: tl.constexpr,
    BLOCK_EXPERT_NUM: tl.constexpr,
):
    """Read cur_expert_start via BOTH global memory and register, output both."""
    cur_expert = tl.program_id(0)
    offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
    tokens_per_expert = tl.load(
        num_recv_tokens_per_expert + offset_cumsum,
        mask=offset_cumsum < num_experts, other=0,
    )
    cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
    tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)

    # Method 1: global memory load (the buggy pattern)
    from_global = tl.load(expert_start_loc + cur_expert)

    # Method 2: register-local extraction (the correct pattern)
    expert_mask = offset_cumsum == cur_expert
    from_register = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))

    tl.store(global_read_buf + cur_expert, from_global)
    tl.store(register_read_buf + cur_expert, from_register)


device = torch.device("cuda")
num_experts = 256

counts_gpu = torch.tensor([128] * num_experts, dtype=torch.int32, device=device)

total_mismatches = 0
mismatch_examples = []

for round_i in range(1000):
    expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=device)
    global_buf = torch.empty(num_experts, dtype=torch.int32, device=device)
    register_buf = torch.empty(num_experts, dtype=torch.int32, device=device)

    _instrumented_kernel[(num_experts,)](
        counts_gpu,
        expert_start_loc,
        global_buf,
        register_buf,
        num_experts=num_experts,
        num_warps=8,
        BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
    )
    torch.cuda.synchronize()

    g = global_buf.cpu()
    r = register_buf.cpu()
    diff_mask = g != r
    n_diff = int(diff_mask.sum().item())
    if n_diff > 0:
        total_mismatches += n_diff
        if len(mismatch_examples) < 5:
            idx = diff_mask.nonzero(as_tuple=True)[0][0].item()
            mismatch_examples.append(
                f"iter={round_i} expert={idx} "
                f"global_read={g[idx].item()} register_read={r[idx].item()}"
            )

for ex in mismatch_examples:
    print(f"Stale read: {ex}")

print(f"\nResult: {total_mismatches} stale reads in 1000 iterations "
      f"({'BUG CONFIRMED' if total_mismatches > 0 else 'no race observed on this hardware'})")

Output on our setup (H20, Triton 3.4.0):

Stale read: iter=0 expert=59 global_read=0 register_read=7552
Stale read: iter=2 expert=52 global_read=0 register_read=6656

Result: 2 stale reads in 1000 iterations (BUG CONFIRMED)

The global memory path reads 0 (uninitialized torch.empty residual) while the register path reads the correct prefix sum. In production, when the stale value is large enough, the resulting out-of-bounds write crashes the process.

2. PTX inspection — confirm no bar.sync between store and load

The following script compiles the kernel and inspects Triton's PTX output:

import os
import glob
import shutil
import torch
import triton
import triton.language as tl


@triton.jit
def _old_kernel(
    num_recv_tokens_per_expert,
    expert_start_loc,
    m_indices,
    num_experts: tl.constexpr,
    BLOCK_E: tl.constexpr,
    BLOCK_EXPERT_NUM: tl.constexpr,
):
    cur_expert = tl.program_id(0)
    offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
    tokens_per_expert = tl.load(
        num_recv_tokens_per_expert + offset_cumsum,
        mask=offset_cumsum < num_experts,
        other=0,
    )
    cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
    tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
    cur_expert_start = tl.load(expert_start_loc + cur_expert)
    cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
    m_indices_start_ptr = m_indices + cur_expert_start
    off_expert = tl.arange(0, BLOCK_E)
    for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
        tl.store(m_indices_start_ptr + start_m + off_expert, cur_expert)


def inspect_ptx():
    device = torch.device("cuda")
    num_experts = 256
    BLOCK_E = 128

    cache_dir = os.environ.get(
        "TRITON_CACHE_DIR",
        os.path.join(os.path.expanduser("~"), ".triton", "cache"),
    )
    if os.path.exists(cache_dir):
        shutil.rmtree(cache_dir)

    counts = torch.full((num_experts,), BLOCK_E, dtype=torch.int32, device=device)
    start_loc = torch.empty(num_experts, dtype=torch.int32, device=device)
    m_indices = torch.empty(num_experts * BLOCK_E, dtype=torch.int32, device=device)

    _old_kernel[(num_experts,)](
        counts, start_loc, m_indices,
        num_experts=num_experts, num_warps=8,
        BLOCK_E=BLOCK_E, BLOCK_EXPERT_NUM=256,
    )
    torch.cuda.synchronize()

    ptx_files = sorted(
        glob.glob(os.path.join(cache_dir, "**", "*.ptx"), recursive=True),
        key=os.path.getmtime, reverse=True,
    )
    if not ptx_files:
        print("No PTX files found")
        return

    with open(ptx_files[0]) as f:
        code = f.read()

    bar_count = code.lower().count("bar.sync")
    st_global = code.lower().count("st.global")
    ld_global = code.lower().count("ld.global")
    print(f"bar.sync={bar_count}  st.global={st_global}  ld.global={ld_global}")

    lines = code.split("\n")
    for i, line in enumerate(lines):
        lo = line.lower()
        if any(kw in lo for kw in ["bar.sync", "st.global", "ld.global"]):
            print(f"  {i:4d}  {line.rstrip()}")


if __name__ == "__main__":
    inspect_ptx()

Output:

bar.sync=1  st.global=2  ld.global=3
    43  	@%p1 ld.global.b32 { %r8 }, [ %rd3 + 0 ];
    75  	bar.sync 	0;
   120  	@%p1 st.global.b32 [ %rd4 + 0 ], { %r11 };
   128  	ld.global.b32 { %r12 }, [ %rd5 + 0 ];
   135  	ld.global.b32 { %r13 }, [ %rd6 + 0 ];
   160  	@%p17 st.global.b32 [ %rd14 + 0 ], { %r57 };

The only bar.sync is at line 75 (for tl.cumsum's shared memory reduction). No barrier between the st.global at line 120 (writing expert_start_loc) and the ld.global at line 128 (reading it back).

Expected behavior:

cur_expert_start should always reflect the correct exclusive prefix sum computed in the same kernel invocation, regardless of warp scheduling.

Error logging:

RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
  File "deepep_scatter_gather.py", line 25, in _fwd_kernel_ep_scatter_1
    cur_expert_start = tl.load(expert_start_loc + cur_expert)

Full crash stack:

model forward
  → layer forward
    → MoE forward
      → fused_moe forward
        → ep_scatter
          → _fwd_kernel_ep_scatter_1 (Triton kernel)
            → CUDA illegal memory access → SIGABRT

Environment:

  • Using container

  • GPU info:

    • NVIDIA-SMI 580.105.08 Driver Version: 580.105.08 CUDA Version: 12.9
    • Graphics cards: NVIDIA H20
  • Python: CPython 3.10.9

  • PyTorch: 2.8.0+cu129

  • openai-triton: triton 3.4.0

  • LightLLM: current main (deepep_scatter_gather.py lines 23–26)

Proposed fix:

Replace the global memory round-trip with register-local extraction using tl.sum + tl.where:

# Before (buggy):
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)

# After (fixed):
expert_mask = offset_cumsum == cur_expert
cur_expert_start = tl.sum(tl.where(expert_mask, cumsum, tl.zeros_like(cumsum)))
cur_expert_token_num = tl.sum(tl.where(expert_mask, tokens_per_expert, tl.zeros_like(tokens_per_expert)))

This keeps the tl.store to expert_start_loc (needed by _fwd_kernel_ep_scatter_2) but reads values directly from registers. tl.sum internally uses shared memory reduction with proper bar.sync, so cross-warp visibility is guaranteed.

We can submit a PR with the fix and unit tests if that would be helpful.

Additional context:

We discovered this bug while stress-testing RTP-LLM, which uses LightLLM's _fwd_kernel_ep_scatter_1 implementation. The fix PR is here: alibaba/rtp-llm#1098

Language:

English

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions