Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,46 @@ if [[ -n "$SLURM_JOB_ID" ]]; then
echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME"
fi

if ! VLLM_PACKAGE_ROOT="$(
python3 - <<'PY'
from pathlib import Path

import vllm

print(Path(vllm.__file__).resolve().parent.parent)
PY
)"; then
echo "Failed to locate the installed vLLM package" >&2
exit 1
fi
if [[ -z "$VLLM_PACKAGE_ROOT" || ! -d "$VLLM_PACKAGE_ROOT/vllm" ]]; then
echo "Invalid installed vLLM package root: $VLLM_PACKAGE_ROOT" >&2
exit 1
fi

INDEX_SCORE_PATCH="$(dirname "$0")/minimaxm3_mi300x_index_score.patch"
if [[ ! -f "$INDEX_SCORE_PATCH" ]]; then
echo "MI300X sparse-index scorer patch is missing: $INDEX_SCORE_PATCH" >&2
exit 1
fi

PATCH_CHECK_ARGS=(--batch --silent -d "$VLLM_PACKAGE_ROOT" -p1 --dry-run)
if patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$INDEX_SCORE_PATCH"; then
echo "MI300X sparse-index scorer patch is already fully applied"
elif patch "${PATCH_CHECK_ARGS[@]}" --forward < "$INDEX_SCORE_PATCH"; then
if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$INDEX_SCORE_PATCH"; then
echo "Failed to apply the MI300X sparse-index scorer patch" >&2
exit 1
fi
else
echo "Installed vLLM cannot cleanly apply the MI300X sparse-index scorer patch" >&2
exit 1
fi
if ! patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$INDEX_SCORE_PATCH"; then
echo "MI300X sparse-index scorer patch verification failed" >&2
exit 1
fi

if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi

