Skip to content

[XPU] Refactor get_padding_offset to single kernel.#7029

Open
Jiajun-Ji wants to merge 3 commits intoPaddlePaddle:developfrom
Jiajun-Ji:get_padding_offset
Open

[XPU] Refactor get_padding_offset to single kernel.#7029
Jiajun-Ji wants to merge 3 commits intoPaddlePaddle:developfrom
Jiajun-Ji:get_padding_offset

Conversation

@Jiajun-Ji
Copy link
Copy Markdown
Contributor

@Jiajun-Ji Jiajun-Ji commented Mar 26, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Rewrite get_padding_offset kernel to align with GPU implementation

Modifications

Usage or Command

Accuracy Tests

  • benchmark长度没发现明显异常(21B A3B单卡)
image
  • 解码未结果发现明显异常
image image image

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 26, 2026 08:14
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 26, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在将 XPU 的 get_padding_offset 重构为单 kernel 实现,以对齐 GPU 侧实现思路,并减少多 kernel 调用开销。

Changes:

  • 调整 XPU Python 侧 get_padding_offset 调用签名:改为只传 input_ids/seq_len/cpu_token_num
  • 更新 XPU 自定义算子/插件接口:移除 cum_offsets/padding_offset 输入输出相关依赖,输出改为 ids_remove_padding + batch_id_per_token + cum_offsets_out + cu_seqlens_*
  • 合并 Kunlun3(XPU3) 上的 padding_offset 计算与 remove_padding 逻辑到单个 XPU kernel。

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
fastdeploy/model_executor/xpu_pre_and_post_process.py 更新 XPU 预处理流程对 get_padding_offset 的调用方式与入参
custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp 插件 wrapper 调整:单 kernel 调用,接口参数重排/裁剪
custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu 新的 Kunlun3 单 kernel 实现(计算 cu_seqlens/cum_offsets + remove padding + batch_id 映射)
custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h 更新插件头文件中 get_padding_offset 的导出接口签名
custom_ops/xpu_ops/src/ops/pybind/pybind.cc 更新 pybind 暴露的 get_padding_offset Python 签名
custom_ops/xpu_ops/src/ops/get_padding_offset.cc 更新 Paddle 扩展侧算子实现与 OP 注册:移除输入、增加 cpu_token_num 属性/参数

Comment on lines 113 to +155
@@ -148,10 +149,10 @@
(
ids_remove_padding,
cum_offsets,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
) = get_padding_offset(input_ids, seq_lens_this_time, token_num_cpu)
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 token_num_cpu = paddle.sum(...).cpu() 得到的是 CPU Tensor(不是 Python 标量)。但 XPU 的 get_padding_offset 绑定现在第三个参数是 int64_t cpu_token_num,传 Tensor 会在 pybind 转换时报类型错误并导致运行期直接崩溃。建议改为从 Tensor 提取标量(如 int(paddle.sum(seq_lens_this_time).item())),并避免重复计算 sum。

Copilot uses AI. Check for mistakes.
Comment on lines 112 to 116
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
token_num_cpu = paddle.sum(seq_lens_this_time).cpu()

if use_speculate_method:
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cum_offsets_nowtoken_num 在非 speculate 分支不再被使用,但当前每次都会计算一次 cumsum/sum,同时又额外算了一次 paddle.sum(...).cpu(),会引入不必要的 XPU 计算和(修复后)同步开销。建议把 cum_offsets_now/token_num 移到 use_speculate_method 分支内,并复用一次 sum 的结果。

Suggested change
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
token_num_cpu = paddle.sum(seq_lens_this_time).cpu()
if use_speculate_method:
token_num = paddle.sum(seq_lens_this_time)
token_num_cpu = token_num.cpu()
if use_speculate_method:
cum_offsets_now = paddle.cumsum(
max_len - seq_lens_this_time, dtype="int32"
)

Copilot uses AI. Check for mistakes.
Comment on lines +57 to +63
__shared__ int sm_seq_lens[MAX_BATCH_SIZE];
__shared__ int sm_cum_seq_len;
__simd__ __shared__ int buffer_cu_seqlens[64];

if (cid == 0) {
GM2SM(seq_lens, sm_seq_lens, sizeof(int) * bs);
}
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel 里 sm_seq_lens 使用固定大小 MAX_BATCH_SIZE=1024 的 shared 数组,但没有任何地方保证/检查 bs <= 1024。一旦 bs 超过该值,GM2SM(seq_lens, sm_seq_lens, sizeof(int)*bs) 会发生越界写入,导致结果错误甚至 kernel 崩溃。建议在 wrapper 侧增加显式断言/返回错误,或把 shared 缓冲改为能覆盖任意 bs 的实现(例如分块加载或动态策略)。

