Skip to content
Open

commit #7060

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,21 @@ __global__ void multi_query_append_attention_kernel(
#endif
}

template<typename T>
inline __device__ static void ifnan_set0(T *dst, int num)
{
__syncthreads();
int start_idx = threadIdx.x + threadIdx.y * blockDim.x;
for (int i = start_idx; i < num; i += blockDim.x * blockDim.y) {
if (isnan((float)(dst[i]))) {
printf("erros: nan found at \n");
dst[i] = 0.0;
}
}
__syncthreads();
}


template <typename T,
bool partition_kv,
uint32_t GROUP_SIZE,
Expand Down Expand Up @@ -517,6 +532,16 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t chunk_len = chunk_end - chunk_start;

extern __shared__ uint8_t smem[];

constexpr uint32_t smem_size = (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * sizeof(T);
T *haha = (T*)(smem);
for (int i = 0; i < smem_size / sizeof(T); ++i) {
haha[i] = 0;
}
__syncthreads();

const int smem_kv_num = NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM;

float s_frag[num_frags_x][num_frags_z][8];
float o_frag[num_frags_x][num_frags_y][8];
float m_frag[num_frags_x][2];
Expand Down Expand Up @@ -657,6 +682,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();

ifnan_set0((T*)(k_smem.base), smem_kv_num);

// s = qk
compute_qk<num_frags_x, num_frags_y, num_frags_z, T>(
Expand Down Expand Up @@ -713,6 +740,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
wait_group<1>();
__syncthreads();

ifnan_set0((T*)(v_smem.base), smem_kv_num);

// compute sfm*v
compute_sfm_v<num_frags_x, num_frags_y, num_frags_z, T>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag);
Expand Down
Loading