Skip to content

Conversation

@bukejiyu
Copy link
Collaborator

@bukejiyu bukejiyu commented Dec 25, 2025

Motivation

PR from #5194
为了能够更好的接入三方Attention,需要对输入进行重排,将prefill token和decode token区分开来,本PR支持了重排功能,目前支持了基础场景及投机解码场景下的重排

Modifications

当前PD重排仅支持 CUDA backend
1.新增InputBatch结构用于gpumodelrunner share_input 管理,新增ProposerInputBatch 用于mtp share_input 管理,用于管理gpu_model_runner的输入,并且添加reorder_split_prefill_and_decode和condense函数支持重排
2.merge develop
3.为每个VL请求增加req_id -> img_features的映射方便重排

Usage or Command

在AttentionBackend中添加类变量enable_ids_reorder字段并设置为True,即可使用P/D重排功能

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@CLAassistant
Copy link

CLAassistant commented Dec 25, 2025

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
2 out of 3 committers have signed the CLA.

✅ EmmonsCurse
✅ bukejiyu
❌ root


root seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot
Copy link

paddle-bot bot commented Dec 25, 2025

Thanks for your contribution!

@codecov-commenter
Copy link

codecov-commenter commented Jan 5, 2026

Codecov Report

❌ Patch coverage is 74.24512% with 145 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@9fc2400). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/worker/input_batch.py 73.67% 105 Missing and 9 partials ⚠️
fastdeploy/worker/gpu_model_runner.py 79.26% 14 Missing and 3 partials ⚠️
fastdeploy/spec_decode/mtp.py 76.92% 6 Missing ⚠️
fastdeploy/model_executor/pre_and_post_process.py 57.14% 3 Missing ⚠️
...tdeploy/model_executor/xpu_pre_and_post_process.py 0.00% 3 Missing ⚠️
...el_executor/layers/attention/flash_attn_backend.py 50.00% 1 Missing ⚠️
...xecutor/layers/attention/moba_attention_backend.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5779   +/-   ##
==========================================
  Coverage           ?   66.48%           
==========================================
  Files              ?      348           
  Lines              ?    44749           
  Branches           ?     6867           
==========================================
  Hits               ?    29753           
  Misses             ?    12806           
  Partials           ?     2190           
Flag Coverage Δ
GPU 66.48% <74.24%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = input_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里model_inputs已经是一个object了,还并存了dict的key-value访问方式是否合理?
原本此处逻辑seq_lens_this_time_buffer是MTPProposer的成员变量,现在又合并回了model_inputs里,是否有其他影响?

Copy link
Collaborator Author

@bukejiyu bukejiyu Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改动的地方太多了,保留了用key访问的接口,gpumodelrunner用的 InputBatch ,mtp的对象是 ProposerInputBatch 都是自己独立的 MTPProposer的成员变量 和 ProposerInputBatch 的成员变量应该就只是 多套了一层的差别吧,seq_lens_this_time_buffer也是和 req id强相关的只能放到 InputBatch内部才能参与排序

req_len = len(req_dicts)

self.model_inputs["num_running_requests"] = num_running_requests
self.model_inputs["running_requests_ids"] = range(num_running_requests)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

# self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
# self.model_inputs["seq_lens_this_time"] = self.model_inputs["seq_lens_this_time_buffer"][:num_running_requests]
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

self.proposer = NgramProposer(self.fd_config)
elif self.speculative_method == "mtp":
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与MTPProposer中的同问

@CSWYF3634076
Copy link
Collaborator

代码行数比较多,单测覆盖率需要补齐,尤其是input_batch.py

image_features_list.append(paddle.concat(merge_image_features, axis=0))
for _, index in req_idx_img_index_map.items():
if index != -1:
self.share_inputs["image_features_list"][idx] = image_features_list[index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请教下上次shape对不上是因为这里只for了一次吗,还是其他问题

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对可以理解成只for了一次,并且以前视频输入 也没有和req_id 绑定 导致重排很困难,目前是用新的list和req_id绑定上,每次append 进 image_features_list 都是某一个req_id 的图像特征

)
)

if self.encoder_cache is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

img_index = img_index + 1
inputs = request.multimodal_inputs
if self.encoder_cache is not None:
if envs.FD_ENABLE_MAX_PREFILL:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用 encoder_cache 的场景感觉也需要 feature_position_list_batches 记录每条请求的位置信息?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我不太确定也,我看他都是用append的应该是自带位置信息的吧?
edf11c35add0c40eaf6733132feeb064

@bukejiyu
Copy link
Collaborator Author

bukejiyu commented Jan 7, 2026

代码行数比较多,单测覆盖率需要补齐,尤其是input_batch.py

有的 新增了单测,目前覆盖率没有过的地方是 get_attention_meta这个函数没有命中

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants