[Feature] support fla triton kernel for qwen3.5#7024
[Feature] support fla triton kernel for qwen3.5#7024wanderHZ wants to merge 3 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 为 Qwen3.5 中的 GatedDeltaNet Attention 引入基于 Triton 的 FLA(Flash Linear Attention)推理内核实现,并补充对应单测,旨在为 Prefill/Decode 两条路径提供高性能的 GDN SSM + causal conv1d 计算。
Changes:
- 新增
fastdeploy/model_executor/ops/triton_ops/fla/FLA Triton kernel 包:涵盖 chunked prefill(WY 6-step)与 fused recurrent decode 两条路径的核心算子与索引/工具函数。 - 新增
causal_conv1d.py:提供 Prefill(varlen) 与 Decode(single-token, pool-index) 的 Triton causal conv1d 实现。 - 新增
tests/model_executor/ops/triton_ops/test_gdn_kernels.py:对 GDN recurrent/chunk 与 causal conv1d 的 kernel 输出做 baseline 对齐验证。
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/model_executor/ops/triton_ops/test_gdn_kernels.py | 新增 GDN/conv1d Triton kernel 的正确性对齐单测与纯 Paddle baseline。 |
| fastdeploy/model_executor/ops/triton_ops/fla/init.py | 导出 FLA kernel 包的 public API。 |
| fastdeploy/model_executor/ops/triton_ops/fla/utils.py | 提供 input_guard(contiguous)与简单 tensor_cache,以及 Triton 环境能力探测。 |
| fastdeploy/model_executor/ops/triton_ops/fla/op.py | Triton 侧基础数学/安全函数与 gather 能力适配。 |
| fastdeploy/model_executor/ops/triton_ops/fla/index.py | varlen chunk 索引/offset 生成工具。 |
| fastdeploy/model_executor/ops/triton_ops/fla/cumsum.py | chunk-local cumsum Triton kernel(标量/向量)。 |
| fastdeploy/model_executor/ops/triton_ops/fla/l2norm.py | L2Norm Triton kernel(推理用 forward)。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk_scaled_dot_kkt.py | chunk 内 beta·K·Kᵀ 计算 Triton kernel。 |
| fastdeploy/model_executor/ops/triton_ops/fla/solve_tril.py | (I + A)^{-1} 下三角逆的 Triton 分块实现。 |
| fastdeploy/model_executor/ops/triton_ops/fla/wy_fast.py | WY 分解中 W/U 重计算 Triton kernel + wrapper。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk_delta_h.py | chunk 间状态传播 Triton kernel + wrapper。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk_o.py | chunk 输出计算 Triton kernel + wrapper。 |
| fastdeploy/model_executor/ops/triton_ops/fla/fused_recurrent.py | Decode 路径 fused recurrent kernel(标准接口与 pool-index 接口)。 |
| fastdeploy/model_executor/ops/triton_ops/fla/chunk.py | Prefill 路径 6-step chunked WY 算法编排与 public API。 |
| fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py | causal conv1d 的 Triton Prefill/Decode 两个接口实现。 |
| beta: Optional[paddle.Tensor] = None, | ||
| scale: Optional[float] = None, | ||
| ssm_pool: Optional[paddle.Tensor] = None, | ||
| ssm_indices: Optional[paddle.Tensor] = None, | ||
| cu_seqlens: Optional[paddle.Tensor] = None, |
There was a problem hiding this comment.
fused_recurrent_gated_delta_rule_update 将 ssm_pool/ssm_indices 设为 Optional 且默认 None,但后续实现会把它们直接传入 Triton kernel(kernel 内会 tl.load(h0_indices) 并做指针写回)。若调用方未显式传这两个参数,会触发运行时崩溃而不是可读异常。建议在该 public API 入口显式检查 ssm_pool/ssm_indices 非空且形状匹配,否则 raise ValueError。
| Returns: | ||
| x: [dim, total_tokens] (channel-last layout) | ||
| weight: [dim, kernel_width] | ||
| bias: [dim,] | ||
| conv_pool: [max_seqs, dim, state_len] |
There was a problem hiding this comment.
_make_varlen_inputs 的文档写的是 “channel-last layout”,但实际 x 的形状是 [dim, total_tokens](dim 在前,更像 channel-first)。建议统一表述,避免后续按错误布局理解并在调用 causal_conv1d_fn 时传错 shape/stride。
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | ||
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) |
There was a problem hiding this comment.
在 varlen 模式(cu_seqlens != None)且 initial_state=None 时,这里用 B 来构造 dummy initial_state / initial_state_indices:_initial_state = zeros([B, H, K, V])、_initial_state_indices = arange(B)。但 chunk_gated_delta_rule_fwd_h 内部会按 N=cu_seqlens.shape[0]-1 启动 grid,并对 initial_state_indices[i_n] 做 tl.load;当 N>1 且 B==1 时会发生越界读取/写入,导致结果错误或非法内存访问。建议按 N 构造 dummy(shape=[N, H, K, V] 且 indices=arange(N)),或在 varlen+initial_state=None 时直接显式报错并要求 caller 传入 state/indices。
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | |
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) | |
| if cu_seqlens is not None: | |
| # varlen mode: grid size is N = cu_seqlens.shape[0] - 1, so | |
| # dummy initial_state/indices must be sized by N instead of B. | |
| N = cu_seqlens.shape[0] - 1 | |
| _initial_state = paddle.zeros([N, H, K, V], dtype=k.dtype) | |
| _initial_state_indices = paddle.arange(N, dtype=paddle.int32) | |
| else: | |
| _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) | |
| _initial_state_indices = paddle.arange(B, dtype=paddle.int32) |
| if scale is None: | ||
| scale = k.shape[-1] ** -0.5 | ||
| if beta is None: | ||
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV] |
There was a problem hiding this comment.
beta is None 时默认值用的是 paddle.ones(q.shape[:-1]),其形状是 [B, T, H]。但该接口文档/Kernel 实际期望 beta 为 [B, T, HV](HV 可能大于 H,用于 GQA/GVA 等场景),Kernel 内部也按 HV 进行指针步进。若 HV!=H 会导致读取越界或计算错误。建议按 v 的 head 维生成默认 beta(例如 [B, T, HV]),并最好加上形状断言(beta.shape[2]==HV)。
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV] | |
| # When beta is not provided, create an all-ones tensor with shape [B, T, HV] | |
| # HV is derived from v to properly support HV != H (e.g. GQA/GVA scenarios). | |
| beta = paddle.ones(v.shape[:3], dtype=v.dtype) | |
| else: | |
| # Validate that beta matches [B, T, HV] derived from v to avoid kernel shape mismatch | |
| if ( | |
| beta.shape[0] != v.shape[0] | |
| or beta.shape[1] != v.shape[1] | |
| or beta.shape[2] != v.shape[2] | |
| ): | |
| raise ValueError( | |
| f"beta must have shape [B, T, HV] matching v, but got " | |
| f"beta.shape={beta.shape}, v.shape[:3]={v.shape[:3]}" | |
| ) |
| if scale is None: | ||
| scale = k.shape[-1] ** -0.5 | ||
| if beta is None: | ||
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) |
There was a problem hiding this comment.
beta is None 时默认值用的是 paddle.ones(q.shape[:-1])(形状 [B, T, H]),但 update kernel 期望 beta 为 [B, T, HV](按 HV 做指针步进)。在 HV!=H(如 GVA/GQA)时会导致读取越界或计算错误。建议用 v 的 head 维生成默认 beta,并加上 beta.shape[2]==HV 的断言。
| beta = paddle.ones(q.shape[:-1], dtype=q.dtype) | |
| # Default beta should match v's head dimension: [B, T, HV] | |
| beta = paddle.ones(v.shape[:-1], dtype=v.dtype) | |
| else: | |
| # Validate beta shape to prevent out-of-bounds access in the kernel | |
| assert beta.ndim == 3, "beta must be 3D tensor of shape [B, T, HV]" | |
| assert beta.shape == v.shape[:-1], ( | |
| f"beta shape {beta.shape} must match v.shape[:-1] {v.shape[:-1]} " | |
| "for fused_recurrent_gated_delta_rule_update" | |
| ) |
Motivation
为支持 Qwen3.5 模型中 GatedDeltaNet Attention 计算操作,新增了基于 FLA (Flash Linear Attention) 的 Triton kernel 实现。
Modifications
新增文件
FLA Triton Kernel 包 (
fastdeploy/model_executor/ops/triton_ops/fla/, 13 个文件):__init__.pyutils.pyis_nvidia_hopper,is_gather_supported)、input_guard装饰器、Triton kernel 编译缓存管理op.pyexp,log,safe_exp,gather)index.pyprepare_lens,prepare_chunk_indices,prepare_chunk_offsets)cumsum.pyl2norm.pychunk_scaled_dot_kkt.pysolve_tril.pywy_fast.pychunk_delta_h.pychunk_o.pyfused_recurrent.pychunk.pyCausal Conv1d Triton Kernel (
fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py, 1 个文件):causal_conv1d_updatecausal_conv1d_fnquery_start_loc/has_initial_state参数单元测试 (
tests/model_executor/ops/triton_ops/test_gdn_kernels.py, 1 个文件):包含 11 个测试用例,覆盖 4 大类:
TestFusedRecurrentGDNTestChunkGDNTestCausalConv1dUpdateTestCausalConv1dFnUsage or Command
1. GDN SSM Kernel — Prefill (chunk algorithm)
2. GDN SSM Kernel — Decode (fused recurrent, pool-index)
3. Causal Conv1d — Decode (单 token 更新)
4. Causal Conv1d — Prefill (varlen)
运行单元测试
cd FastDeploy python -m pytest tests/model_executor/ops/triton_ops/test_gdn_kernels.py -vAccuracy Tests
基准参考实现(Pure-Paddle,从 HuggingFace Transformers 的 PyTorch 参考实现移植而来)覆盖 GDN recurrent / chunk / conv1d 三类操作。
测试精度(bf16 输入):
Checklist
[Feature]]pre-commitbefore commit.