Copilot uses AI. Check for mistakes.
Comment on lines +66 to +88
for (int bi = clusterid; bi < bs; bi += nclusters) {
int cum_seq_len = 0;
for (int i = cid; i <= bi; i += ncores) {
cum_seq_len += sm_seq_lens[i];
}
buffer_cu_seqlens[cid] = cum_seq_len;
mfence();
sync_all();

if (cid == 0) {
int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1];
mfence_lm();
LM2GM_ASYNC(cum_offsets_lm, cum_offsets_out + i, sizeof(int));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + i + 1, sizeof(int));
LM2GM(&cum_seq_len, cu_seqlens_k + i + 1, sizeof(int));
cum_seq_len =
primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores));

LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int));

int cum_offset = bi * max_seq_len - (cum_seq_len - sm_seq_lens[bi]);
LM2GM(&cum_offset, cum_offsets_out + bi, sizeof(int));

sm_cum_seq_len = cum_seq_len;
}
mfence();
sync_all();
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前实现对每个 bi 都通过遍历 seq_lens[0..bi] 来重新计算 cum_seq_len(并做一次 reduce),整体复杂度是 O(bs^2);相比旧实现利用 cum_offsets 做 O(1) 索引,会在大 batch 时明显放大开销。建议改成一次性在 shared/GM 上做 prefix-sum(scan)后直接读 cum_seq_len[bi],或至少避免在每个 bi 上从 0 重新累加。

Copilot uses AI. Check for mistakes.
Comment on lines 149 to +155
(
ids_remove_padding,
cum_offsets,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
) = get_padding_offset(input_ids, seq_lens_this_time, token_num_cpu)
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 描述里 Modifications / Usage or Command / Accuracy Tests 等小节目前是空的,但该 PR 修改了 XPU 核心 kernel 行为,建议补充:具体改动点、如何复现/benchmark、以及至少一组精度对齐或回归验证结果,方便后续评审与回滚定位。

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 26, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@6cff780). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7029   +/-   ##
==========================================
  Coverage           ?   73.63%           
==========================================
  Files              ?      399           
  Lines              ?    56412           
  Branches           ?     8919           
==========================================
  Hits               ?    41538           
  Misses             ?    11920           
  Partials           ?     2954           
Flag Coverage Δ
GPU 73.63% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copilot AI review requested due to automatic review settings March 27, 2026 02:16
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Comment on lines +64 to +74
sync_all();

for (int bi = clusterid; bi < bs; bi += nclusters) {
int cum_seq_len = 0;
for (int i = cid; i <= bi; i += ncores) {
cum_seq_len += sm_seq_lens[i];
}
buffer_cu_seqlens[cid] = cum_seq_len;
mfence();
sync_all();

Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前 kernel 在 clusterid 分片的 for 循环里多次调用 sync_all()。如果 sync_all 是跨 cluster 的全局 barrier(本仓库其他 kernel 通常只在所有 cluster 都会执行相同 barrier 次数的场景才用),则不同 cluster 循环迭代次数不一致会导致死锁/卡住。建议改用 sync_cluster()(若只需 cluster 内同步),或重构为所有 cluster 执行一致的 barrier 次数(例如用 tid/nthreads 的全局并行方式而不是 clusterid-stride 循环)。

Copilot uses AI. Check for mistakes.
__shared__ int sm_seq_lens[MAX_BATCH_SIZE];
__shared__ int sm_cum_seq_len;
__simd__ __shared__ int buffer_cu_seqlens[64];

Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm_seq_lens 使用固定大小 MAX_BATCH_SIZE=1024 的 shared buffer,但没有对 bs 做上界校验;当 bs>1024 时 GM2SM 会越界写 shared memory,导致未定义行为。建议在 wrapper 或 kernel 开头增加显式检查/断言(bs<=MAX_BATCH_SIZE),或改为分块加载/使用动态分配方案。

Suggested change
// Ensure bs does not exceed the shared memory buffer capacity
if (bs > MAX_BATCH_SIZE) {
return;
}

Copilot uses AI. Check for mistakes.
Comment on lines 26 to 30
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = static_cast<int>(cpu_token_num);

Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu_token_num 直接 static_cast 会在超出 int 范围时截断,进而造成输出张量 shape 分配错误甚至越界访问。建议对 cpu_token_num 做范围校验(>=0、<=bsz*seq_length、<=INT_MAX),并尽量在后续逻辑里保持 int64_t 以避免隐式溢出。

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

@EmmonsCurse EmmonsCurse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~ Skip coverage check as it mainly relies on XPU end-to-end tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants