[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend#7028
[BugFix] Fix kv cache int8 dynamic quant on flash and flash_mask backend#7028Wanglongzhi2001 wants to merge 5 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在修复 Flash / FlashMask 注意力后端在 KV cache 量化(尤其是 block_wise_fp8 动态 scale 场景)下的 cache 访问与反量化逻辑,并增强 softmax 在全 mask(max 为 -inf)时的数值稳定性。
Changes:
- Flash / FlashMask 后端在
forward_mixed中按cache_quant_type_str选择 KV cache 与 scale 的来源与索引方式(2layer_id vs 4layer_id)。 gqa_rope_write_cache.cu的 cache 读取补齐路径(AppendCacheKV)扩展支持block_wise_fp8,并引入 dynamic scale 读取。flash_mask_attn/softmax.hpp对max == -INFINITY的场景做保护,避免inf - inf/NaN导致异常。
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py | block_wise_fp8 下改为使用 4*layer_id 布局读取 cache 与 scales,并透传到算子 |
| fastdeploy/model_executor/layers/attention/flash_attn_backend.py | 同上:修复 flash backend 在 block_wise_fp8 下的 cache/scale 选择与透传 |
| custom_ops/gpu_ops/flash_mask_attn/softmax.hpp | softmax 在 max 为 -inf 时的数值稳定性修复 |
| custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu | AppendCacheKV 支持 block_wise_fp8 dynamic scales,并调整 kernel 入参为 optional scales |
| cache_k_zp.get(), | ||
| cache_v_zp.get(), |
There was a problem hiding this comment.
token_num < kv_token_num 时这里无条件调用 cache_k_zp.get()/cache_v_zp.get(),但 block_wise_fp8/cache_int8/cache_fp8 等模式通常不会提供 zp(Python 侧传 None)。这会在 prefix caching 等场景直接触发 host 侧异常/崩溃。建议只在 cache_quant_type=="cache_int4_zp"(或确实需要 zp 的模式) 才解引用 zp;其他模式可传空指针/默认 Tensor,或把 AppendCacheKV 的 zp 入参也改为 optional 并在 int4_zp 分支强制校验。
| cache_k_zp.get(), | |
| cache_v_zp.get(), | |
| cache_quant_type == "cache_int4_zp" ? cache_k_zp.get() : nullptr, | |
| cache_quant_type == "cache_int4_zp" ? cache_v_zp.get() : nullptr, |
| cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") | ||
| if cache_quant_type_str == "block_wise_fp8": | ||
| cache_k = forward_meta.caches[4 * layer.layer_id] | ||
| cache_v = forward_meta.caches[4 * layer.layer_id + 1] | ||
| cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] | ||
| cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] | ||
| else: | ||
| cache_k = forward_meta.caches[2 * layer.layer_id] | ||
| cache_v = forward_meta.caches[2 * layer.layer_id + 1] | ||
| cache_k_scales = getattr(layer, "cache_k_scale", None) | ||
| cache_v_scales = getattr(layer, "cache_v_scale", None) | ||
|
|
There was a problem hiding this comment.
这里新增了 block_wise_fp8 的 cache 布局分支(4 * layer_id 索引 K/V/scale),但当前测试用例主要覆盖 flash_attn_func 的确定性/算子层面,未覆盖 forward_mixed 在 block_wise_fp8 下的 cache 索引与 scale 传参路径。建议补充一个最小化单测(可 mock ForwardMeta.caches 的 4*layer_id 布局并走一次 forward_mixed 关键分支),避免未来改动再次引入索引错误。
| cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") | ||
| if cache_quant_type_str == "block_wise_fp8": | ||
| cache_k = forward_meta.caches[4 * layer.layer_id] | ||
| cache_v = forward_meta.caches[4 * layer.layer_id + 1] | ||
| cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] | ||
| cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] | ||
| else: | ||
| cache_k = forward_meta.caches[2 * layer.layer_id] | ||
| cache_v = forward_meta.caches[2 * layer.layer_id + 1] | ||
| cache_k_scales = getattr(layer, "cache_k_scale", None) | ||
| cache_v_scales = getattr(layer, "cache_v_scale", None) | ||
|
|
There was a problem hiding this comment.
flash_mask 后端这里同样新增了 block_wise_fp8 的 cache 布局/scale 选择逻辑,但仓库中暂未发现覆盖 flash_mask_attn_backend.forward_mixed 的单测。建议补充覆盖 block_wise_fp8 的最小化用例(至少验证 caches 的 4*layer_id 索引与 scale 张量透传到 gqa_rope_write_cache/append_attention),以防回归。
| } else if (cache_quant_type == "cache_int8" || | ||
| cache_quant_type == "cache_fp8") { | ||
| cache_quant_type == "cache_fp8" || | ||
| cache_quant_type == "block_wise_fp8") { | ||
| const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; | ||
|
|
There was a problem hiding this comment.
这里把 cache_k_quant_scales/cache_k_dequant_scales 改成 optional 并在 launch 时允许传 nullptr,但对应的 CUDA kernel 在不同 cache_quant_type 下会无条件解引用其中之一(例如 cache_int8/cache_fp8 需要 dequant_scales,block_wise_fp8 需要 quant_scales)。建议在进入该分支时对必要的 optional 做显式校验(如 PADDLE_ENFORCE/PD_THROW),避免 nullptr 解引用导致的难定位崩溃。
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #7028 +/- ##
==========================================
Coverage ? 73.64%
==========================================
Files ? 399
Lines ? 56432
Branches ? 8921
==========================================
Hits ? 41561
Misses ? 11926
Partials ? 2945
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); | ||
| scores_scale(mi) = (scores_max_prev(mi) == -INFINITY && scores_max_cur == -INFINITY) | ||
| ? 1.f | ||
| : exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); |
There was a problem hiding this comment.
当 batch 中 seq_len_k - seq_len_q 未对齐 kBlockN=128 时,部分 Q 行的所有 K score 被 mask 为 -INF,online softmax 计算 exp(-INF - (-INF)) = NaN,会导致输出乱码
| const T *cur_cache_v_scales; | ||
| T cache_k_scale = 0; | ||
| T cache_v_scale = 0; | ||
| if (dynamic_quant) { |
There was a problem hiding this comment.
用if constexpr ()吧 编译时处理
| T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; | ||
| T cache_k_scale_0 = cache_k_scale; | ||
| T cache_k_scale_1 = cache_k_scale; | ||
| if (dynamic_quant) { |
| T cache_v_scale_2 = cache_v_scale; | ||
| T cache_v_scale_3 = cache_v_scale; | ||
|
|
||
| if (dynamic_quant) { |
| t = paddle.zeros([1], dtype="float32") | ||
| t._sentinel_name = name |
There was a problem hiding this comment.
_make_sentinel() 在 paddle.Tensor 上设置自定义属性(t._sentinel_name = ...)很可能会因 Tensor 不支持动态属性而在导入/运行测试时直接报错;同时该属性在后续断言中也未被使用。建议移除此赋值,或用独立的 Python 侧映射(例如 dict[id(tensor)] -> name)来标识哨兵对象。
| t = paddle.zeros([1], dtype="float32") | |
| t._sentinel_name = name | |
| # Note: we avoid attaching custom attributes to paddle.Tensor to keep it backend-safe. | |
| t = paddle.zeros([1], dtype="float32") |
| def test_only_block_wise_fp8_triggers_4x(self): | ||
| all_types = ["none", "cache_int8", "cache_fp8", "cache_int4_zp", "block_wise_fp8"] | ||
| for qt in all_types: | ||
| self.assertEqual(qt == "block_wise_fp8", qt == "block_wise_fp8") | ||
|
|
There was a problem hiding this comment.
test_only_block_wise_fp8_triggers_4x 里的断言是恒等式(左右完全相同),该测试用例永远不会失败,无法验证逻辑是否正确。建议改为真正检查不同 quant type 下选择的 cache 索引(例如对比 2x/4x 路由结果),或直接删除该无效用例以免产生“有覆盖但实际没测到”的误导。
| """ | ||
| Unit tests for the KV cache int8 dynamic quant fix on flash_attn_backend | ||
| and flash_mask_attn_backend (commit 584df2ba8). | ||
|
|
||
| The fix ensures that when cache_quant_type_str == "block_wise_fp8": | ||
| - cache_k/v are taken from caches[4*layer_id : 4*layer_id+2] | ||
| - cache_k/v_scales are taken from caches[4*layer_id+2 : 4*layer_id+4] | ||
| Otherwise (non-dynamic-quant): | ||
| - cache_k/v are taken from caches[2*layer_id : 2*layer_id+2] | ||
| - cache_k/v_scales are taken from layer.cache_k_scale / layer.cache_v_scale | ||
|
|
||
| Strategy: We mock the entire fastdeploy import chain and the external op | ||
| functions, then verify the correct cache tensors are routed through. | ||
| """ |
There was a problem hiding this comment.
PR 描述目前只填写了 Motivation,Modifications / Usage or Command / Accuracy Tests 等章节仍为空,Checklist 也未更新。由于该 PR 涉及 CUDA kernel 与注意力后端行为变更,建议补充:具体改动点摘要、复现/验证命令、以及至少一组精度对齐/回归结果(或说明为何无法提供)。
|
|
||
|
|
||
| # Mock problematic transitive dependencies that may be missing in some environments | ||
| _ensure_mock_module("aistudio_sdk.snapshot_download", {"snapshot_download": lambda *a, **kw: None}) |
There was a problem hiding this comment.
这里 mock 了 aistudio_sdk.snapshot_download,但没有同时创建父模块 aistudio_sdk。当代码执行 from aistudio_sdk.snapshot_download import ... 时,通常会先导入 aistudio_sdk,缺失父模块可能导致 ImportError,从而让本文件的大部分测试被 skip。建议同时 _ensure_mock_module("aistudio_sdk") 并把子模块挂到父模块上,或直接 mock aistudio_sdk 包。
| _ensure_mock_module("aistudio_sdk.snapshot_download", {"snapshot_download": lambda *a, **kw: None}) | |
| # Ensure both the parent package and the snapshot_download submodule exist. | |
| _aistudio_pkg = _ensure_mock_module("aistudio_sdk") | |
| _aistudio_snapshot_module = _ensure_mock_module( | |
| "aistudio_sdk.snapshot_download", | |
| {"snapshot_download": lambda *args, **kwargs: None}, | |
| ) | |
| setattr(_aistudio_pkg, "snapshot_download", _aistudio_snapshot_module) |
| def test_static_quant_null_quant_scales(self): | ||
| """Static quant: quant_scales=None, dequant_scales provided.""" | ||
| self.assertIsNone(None) # quant_scales | ||
| self.assertIsNotNone(np.ones(4)) # dequant_scales | ||
|
|
||
|
|
There was a problem hiding this comment.
这一组断言基本是“自证为真”(例如 assertIsNone(None)、assertIsNotNone(np.ones(4))),并没有验证任何被测代码路径/行为,容易造成测试噪音且掩盖真正回归。建议删除这些无效断言,或改为针对实际函数/参数映射的可观测行为断言。
| def test_static_quant_null_quant_scales(self): | |
| """Static quant: quant_scales=None, dequant_scales provided.""" | |
| self.assertIsNone(None) # quant_scales | |
| self.assertIsNotNone(np.ones(4)) # dequant_scales |
| launchWithPdlWhenEnabled( | ||
| kernel_func, | ||
| grids, | ||
| blocks, | ||
| smem_size, | ||
| stream, | ||
| cache_k.data<uint8_t>(), | ||
| cache_v.data<uint8_t>(), | ||
| reinterpret_cast<NV_TYPE *>(k_out->data<T>()), | ||
| reinterpret_cast<NV_TYPE *>(v_out->data<T>()), | ||
| cache_k_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( | ||
| cache_k_quant_scales.get().data<T>())) | ||
| : nullptr, | ||
| cache_v_quant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( | ||
| cache_v_quant_scales.get().data<T>())) | ||
| : nullptr, | ||
| cache_k_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( | ||
| cache_k_dequant_scales.get().data<T>())) | ||
| : nullptr, | ||
| cache_v_dequant_scales ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( | ||
| cache_v_dequant_scales.get().data<T>())) | ||
| : nullptr, |
There was a problem hiding this comment.
AppendCacheKV 这里把 quant/dequant scales 改成了 optional,并在 launch 时可能传入 nullptr。但 append_cache_kv_c8 在 dynamic_quant=false 的编译分支里会无条件读取 cache_*_dequant_scales[kv_head_idx],一旦上层传了 None 会导致 CUDA kernel 空指针解引用(更难定位)。建议在进入该分支时用 PADDLE_ENFORCE/PD_CHECK 显式校验:
- cache_int8/cache_fp8 必须提供 dequant_scales
- block_wise_fp8 必须提供 quant_scales
并在必要时给出清晰错误信息。
| } else if (cache_quant_type == "cache_int8" || | ||
| cache_quant_type == "cache_fp8") { | ||
| cache_quant_type == "cache_fp8" || | ||
| cache_quant_type == "block_wise_fp8") { |
There was a problem hiding this comment.
这里新增了对 block_wise_fp8 的支持,但文件后面 cache_quant_type_str should be one of ... 的报错文案如果未同步加入 block_wise_fp8,会导致用户在传入该类型时看到误导性的错误列表。建议同步更新相关 PD_THROW/PADDLE_THROW 的可选项提示,保持行为与错误信息一致。
| """ | ||
| Unit tests for the KV cache int8 dynamic quant fix on flash_attn_backend | ||
| and flash_mask_attn_backend (commit 584df2ba8). | ||
|
|
There was a problem hiding this comment.
这里连续出现了两个三引号字符串:文件开头的 license 文本已经占用 module docstring,后面的这一段会变成“无效果的顶层字符串语句”(只产生一次字符串常量求值),容易触发 lint/格式化规则且增加困惑。建议把 license 改为普通注释,并仅保留一个真正的模块 docstring。
Motivation
Fix kv cache int8 dynamic quant on flash and flash_mask backend
Modifications
Fix kv cache int8 dynamic quant on flash and flash_mask backend
Usage or Command
Add the following config to model's config.json:
Then:
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.