-
Notifications
You must be signed in to change notification settings - Fork 731
[XPU] Refactor get_padding_offset to single kernel. #7029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,50 +1,123 @@ | ||||||||||||||
| #include "xpu/kernel/cluster.h" | ||||||||||||||
| #include "xpu/kernel/cluster_partition.h" | ||||||||||||||
| #include "xpu/kernel/cluster_primitive.h" | ||||||||||||||
| #include "xpu/kernel/cluster_simd.h" | ||||||||||||||
|
|
||||||||||||||
| namespace fd_xpu3 { | ||||||||||||||
|
|
||||||||||||||
| __global__ void get_padding_offset(int *batch_id_per_token, | ||||||||||||||
| int *cum_offsets_out, | ||||||||||||||
| int *cu_seqlens_q, | ||||||||||||||
| int *cu_seqlens_k, | ||||||||||||||
| const int *cum_offsets, | ||||||||||||||
| const int *seq_lens, | ||||||||||||||
| #define MAX_BATCH_SIZE 1024 | ||||||||||||||
|
|
||||||||||||||
| static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) { | ||||||||||||||
| auto v1 = vsrlp_int32x16(1 << 8, v0); | ||||||||||||||
| v0 = vvadd_int32x16(v0, v1); | ||||||||||||||
| v1 = vsrlp_int32x16(1 << 7, v0); | ||||||||||||||
| v0 = vvadd_int32x16(v0, v1); | ||||||||||||||
| v1 = vsrlp_int32x16(1 << 6, v0); | ||||||||||||||
| v0 = vvadd_int32x16(v0, v1); | ||||||||||||||
| v1 = vsrlp_int32x16(1 << 5, v0); | ||||||||||||||
| v0 = vvadd_int32x16(v0, v1); | ||||||||||||||
| return vextract_int32x16(v0, 1); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x, | ||||||||||||||
| int64_t len) { | ||||||||||||||
| int32x16_t x_l, x_h; | ||||||||||||||
| int32x16_t sum = vset_zero_int(); | ||||||||||||||
| const auto rounddown_len = rounddown32(len); | ||||||||||||||
|
|
||||||||||||||
| for (int64_t i = 0; i < rounddown_len; i += 32) { | ||||||||||||||
| vload2_sm(x + i, x_l, x_h); | ||||||||||||||
| sum = vvadd_int32x16(sum, x_l); | ||||||||||||||
| sum = vvadd_int32x16(sum, x_h); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| if (rounddown_len < len) { | ||||||||||||||
| const auto mask = ~(-1 << (len - rounddown_len)); | ||||||||||||||
| vload2_sm_mz(x + rounddown_len, x_l, x_h, mask); | ||||||||||||||
| sum = vvadd_int32x16(sum, x_l); | ||||||||||||||
| sum = vvadd_int32x16(sum, x_h); | ||||||||||||||
| } | ||||||||||||||
| return v_reduce_sum_int32(sum); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| __global__ void get_padding_offset(int64_t* ids_remove_padding, | ||||||||||||||
| int* batch_id_per_token, | ||||||||||||||
| int* cum_offsets_out, | ||||||||||||||
| int* cu_seqlens_q, | ||||||||||||||
| int* cu_seqlens_k, | ||||||||||||||
| const int64_t* input_data, | ||||||||||||||
| const int* seq_lens, | ||||||||||||||
| const int max_seq_len, | ||||||||||||||
| const int bs) { | ||||||||||||||
| int cid = core_id(); | ||||||||||||||
| int ncores = core_num(); | ||||||||||||||
| int clusterid = cluster_id(); | ||||||||||||||
| int nclusters = cluster_num(); | ||||||||||||||
| int tid = clusterid * ncores + cid; | ||||||||||||||
|
|
||||||||||||||
| int buf_len = 32; | ||||||||||||||
| __simd__ int batch_id_per_token_lm[buf_len]; | ||||||||||||||
| __simd__ int cum_offsets_lm[16]; | ||||||||||||||
| int seq_len_lm; | ||||||||||||||
| for (int i = clusterid; i < bs; i += nclusters) { | ||||||||||||||
| GM2LM_ASYNC(seq_lens + i, &seq_len_lm, sizeof(int)); | ||||||||||||||
| GM2LM(cum_offsets + i - 1, cum_offsets_lm, 2 * sizeof(int)); | ||||||||||||||
| if (i == 0) { | ||||||||||||||
| cum_offsets_lm[0] = 0; | ||||||||||||||
| } | ||||||||||||||
| for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) { | ||||||||||||||
| int cur_len = min(seq_len_lm - j, buf_len); | ||||||||||||||
| for (int k = 0; k < cur_len; k++) { | ||||||||||||||
| batch_id_per_token_lm[k] = i; | ||||||||||||||
| } | ||||||||||||||
| mfence_lm(); | ||||||||||||||
| LM2GM(batch_id_per_token_lm, | ||||||||||||||
| batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j, | ||||||||||||||
| cur_len * sizeof(int)); | ||||||||||||||
|
|
||||||||||||||
| __shared__ int sm_seq_lens[MAX_BATCH_SIZE]; | ||||||||||||||
| __shared__ int sm_cum_seq_len; | ||||||||||||||
| __simd__ __shared__ int buffer_cu_seqlens[64]; | ||||||||||||||
|
|
||||||||||||||
|
||||||||||||||
| // Ensure bs does not exceed the shared memory buffer capacity | |
| if (bs > MAX_BATCH_SIZE) { | |
| return; | |
| } |
Copilot
AI
Mar 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel 里 sm_seq_lens 使用固定大小 MAX_BATCH_SIZE=1024 的 shared 数组,但没有任何地方保证/检查 bs <= 1024。一旦 bs 超过该值,GM2SM(seq_lens, sm_seq_lens, sizeof(int)*bs) 会发生越界写入,导致结果错误甚至 kernel 崩溃。建议在 wrapper 侧增加显式断言/返回错误,或把 shared 缓冲改为能覆盖任意 bs 的实现(例如分块加载或动态策略)。
Copilot
AI
Mar 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前 kernel 在 clusterid 分片的 for 循环里多次调用 sync_all()。如果 sync_all 是跨 cluster 的全局 barrier(本仓库其他 kernel 通常只在所有 cluster 都会执行相同 barrier 次数的场景才用),则不同 cluster 循环迭代次数不一致会导致死锁/卡住。建议改用 sync_cluster()(若只需 cluster 内同步),或重构为所有 cluster 执行一致的 barrier 次数(例如用 tid/nthreads 的全局并行方式而不是 clusterid-stride 循环)。
Copilot
AI
Mar 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前实现对每个 bi 都通过遍历 seq_lens[0..bi] 来重新计算 cum_seq_len(并做一次 reduce),整体复杂度是 O(bs^2);相比旧实现利用 cum_offsets 做 O(1) 索引,会在大 batch 时明显放大开销。建议改成一次性在 shared/GM 上做 prefix-sum(scan)后直接读 cum_seq_len[bi],或至少避免在每个 bi 上从 0 重新累加。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu_token_num 直接 static_cast 会在超出 int 范围时截断,进而造成输出张量 shape 分配错误甚至越界访问。建议对 cpu_token_num 做范围校验(>=0、<=bsz*seq_length、<=INT_MAX),并尽量在后续逻辑里保持 int64_t 以避免隐式溢出。