Skip to content

[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend#7028

Open
Wanglongzhi2001 wants to merge 5 commits intoPaddlePaddle:developfrom
Wanglongzhi2001:flash_attn_dy_c8
Open

[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend#7028
Wanglongzhi2001 wants to merge 5 commits intoPaddlePaddle:developfrom
Wanglongzhi2001:flash_attn_dy_c8

Conversation

@Wanglongzhi2001
Copy link
Copy Markdown
Collaborator

@Wanglongzhi2001 Wanglongzhi2001 commented Mar 26, 2026

Motivation

Fix kv cache int8 dynamic quant on flash and flash_mask backend

Modifications

Fix kv cache int8 dynamic quant on flash and flash_mask backend

Usage or Command

Add the following config to model's config.json:

  "quantization_config": { 
      "quantization": "mix_quant",
      "kv_cache_quant_type": "block_wise_fp8",
      "dense_quant_type": "block_wise_fp8",
      "moe_quant_type": "block_wise_fp8"
  }

Then:

# Flash attn
# export FD_ATTENTION_BACKEND=FLASH_ATTN

# or flash mask attn
export FD_ATTENTION_BACKEND=FLASH_MASK_ATTN

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 26, 2026 07:26
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 26, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在修复 Flash / FlashMask 注意力后端在 KV cache 量化(尤其是 block_wise_fp8 动态 scale 场景)下的 cache 访问与反量化逻辑,并增强 softmax 在全 mask(max 为 -inf)时的数值稳定性。

Changes:

  • Flash / FlashMask 后端在 forward_mixed 中按 cache_quant_type_str 选择 KV cache 与 scale 的来源与索引方式(2layer_id vs 4layer_id)。
  • gqa_rope_write_cache.cu 的 cache 读取补齐路径(AppendCacheKV)扩展支持 block_wise_fp8,并引入 dynamic scale 读取。
  • flash_mask_attn/softmax.hppmax == -INFINITY 的场景做保护,避免 inf - inf/NaN 导致异常。

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py block_wise_fp8 下改为使用 4*layer_id 布局读取 cache 与 scales,并透传到算子
fastdeploy/model_executor/layers/attention/flash_attn_backend.py 同上:修复 flash backend 在 block_wise_fp8 下的 cache/scale 选择与透传
custom_ops/gpu_ops/flash_mask_attn/softmax.hpp softmax 在 max 为 -inf 时的数值稳定性修复
custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu AppendCacheKV 支持 block_wise_fp8 dynamic scales,并调整 kernel 入参为 optional scales

Comment on lines 1513 to 1514
cache_k_zp.get(),
cache_v_zp.get(),
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_num < kv_token_num 时这里无条件调用 cache_k_zp.get()/cache_v_zp.get(),但 block_wise_fp8/cache_int8/cache_fp8 等模式通常不会提供 zp(Python 侧传 None)。这会在 prefix caching 等场景直接触发 host 侧异常/崩溃。建议只在 cache_quant_type=="cache_int4_zp"(或确实需要 zp 的模式) 才解引用 zp;其他模式可传空指针/默认 Tensor,或把 AppendCacheKV 的 zp 入参也改为 optional 并在 int4_zp 分支强制校验。

Suggested change
cache_k_zp.get(),
cache_v_zp.get(),
cache_quant_type == "cache_int4_zp" ? cache_k_zp.get() : nullptr,
cache_quant_type == "cache_int4_zp" ? cache_v_zp.get() : nullptr,

Copilot uses AI. Check for mistakes.
Comment on lines +357 to +368
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
if cache_quant_type_str == "block_wise_fp8":
cache_k = forward_meta.caches[4 * layer.layer_id]
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
else:
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)

Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增了 block_wise_fp8 的 cache 布局分支(4 * layer_id 索引 K/V/scale),但当前测试用例主要覆盖 flash_attn_func 的确定性/算子层面,未覆盖 forward_mixed 在 block_wise_fp8 下的 cache 索引与 scale 传参路径。建议补充一个最小化单测(可 mock ForwardMeta.caches 的 4*layer_id 布局并走一次 forward_mixed 关键分支),避免未来改动再次引入索引错误。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +186 to +197
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
if cache_quant_type_str == "block_wise_fp8":
cache_k = forward_meta.caches[4 * layer.layer_id]
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
else:
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)

Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flash_mask 后端这里同样新增了 block_wise_fp8 的 cache 布局/scale 选择逻辑,但仓库中暂未发现覆盖 flash_mask_attn_backend.forward_mixed 的单测。建议补充覆盖 block_wise_fp8 的最小化用例(至少验证 caches 的 4*layer_id 索引与 scale 张量透传到 gqa_rope_write_cache/append_attention),以防回归。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines 1241 to 1245
} else if (cache_quant_type == "cache_int8" ||
cache_quant_type == "cache_fp8") {
cache_quant_type == "cache_fp8" ||
cache_quant_type == "block_wise_fp8") {
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;

Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 cache_k_quant_scales/cache_k_dequant_scales 改成 optional 并在 launch 时允许传 nullptr,但对应的 CUDA kernel 在不同 cache_quant_type 下会无条件解引用其中之一(例如 cache_int8/cache_fp8 需要 dequant_scales,block_wise_fp8 需要 quant_scales)。建议在进入该分支时对必要的 optional 做显式校验(如 PADDLE_ENFORCE/PD_THROW),避免 nullptr 解引用导致的难定位崩溃。

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 26, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@e6804ba). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7028   +/-   ##
==========================================
  Coverage           ?   73.64%           
