From a2b4cb0ccdf32a99497671e54e0c17f51f477342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cliuruian=E2=80=9D?= Date: Sat, 28 Mar 2026 23:01:15 +0800 Subject: [PATCH 1/5] cpmmot --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index d8c94dc5446..8f7837a69a5 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -157,8 +157,8 @@ __global__ void multi_query_append_attention_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t q_end = q_len; + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -569,8 +569,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t q_end = q_len; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); From d18d23a48bedccc30781032cee27240f1e44b79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cliuruian=E2=80=9D?= Date: Sun, 29 Mar 2026 09:04:47 +0800 Subject: [PATCH 2/5] cpmmot --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 8f7837a69a5..623f54fe21d 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -157,8 +157,6 @@ __global__ void multi_query_append_attention_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - const uint32_t q_end = q_len; - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -166,7 +164,7 @@ __global__ void multi_query_append_attention_kernel( q_base_ptr, &qo_smem, q_base_seq_id_this_block, - q_end, + q_len, q_ori_n_stride, HEAD_DIM); commit_group(); @@ -569,8 +567,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 - const uint32_t q_end = q_len; - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -582,7 +578,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( T>(q_base_ptr, &qo_smem, q_base_seq_id_this_block, - q_end, + q_len, q_ori_n_stride, HEAD_DIM); commit_group(); From 3994595963dd8ecea307969bc7d92b17722a710a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cliuruian=E2=80=9D?= Date: Sun, 29 Mar 2026 09:32:47 +0800 Subject: [PATCH 3/5] cpmmot --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 623f54fe21d..5eaa091d1e7 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -588,9 +588,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel( q_smem_inplace_multiply_sm_scale_multi_warps( &qo_smem, scale); - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); + static_assert(num_rows_per_block == num_frags_x * 16); + static_assert(BLOCK_SIZE == NUM_WARP_KV * num_frags_z * 16); + smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); const uint32_t num_iterations = div_up( CAUSAL From 430759339d1775befac2fa1f1c18ffd429013537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cliuruian=E2=80=9D?= Date: Sun, 29 Mar 2026 12:04:00 +0800 Subject: [PATCH 4/5] cpmmot --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 5eaa091d1e7..b535b69a0d4 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -601,13 +601,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel( div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), chunk_start))) : chunk_len, - NUM_WARP_KV * num_frags_z * 16); + BLOCK_SIZE); const uint32_t mask_check_iteration = (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero(kv_len - q_len, chunk_start))) : mask_offset ? 0 : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); + (BLOCK_SIZE); uint32_t k_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); @@ -694,7 +694,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + kv_idx_base += BLOCK_SIZE; block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; From 48b49159837e69f8e8e6407eb2fdc46e52f8c979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cliuruian=E2=80=9D?= Date: Mon, 30 Mar 2026 12:23:09 +0800 Subject: [PATCH 5/5] cpmmot --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index b535b69a0d4..cf283a617d2 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -484,11 +484,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - // When cudagraph capture prefill, may launch more gridDim.x - if (btid >= static_cast(num_blocks_x_cpu)) { - return; - } - const uint32_t q_len = seq_lens[batch_id]; if (q_len <= 0) { return;