if [ -n "$ROCR_VISIBLE_DEVICES" ]; then
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
diff --git a/vllm/models/minimax_m3/common/ops/index_topk.py b/vllm/models/minimax_m3/common/ops/index_topk.py
index c32ff38d9..0539a4113 100644
--- a/vllm/models/minimax_m3/common/ops/index_topk.py
+++ b/vllm/models/minimax_m3/common/ops/index_topk.py
@@ -14,6 +14,8 @@ disabled (score-only indexer), single shared index head. The selected block ids
feed the block-sparse attention kernels in ``sparse_attn``.
"""

+from functools import cache
+
import torch

from vllm.platforms import current_platform
@@ -24,6 +26,47 @@ from vllm.utils.math_utils import round_up
SPARSE_BLOCK_SIZE = 128


+@cache
+def _use_gfx942_high_batch_score_splits() -> bool:
+ if current_platform.is_rocm():
+ from vllm.platforms.rocm import on_gfx942
+
+ return on_gfx942()
+ return False
+
+
+def _decode_score_num_kv_chunks(batch: int, max_block: int) -> int:
+ """Choose split-K parallelism for the decode index scorer.
+
+ The scorer is dominated by streaming index-K reads. At high batch on
+ gfx942, keep about eight 32-KiB KV blocks in each workgroup. This exposes
+ enough independent HBM work without the launch overhead and short loops
+ caused by splitting every block into its own workgroup. Never launch more
+ chunks than blocks: those workgroups would immediately exit.
+ """
+ if batch >= 128 and _use_gfx942_high_batch_score_splits():
+ target = max(1, min(256, max_block // 8))
+ else:
+ target = max(1, min(256, 512 // max(1, batch)))
+ target = min(target, max(1, max_block))
+ return 1 << (target.bit_length() - 1)
+
+
+def _use_gfx942_decode_score_vector(
+ batch: int,
+ num_idx_heads: int,
+ head_dim: int,
+ decode_query_len: int,
+) -> bool:
+ return (
+ batch >= 128
+ and num_idx_heads == 1
+ and head_dim == 128
+ and decode_query_len == 1
+ and _use_gfx942_high_batch_score_splits()
+ )
+
+
# ---------------------------------------------------------------------------
# Bitonic top-k helpers (layout-agnostic).
# ---------------------------------------------------------------------------
@@ -383,6 +426,76 @@ def _decode_index_score_kernel(
)


+# gfx942 is substantially faster on M3's single-query decode shape when each
+# workgroup streams one KV block and reduces its 128 independent token scores
+# directly. The generic dot kernel remains preferable for multi-head/spec
+# decode, where its matrix tile reuses each K block across multiple queries.
+@triton.jit(do_not_specialize=["num_kv_chunks"])
+def _decode_index_score_vector_kernel(
+ q_ptr, # idx_q: [batch, 1, 128]
+ ik_cache_ptr, # index-K cache: [num_blocks, 128, 128]
+ score_ptr, # [1, batch, max_block]
+ block_table_ptr, # [batch, max_blocks]
+ seq_lens, # [batch]
+ init_blocks,
+ local_blocks,
+ stride_q_n,
+ stride_q_d,
+ stride_ik_blk,
+ stride_ik_pos,
+ stride_ik_d,
+ stride_s_n,
+ stride_s_k,
+ stride_bt_b,
+ BLOCK_SIZE_K: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr,
+ num_kv_chunks,
+):
+ request_id = tl.program_id(0)
+ chunk_id = tl.program_id(1)
+ seq_len = tl.load(seq_lens + request_id)
+ num_blocks = (seq_len + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K
+ chunk_size = (num_blocks + num_kv_chunks - 1) // num_kv_chunks
+ chunk_start = chunk_id * chunk_size
+ chunk_end = tl.minimum(chunk_start + chunk_size, num_blocks)
+ if chunk_start >= chunk_end:
+ return
+
+ offsets_k = tl.arange(0, BLOCK_SIZE_K)
+ block_table_row = block_table_ptr + request_id * stride_bt_b
+ local_start = tl.maximum(0, num_blocks - local_blocks)
+ for block_id in tl.range(chunk_start, chunk_end):
+ page = tl.load(block_table_row + block_id).to(tl.int64)
+ accumulator = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
+ for d_start in tl.static_range(0, 128, BLOCK_SIZE_D):
+ offsets_d = d_start + tl.arange(0, BLOCK_SIZE_D)
+ query = tl.load(
+ q_ptr + request_id * stride_q_n + offsets_d * stride_q_d
+ ).to(tl.float32)
+ keys = tl.load(
+ ik_cache_ptr
+ + page * stride_ik_blk
+ + offsets_k[:, None] * stride_ik_pos
+ + offsets_d[None, :] * stride_ik_d
+ ).to(tl.float32)
+ accumulator += tl.sum(keys * query[None, :], axis=1)
+
+ positions = block_id * BLOCK_SIZE_K + offsets_k
+ accumulator = tl.where(
+ positions < seq_len,
+ accumulator,
+ float("-inf"),
+ )
+ score = tl.max(accumulator, axis=0)
+ is_init = block_id < init_blocks
+ is_local = (block_id >= local_start) & (block_id < num_blocks)
+ score = tl.where(is_local, 1e29, tl.where(is_init, 1e30, score))
+ tl.store(
+ score_ptr + request_id * stride_s_n + block_id * stride_s_k,
+ score,
+ )
+
+
# ---------------------------------------------------------------------------
# Decode top-k (split-K): per-chunk partial top-k + merge. Forced init/local
# blocks are already encoded in the scores.
@@ -783,40 +896,64 @@ def minimax_m3_index_decode(
)
# split-K over seq blocks; chunk count depends only on shape constants so
# the grid is fixed within a cuda graph.
- TARGET_GRID = 4096
- MAX_NUM_KV_CHUNKS = 256
- target = max(
- 1, min(MAX_NUM_KV_CHUNKS, TARGET_GRID // max(1, batch * num_idx_heads))
- )
- num_kv_chunks = 1 << (target.bit_length() - 1)
+ num_kv_chunks = _decode_score_num_kv_chunks(batch, max_block)
grid_score = (batch, num_kv_chunks)
- _decode_index_score_kernel[grid_score](
- idx_q,
- index_kv_cache,
- score,
- block_table,
- seq_lens,
+ if _use_gfx942_decode_score_vector(
+ batch,
num_idx_heads,
head_dim,
- init_blocks,
- local_blocks,
- sm_scale,
decode_query_len,
- idx_q.stride(0),
- idx_q.stride(1),
- idx_q.stride(2),
- index_kv_cache.stride(0),
- index_kv_cache.stride(1),
- index_kv_cache.stride(2),
- score.stride(0),
- score.stride(1),
- score.stride(2),
- block_table.stride(0),
- BLOCK_SIZE_K=SPARSE_BLOCK_SIZE,
- num_kv_chunks=num_kv_chunks,
- USE_PDL=use_pdl,
- **pdl_launch,
- )
+ ):
+ _decode_index_score_vector_kernel[grid_score](
+ idx_q,
+ index_kv_cache,
+ score,
+ block_table,
+ seq_lens,
+ init_blocks,
+ local_blocks,
+ idx_q.stride(0),
+ idx_q.stride(2),
+ index_kv_cache.stride(0),
+ index_kv_cache.stride(1),
+ index_kv_cache.stride(2),
+ score.stride(1),
+ score.stride(2),
+ block_table.stride(0),
+ BLOCK_SIZE_K=SPARSE_BLOCK_SIZE,
+ BLOCK_SIZE_D=64,
+ num_kv_chunks=num_kv_chunks,
+ num_warps=8,
+ num_stages=1,
+ )
+ else:
+ _decode_index_score_kernel[grid_score](
+ idx_q,
+ index_kv_cache,
+ score,
+ block_table,
+ seq_lens,
+ num_idx_heads,
+ head_dim,
+ init_blocks,
+ local_blocks,
+ sm_scale,
+ decode_query_len,
+ idx_q.stride(0),
+ idx_q.stride(1),
+ idx_q.stride(2),
+ index_kv_cache.stride(0),
+ index_kv_cache.stride(1),
+ index_kv_cache.stride(2),
+ score.stride(0),
+ score.stride(1),
+ score.stride(2),
+ block_table.stride(0),
+ BLOCK_SIZE_K=SPARSE_BLOCK_SIZE,
+ num_kv_chunks=num_kv_chunks,
+ USE_PDL=use_pdl,
+ **pdl_launch,
+ )

topk_idx = torch.empty(
(num_idx_heads, total_q, topk),
6 changes: 6 additions & 0 deletions perf-changelog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3950,3 +3950,9 @@
- "Update ISL=8192 search-space: TP8-only from conc=4-64, DPA from conc=128-1024 (previously conc=1-64 and DPA conc=64-512)"
- "Update Applied TBO on high concurrencies"
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1717

- config-keys:
- minimaxm3-fp8-mi300x-vllm
description:
- "Apply a gfx942-specific vectorized MiniMax M3 sparse index scorer at high batch, reducing isolated c256 scorer latency by 47-56% at 1k, 8k, and 32k context."
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1840
Loading