==========================================
  Files              ?      399           
  Lines              ?    56432           
  Branches           ?     8921           
==========================================
  Hits               ?    41561           
  Misses             ?    11926           
  Partials           ?     2945           
Flag Coverage Δ
GPU 73.64% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scores_scale(mi) = (scores_max_prev(mi) == -INFINITY && scores_max_cur == -INFINITY)
? 1.f
: exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当 batch 中 seq_len_k - seq_len_q 未对齐 kBlockN=128 时,部分 Q 行的所有 K score 被 mask 为 -INF,online softmax 计算 exp(-INF - (-INF)) = NaN,会导致输出乱码

const T *cur_cache_v_scales;
T cache_k_scale = 0;
T cache_v_scale = 0;
if (dynamic_quant) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用if constexpr ()吧 编译时处理

T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
T cache_k_scale_0 = cache_k_scale;
T cache_k_scale_1 = cache_k_scale;
if (dynamic_quant) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

T cache_v_scale_2 = cache_v_scale;
T cache_v_scale_3 = cache_v_scale;

if (dynamic_quant) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Comment on lines +161 to +162
t = paddle.zeros([1], dtype="float32")
t._sentinel_name = name
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_make_sentinel()paddle.Tensor 上设置自定义属性(t._sentinel_name = ...)很可能会因 Tensor 不支持动态属性而在导入/运行测试时直接报错;同时该属性在后续断言中也未被使用。建议移除此赋值,或用独立的 Python 侧映射(例如 dict[id(tensor)] -> name)来标识哨兵对象。

Suggested change
t = paddle.zeros([1], dtype="float32")
t._sentinel_name = name
# Note: we avoid attaching custom attributes to paddle.Tensor to keep it backend-safe.
t = paddle.zeros([1], dtype="float32")

Copilot uses AI. Check for mistakes.
Comment on lines +888 to +892
def test_only_block_wise_fp8_triggers_4x(self):
all_types = ["none", "cache_int8", "cache_fp8", "cache_int4_zp", "block_wise_fp8"]
for qt in all_types:
self.assertEqual(qt == "block_wise_fp8", qt == "block_wise_fp8")

Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_only_block_wise_fp8_triggers_4x 里的断言是恒等式(左右完全相同),该测试用例永远不会失败,无法验证逻辑是否正确。建议改为真正检查不同 quant type 下选择的 cache 索引(例如对比 2x/4x 路由结果),或直接删除该无效用例以免产生“有覆盖但实际没测到”的误导。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +17 to +30
"""
Unit tests for the KV cache int8 dynamic quant fix on flash_attn_backend
and flash_mask_attn_backend (commit 584df2ba8).

The fix ensures that when cache_quant_type_str == "block_wise_fp8":
- cache_k/v are taken from caches[4*layer_id : 4*layer_id+2]
- cache_k/v_scales are taken from caches[4*layer_id+2 : 4*layer_id+4]
Otherwise (non-dynamic-quant):
- cache_k/v are taken from caches[2*layer_id : 2*layer_id+2]
- cache_k/v_scales are taken from layer.cache_k_scale / layer.cache_v_scale

Strategy: We mock the entire fastdeploy import chain and the external op
functions, then verify the correct cache tensors are routed through.
"""
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 描述目前只填写了 Motivation,Modifications / Usage or Command / Accuracy Tests 等章节仍为空,Checklist 也未更新。由于该 PR 涉及 CUDA kernel 与注意力后端行为变更,建议补充:具体改动点摘要、复现/验证命令、以及至少一组精度对齐/回归结果(或说明为何无法提供)。

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 27, 2026 07:23
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.



