From 8ebbee81ce1c62c5aede5894700a5e9e552155e6 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Fri, 19 Jun 2026 07:52:14 +0800 Subject: [PATCH 1/2] perf(mi300x): optimize MiniMax M3 sparse index scoring --- .../fixed_seq_len/minimaxm3_fp8_mi300x.sh | 40 +++ .../minimaxm3_mi300x_index_score.patch | 233 ++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_index_score.patch diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index f2cdaf284..402dc2c6f 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -24,6 +24,46 @@ if [[ -n "$SLURM_JOB_ID" ]]; then echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" fi +if ! VLLM_PACKAGE_ROOT="$( + python3 - <<'PY' +from pathlib import Path + +import vllm + +print(Path(vllm.__file__).resolve().parent.parent) +PY +)"; then + echo "Failed to locate the installed vLLM package" >&2 + exit 1 +fi +if [[ -z "$VLLM_PACKAGE_ROOT" || ! -d "$VLLM_PACKAGE_ROOT/vllm" ]]; then + echo "Invalid installed vLLM package root: $VLLM_PACKAGE_ROOT" >&2 + exit 1 +fi + +INDEX_SCORE_PATCH="$(dirname "$0")/minimaxm3_mi300x_index_score.patch" +if [[ ! -f "$INDEX_SCORE_PATCH" ]]; then + echo "MI300X sparse-index scorer patch is missing: $INDEX_SCORE_PATCH" >&2 + exit 1 +fi + +PATCH_CHECK_ARGS=(--batch --silent -d "$VLLM_PACKAGE_ROOT" -p1 --dry-run) +if patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$INDEX_SCORE_PATCH"; then + echo "MI300X sparse-index scorer patch is already fully applied" +elif patch "${PATCH_CHECK_ARGS[@]}" --forward < "$INDEX_SCORE_PATCH"; then + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$INDEX_SCORE_PATCH"; then + echo "Failed to apply the MI300X sparse-index scorer patch" >&2 + exit 1 + fi +else + echo "Installed vLLM cannot cleanly apply the MI300X sparse-index scorer patch" >&2 + exit 1 +fi +if ! patch "${PATCH_CHECK_ARGS[@]}" --reverse --forward < "$INDEX_SCORE_PATCH"; then + echo "MI300X sparse-index scorer patch verification failed" >&2 + exit 1 +fi + if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi if [ -n "$ROCR_VISIBLE_DEVICES" ]; then diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_index_score.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_index_score.patch new file mode 100644 index 000000000..9d8967e87 --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_index_score.patch @@ -0,0 +1,233 @@ +diff --git a/vllm/models/minimax_m3/common/ops/index_topk.py b/vllm/models/minimax_m3/common/ops/index_topk.py +index c32ff38d9..0539a4113 100644 +--- a/vllm/models/minimax_m3/common/ops/index_topk.py ++++ b/vllm/models/minimax_m3/common/ops/index_topk.py +@@ -14,6 +14,8 @@ disabled (score-only indexer), single shared index head. The selected block ids + feed the block-sparse attention kernels in ``sparse_attn``. + """ + ++from functools import cache ++ + import torch + + from vllm.platforms import current_platform +@@ -24,6 +26,47 @@ from vllm.utils.math_utils import round_up + SPARSE_BLOCK_SIZE = 128 + + ++@cache ++def _use_gfx942_high_batch_score_splits() -> bool: ++ if current_platform.is_rocm(): ++ from vllm.platforms.rocm import on_gfx942 ++ ++ return on_gfx942() ++ return False ++ ++ ++def _decode_score_num_kv_chunks(batch: int, max_block: int) -> int: ++ """Choose split-K parallelism for the decode index scorer. ++ ++ The scorer is dominated by streaming index-K reads. At high batch on ++ gfx942, keep about eight 32-KiB KV blocks in each workgroup. This exposes ++ enough independent HBM work without the launch overhead and short loops ++ caused by splitting every block into its own workgroup. Never launch more ++ chunks than blocks: those workgroups would immediately exit. ++ """ ++ if batch >= 128 and _use_gfx942_high_batch_score_splits(): ++ target = max(1, min(256, max_block // 8)) ++ else: ++ target = max(1, min(256, 512 // max(1, batch))) ++ target = min(target, max(1, max_block)) ++ return 1 << (target.bit_length() - 1) ++ ++ ++def _use_gfx942_decode_score_vector( ++ batch: int, ++ num_idx_heads: int, ++ head_dim: int, ++ decode_query_len: int, ++) -> bool: ++ return ( ++ batch >= 128 ++ and num_idx_heads == 1 ++ and head_dim == 128 ++ and decode_query_len == 1 ++ and _use_gfx942_high_batch_score_splits() ++ ) ++ ++ + # --------------------------------------------------------------------------- + # Bitonic top-k helpers (layout-agnostic). + # --------------------------------------------------------------------------- +@@ -383,6 +426,76 @@ def _decode_index_score_kernel( + ) + + ++# gfx942 is substantially faster on M3's single-query decode shape when each ++# workgroup streams one KV block and reduces its 128 independent token scores ++# directly. The generic dot kernel remains preferable for multi-head/spec ++# decode, where its matrix tile reuses each K block across multiple queries. ++@triton.jit(do_not_specialize=["num_kv_chunks"]) ++def _decode_index_score_vector_kernel( ++ q_ptr, # idx_q: [batch, 1, 128] ++ ik_cache_ptr, # index-K cache: [num_blocks, 128, 128] ++ score_ptr, # [1, batch, max_block] ++ block_table_ptr, # [batch, max_blocks] ++ seq_lens, # [batch] ++ init_blocks, ++ local_blocks, ++ stride_q_n, ++ stride_q_d, ++ stride_ik_blk, ++ stride_ik_pos, ++ stride_ik_d, ++ stride_s_n, ++ stride_s_k, ++ stride_bt_b, ++ BLOCK_SIZE_K: tl.constexpr, ++ BLOCK_SIZE_D: tl.constexpr, ++ num_kv_chunks, ++): ++ request_id = tl.program_id(0) ++ chunk_id = tl.program_id(1) ++ seq_len = tl.load(seq_lens + request_id) ++ num_blocks = (seq_len + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K ++ chunk_size = (num_blocks + num_kv_chunks - 1) // num_kv_chunks ++ chunk_start = chunk_id * chunk_size ++ chunk_end = tl.minimum(chunk_start + chunk_size, num_blocks) ++ if chunk_start >= chunk_end: ++ return ++ ++ offsets_k = tl.arange(0, BLOCK_SIZE_K) ++ block_table_row = block_table_ptr + request_id * stride_bt_b ++ local_start = tl.maximum(0, num_blocks - local_blocks) ++ for block_id in tl.range(chunk_start, chunk_end): ++ page = tl.load(block_table_row + block_id).to(tl.int64) ++ accumulator = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) ++ for d_start in tl.static_range(0, 128, BLOCK_SIZE_D): ++ offsets_d = d_start + tl.arange(0, BLOCK_SIZE_D) ++ query = tl.load( ++ q_ptr + request_id * stride_q_n + offsets_d * stride_q_d ++ ).to(tl.float32) ++ keys = tl.load( ++ ik_cache_ptr ++ + page * stride_ik_blk ++ + offsets_k[:, None] * stride_ik_pos ++ + offsets_d[None, :] * stride_ik_d ++ ).to(tl.float32) ++ accumulator += tl.sum(keys * query[None, :], axis=1) ++ ++ positions = block_id * BLOCK_SIZE_K + offsets_k ++ accumulator = tl.where( ++ positions < seq_len, ++ accumulator, ++ float("-inf"), ++ ) ++ score = tl.max(accumulator, axis=0) ++ is_init = block_id < init_blocks ++ is_local = (block_id >= local_start) & (block_id < num_blocks) ++ score = tl.where(is_local, 1e29, tl.where(is_init, 1e30, score)) ++ tl.store( ++ score_ptr + request_id * stride_s_n + block_id * stride_s_k, ++ score, ++ ) ++ ++ + # --------------------------------------------------------------------------- + # Decode top-k (split-K): per-chunk partial top-k + merge. Forced init/local + # blocks are already encoded in the scores. +@@ -783,40 +896,64 @@ def minimax_m3_index_decode( + ) + # split-K over seq blocks; chunk count depends only on shape constants so + # the grid is fixed within a cuda graph. +- TARGET_GRID = 4096 +- MAX_NUM_KV_CHUNKS = 256 +- target = max( +- 1, min(MAX_NUM_KV_CHUNKS, TARGET_GRID // max(1, batch * num_idx_heads)) +- ) +- num_kv_chunks = 1 << (target.bit_length() - 1) ++ num_kv_chunks = _decode_score_num_kv_chunks(batch, max_block) + grid_score = (batch, num_kv_chunks) +- _decode_index_score_kernel[grid_score]( +- idx_q, +- index_kv_cache, +- score, +- block_table, +- seq_lens, ++ if _use_gfx942_decode_score_vector( ++ batch, + num_idx_heads, + head_dim, +- init_blocks, +- local_blocks, +- sm_scale, + decode_query_len, +- idx_q.stride(0), +- idx_q.stride(1), +- idx_q.stride(2), +- index_kv_cache.stride(0), +- index_kv_cache.stride(1), +- index_kv_cache.stride(2), +- score.stride(0), +- score.stride(1), +- score.stride(2), +- block_table.stride(0), +- BLOCK_SIZE_K=SPARSE_BLOCK_SIZE, +- num_kv_chunks=num_kv_chunks, +- USE_PDL=use_pdl, +- **pdl_launch, +- ) ++ ): ++ _decode_index_score_vector_kernel[grid_score]( ++ idx_q, ++ index_kv_cache, ++ score, ++ block_table, ++ seq_lens, ++ init_blocks, ++ local_blocks, ++ idx_q.stride(0), ++ idx_q.stride(2), ++ index_kv_cache.stride(0), ++ index_kv_cache.stride(1), ++ index_kv_cache.stride(2), ++ score.stride(1), ++ score.stride(2), ++ block_table.stride(0), ++ BLOCK_SIZE_K=SPARSE_BLOCK_SIZE, ++ BLOCK_SIZE_D=64, ++ num_kv_chunks=num_kv_chunks, ++ num_warps=8, ++ num_stages=1, ++ ) ++ else: ++ _decode_index_score_kernel[grid_score]( ++ idx_q, ++ index_kv_cache, ++ score, ++ block_table, ++ seq_lens, ++ num_idx_heads, ++ head_dim, ++ init_blocks, ++ local_blocks, ++ sm_scale, ++ decode_query_len, ++ idx_q.stride(0), ++ idx_q.stride(1), ++ idx_q.stride(2), ++ index_kv_cache.stride(0), ++ index_kv_cache.stride(1), ++ index_kv_cache.stride(2), ++ score.stride(0), ++ score.stride(1), ++ score.stride(2), ++ block_table.stride(0), ++ BLOCK_SIZE_K=SPARSE_BLOCK_SIZE, ++ num_kv_chunks=num_kv_chunks, ++ USE_PDL=use_pdl, ++ **pdl_launch, ++ ) + + topk_idx = torch.empty( + (num_idx_heads, total_q, topk), From 5d29425348691f42ae19a9f4b629048cc23d3935 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Fri, 19 Jun 2026 07:53:18 +0800 Subject: [PATCH 2/2] chore: benchmark MI300X MiniMax M3 index scorer --- perf-changelog.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 06a81eaf1..f395f7d8e 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3950,3 +3950,9 @@ - "Update ISL=8192 search-space: TP8-only from conc=4-64, DPA from conc=128-1024 (previously conc=1-64 and DPA conc=64-512)" - "Update Applied TBO on high concurrencies" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1717 + +- config-keys: + - minimaxm3-fp8-mi300x-vllm + description: + - "Apply a gfx942-specific vectorized MiniMax M3 sparse index scorer at high batch, reducing isolated c256 scorer latency by 47-56% at 1k, 8k, and 32k context." + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1840