# Mock problematic transitive dependencies that may be missing in some environments
_ensure_mock_module("aistudio_sdk.snapshot_download", {"snapshot_download": lambda *a, **kw: None})
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 mock 了 aistudio_sdk.snapshot_download,但没有同时创建父模块 aistudio_sdk。当代码执行 from aistudio_sdk.snapshot_download import ... 时,通常会先导入 aistudio_sdk,缺失父模块可能导致 ImportError,从而让本文件的大部分测试被 skip。建议同时 _ensure_mock_module("aistudio_sdk") 并把子模块挂到父模块上,或直接 mock aistudio_sdk 包。

Suggested change
_ensure_mock_module("aistudio_sdk.snapshot_download", {"snapshot_download": lambda *a, **kw: None})
# Ensure both the parent package and the snapshot_download submodule exist.
_aistudio_pkg = _ensure_mock_module("aistudio_sdk")
_aistudio_snapshot_module = _ensure_mock_module(
"aistudio_sdk.snapshot_download",
{"snapshot_download": lambda *args, **kwargs: None},
)
setattr(_aistudio_pkg, "snapshot_download", _aistudio_snapshot_module)

Copilot uses AI. Check for mistakes.
Comment on lines +858 to +863
def test_static_quant_null_quant_scales(self):
"""Static quant: quant_scales=None, dequant_scales provided."""
self.assertIsNone(None) # quant_scales
self.assertIsNotNone(np.ones(4)) # dequant_scales


Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一组断言基本是“自证为真”(例如 assertIsNone(None)assertIsNotNone(np.ones(4))),并没有验证任何被测代码路径/行为,容易造成测试噪音且掩盖真正回归。建议删除这些无效断言,或改为针对实际函数/参数映射的可观测行为断言。

Suggested change
def test_static_quant_null_quant_scales(self):
"""Static quant: quant_scales=None, dequant_scales provided."""
self.assertIsNone(None) # quant_scales
self.assertIsNotNone(np.ones(4)) # dequant_scales

Copilot uses AI. Check for mistakes.
Comment on lines +1275 to +1296
launchWithPdlWhenEnabled(
kernel_func,
grids,
blocks,
smem_size,
stream,
cache_k.data<uint8_t>(),
cache_v.data<uint8_t>(),
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
cache_k_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_k_quant_scales.get().data<T>()))
: nullptr,
cache_v_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_v_quant_scales.get().data<T>()))
: nullptr,
cache_k_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_k_dequant_scales.get().data<T>()))
: nullptr,
cache_v_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
cache_v_dequant_scales.get().data<T>()))
: nullptr,
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AppendCacheKV 这里把 quant/dequant scales 改成了 optional,并在 launch 时可能传入 nullptr。但 append_cache_kv_c8dynamic_quant=false 的编译分支里会无条件读取 cache_*_dequant_scales[kv_head_idx],一旦上层传了 None 会导致 CUDA kernel 空指针解引用(更难定位)。建议在进入该分支时用 PADDLE_ENFORCE/PD_CHECK 显式校验:

  • cache_int8/cache_fp8 必须提供 dequant_scales
  • block_wise_fp8 必须提供 quant_scales
    并在必要时给出清晰错误信息。

Copilot uses AI. Check for mistakes.
Comment on lines 1242 to +1244
} else if (cache_quant_type == "cache_int8" ||
cache_quant_type == "cache_fp8") {
cache_quant_type == "cache_fp8" ||
cache_quant_type == "block_wise_fp8") {
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增了对 block_wise_fp8 的支持,但文件后面 cache_quant_type_str should be one of ... 的报错文案如果未同步加入 block_wise_fp8,会导致用户在传入该类型时看到误导性的错误列表。建议同步更新相关 PD_THROW/PADDLE_THROW 的可选项提示,保持行为与错误信息一致。

Copilot uses AI. Check for mistakes.
Comment on lines +17 to +20
"""
Unit tests for the KV cache int8 dynamic quant fix on flash_attn_backend
and flash_mask_attn_backend (commit 584df2ba8).

Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里连续出现了两个三引号字符串:文件开头的 license 文本已经占用 module docstring,后面的这一段会变成“无效果的顶层字符串语句”(只产生一次字符串常量求值),容易触发 lint/格式化规则且增加困惑。建议把 license 改为普通注释,并仅保留一个真正的模块 docstring。

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants