Skip to content

[Feature] support fla triton kernel for qwen3.5#7024

Open
wanderHZ wants to merge 3 commits intoPaddlePaddle:developfrom
wanderHZ:add_triton_kernel
Open

[Feature] support fla triton kernel for qwen3.5#7024
wanderHZ wants to merge 3 commits intoPaddlePaddle:developfrom
wanderHZ:add_triton_kernel

Conversation

@wanderHZ
Copy link
Copy Markdown

@wanderHZ wanderHZ commented Mar 26, 2026

Motivation

为支持 Qwen3.5 模型中 GatedDeltaNet Attention 计算操作,新增了基于 FLA (Flash Linear Attention) 的 Triton kernel 实现。

Modifications

新增文件

FLA Triton Kernel 包 (fastdeploy/model_executor/ops/triton_ops/fla/, 13 个文件):

文件 功能说明
__init__.py 包入口,导出所有公共 API
utils.py 环境检测 (is_nvidia_hopper, is_gather_supported)、input_guard 装饰器、Triton kernel 编译缓存管理
op.py 基础 Triton 操作辅助 (exp, log, safe_exp, gather)
index.py varlen 模式索引工具 (prepare_lens, prepare_chunk_indices, prepare_chunk_offsets)
cumsum.py chunk-local prefix 累积和 Triton kernel(标量 3D / 向量 4D 两种)
l2norm.py L2 归一化 Triton kernel
chunk_scaled_dot_kkt.py chunk 内 β·K·Kᵀ 计算 Triton kernel
solve_tril.py 下三角矩阵求逆 Triton kernel(16×16 分块)
wy_fast.py WY 表示中的 w/u 重计算 Triton kernel
chunk_delta_h.py chunk 间状态传播 Triton kernel(K 维 64-block 展开)
chunk_o.py chunk 输出计算 Triton kernel
fused_recurrent.py Decode 路径 fused recurrent kernel(标准接口 + pool-index 接口),支持 PAD_SLOT_ID=-1 sentinel
chunk.py Prefill 路径 6-step chunked WY 算法编排(调用上述 kernel 组合)

Causal Conv1d Triton Kernel (fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py, 1 个文件):

函数 功能说明
causal_conv1d_update Decode 单 token 因果卷积更新,基于 conv_state pool + slot_ids 的 in-place 操作
causal_conv1d_fn Prefill varlen 因果卷积,支持 query_start_loc / has_initial_state 参数

单元测试 (tests/model_executor/ops/triton_ops/test_gdn_kernels.py, 1 个文件):

包含 11 个测试用例,覆盖 4 大类:

测试类 用例数 说明
TestFusedRecurrentGDN 4 Decode fused recurrent:无状态、L2 norm、final state 输出、带 initial state
TestChunkGDN 3 Prefill chunk:无状态、L2 norm、chunk-recurrent 一致性交叉验证
TestCausalConv1dUpdate 3 Decode conv:无 bias、有 bias、state pool in-place 更新验证
TestCausalConv1dFn 1 Prefill varlen conv:无初始状态

Usage or Command

1. GDN SSM Kernel — Prefill (chunk algorithm)

from fastdeploy.model_executor.ops.triton_ops.fla import chunk_gated_delta_rule

# q, k: [B, T, H, K]    — query / key
# v:    [B, T, HV, V]    — value (HV >= H for GVA)
# g:    [B, T, H]        — log decay (负值)
# beta: [B, T, H]        — write gate [0, 1]
# initial_state: [N, H, K, V] or None — SSM 初始状态
o, final_state = chunk_gated_delta_rule(
    q, k, v, g, beta,
    scale=None,                    # 默认 1/sqrt(K)
    initial_state=initial_state,
    output_final_state=True,
    use_qk_l2norm_in_kernel=True,
)
# o: [B, T, HV, V]
# final_state: [N, H, K, V]

2. GDN SSM Kernel — Decode (fused recurrent, pool-index)

from fastdeploy.model_executor.ops.triton_ops.fla import fused_recurrent_gated_delta_rule_update

# ssm_pool: [max_seqs, HV, K, V] — SSM 状态池 (in-place 读写)
# ssm_indices: [N] int32 — 每个请求对应的 pool slot 索引, PAD_SLOT_ID=-1 安全
o = fused_recurrent_gated_delta_rule_update(
    q, k, v, g, beta,
    ssm_pool=ssm_pool,
    ssm_indices=ssm_indices,
)
# o: [B, T, HV, V]   (T=1 for Decode)
# ssm_pool 已被 in-place 更新

3. Causal Conv1d — Decode (单 token 更新)

from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import causal_conv1d_update

# x: [batch, dim]
# conv_state: [max_seqs, dim, state_len] — conv 状态池 (in-place 更新)
# conv_state_indices: [batch] int32 — pool slot 索引
out = causal_conv1d_update(
    x, conv_state, weight,
    bias=bias,
    activation="silu",
    conv_state_indices=slot_ids,
)
# out: [batch, dim]

4. Causal Conv1d — Prefill (varlen)

from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import causal_conv1d_fn

# x: [dim, cu_seqlen] — 所有序列拼接 (channel-first)
# conv_states: [max_seqs, dim, width-1] — conv 状态池 (in-place 更新)
# query_start_loc: [N+1] int32 — 每个序列在 x 中的起始位置
# has_initial_state: [N] bool
out = causal_conv1d_fn(
    x, weight, bias, conv_states,
    query_start_loc, seq_lens_cpu,
    cache_indices=slot_ids,
    has_initial_state=has_init,
    activation="silu",
)
# out: [dim, cu_seqlen]

运行单元测试

cd FastDeploy
python -m pytest tests/model_executor/ops/triton_ops/test_gdn_kernels.py -v

Accuracy Tests

基准参考实现(Pure-Paddle,从 HuggingFace Transformers 的 PyTorch 参考实现移植而来)覆盖 GDN recurrent / chunk / conv1d 三类操作。

测试精度(bf16 输入):

测试项 rtol atol 状态
fused_recurrent (no state) 1e-2 1e-2 PASS
fused_recurrent (l2norm) 1e-3 1e-3 PASS
fused_recurrent (final state) 1e-3 1e-3 PASS
fused_recurrent (initial state) 1e-2 1e-2 PASS
chunk (no state) 2e-2 2e-2 PASS
chunk (l2norm) 2e-2 2e-2 PASS
chunk vs recurrent consistency 2e-2 2e-2 PASS
causal_conv1d_update (no bias) 1e-2 1e-2 PASS
causal_conv1d_update (with bias) 1e-2 1e-2 PASS
causal_conv1d_update (state inplace) 1e-3 1e-3 PASS
causal_conv1d_fn (no initial state) 2e-2 5e-2 PASS

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[Feature]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.

Copilot AI review requested due to automatic review settings March 26, 2026 03:52
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 26, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 为 Qwen3.5 中的 GatedDeltaNet Attention 引入基于 Triton 的 FLA(Flash Linear Attention)推理内核实现,并补充对应单测,旨在为 Prefill/Decode 两条路径提供高性能的 GDN SSM + causal conv1d 计算。

Changes:

  • 新增 fastdeploy/model_executor/ops/triton_ops/fla/ FLA Triton kernel 包:涵盖 chunked prefill(WY 6-step)与 fused recurrent decode 两条路径的核心算子与索引/工具函数。
  • 新增 causal_conv1d.py:提供 Prefill(varlen) 与 Decode(single-token, pool-index) 的 Triton causal conv1d 实现。
  • 新增 tests/model_executor/ops/triton_ops/test_gdn_kernels.py:对 GDN recurrent/chunk 与 causal conv1d 的 kernel 输出做 baseline 对齐验证。

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/model_executor/ops/triton_ops/test_gdn_kernels.py 新增 GDN/conv1d Triton kernel 的正确性对齐单测与纯 Paddle baseline。
fastdeploy/model_executor/ops/triton_ops/fla/init.py 导出 FLA kernel 包的 public API。
fastdeploy/model_executor/ops/triton_ops/fla/utils.py 提供 input_guard(contiguous)与简单 tensor_cache,以及 Triton 环境能力探测。
fastdeploy/model_executor/ops/triton_ops/fla/op.py Triton 侧基础数学/安全函数与 gather 能力适配。
fastdeploy/model_executor/ops/triton_ops/fla/index.py varlen chunk 索引/offset 生成工具。
fastdeploy/model_executor/ops/triton_ops/fla/cumsum.py chunk-local cumsum Triton kernel(标量/向量)。
fastdeploy/model_executor/ops/triton_ops/fla/l2norm.py L2Norm Triton kernel(推理用 forward)。
fastdeploy/model_executor/ops/triton_ops/fla/chunk_scaled_dot_kkt.py chunk 内 beta·K·Kᵀ 计算 Triton kernel。
fastdeploy/model_executor/ops/triton_ops/fla/solve_tril.py (I + A)^{-1} 下三角逆的 Triton 分块实现。
fastdeploy/model_executor/ops/triton_ops/fla/wy_fast.py WY 分解中 W/U 重计算 Triton kernel + wrapper。
fastdeploy/model_executor/ops/triton_ops/fla/chunk_delta_h.py chunk 间状态传播 Triton kernel + wrapper。
fastdeploy/model_executor/ops/triton_ops/fla/chunk_o.py chunk 输出计算 Triton kernel + wrapper。
fastdeploy/model_executor/ops/triton_ops/fla/fused_recurrent.py Decode 路径 fused recurrent kernel(标准接口与 pool-index 接口)。
fastdeploy/model_executor/ops/triton_ops/fla/chunk.py Prefill 路径 6-step chunked WY 算法编排与 public API。
fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py causal conv1d 的 Triton Prefill/Decode 两个接口实现。

Comment on lines +509 to +513
beta: Optional[paddle.Tensor] = None,
scale: Optional[float] = None,
ssm_pool: Optional[paddle.Tensor] = None,
ssm_indices: Optional[paddle.Tensor] = None,
cu_seqlens: Optional[paddle.Tensor] = None,
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_recurrent_gated_delta_rule_update 将 ssm_pool/ssm_indices 设为 Optional 且默认 None,但后续实现会把它们直接传入 Triton kernel(kernel 内会 tl.load(h0_indices) 并做指针写回)。若调用方未显式传这两个参数,会触发运行时崩溃而不是可读异常。建议在该 public API 入口显式检查 ssm_pool/ssm_indices 非空且形状匹配,否则 raise ValueError。

Copilot uses AI. Check for mistakes.
Comment on lines +731 to +735
Returns:
x: [dim, total_tokens] (channel-last layout)
weight: [dim, kernel_width]
bias: [dim,]
conv_pool: [max_seqs, dim, state_len]
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_make_varlen_inputs 的文档写的是 “channel-last layout”,但实际 x 的形状是 [dim, total_tokens](dim 在前,更像 channel-first)。建议统一表述,避免后续按错误布局理解并在调用 causal_conv1d_fn 时传错 shape/stride。

Copilot uses AI. Check for mistakes.
Comment on lines +104 to +105
_initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype)
_initial_state_indices = paddle.arange(B, dtype=paddle.int32)
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 varlen 模式(cu_seqlens != None)且 initial_state=None 时,这里用 B 来构造 dummy initial_state / initial_state_indices:_initial_state = zeros([B, H, K, V])_initial_state_indices = arange(B)。但 chunk_gated_delta_rule_fwd_h 内部会按 N=cu_seqlens.shape[0]-1 启动 grid,并对 initial_state_indices[i_n] 做 tl.load;当 N>1 且 B==1 时会发生越界读取/写入,导致结果错误或非法内存访问。建议按 N 构造 dummy(shape=[N, H, K, V] 且 indices=arange(N)),或在 varlen+initial_state=None 时直接显式报错并要求 caller 传入 state/indices。

Suggested change
_initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype)
_initial_state_indices = paddle.arange(B, dtype=paddle.int32)
if cu_seqlens is not None:
# varlen mode: grid size is N = cu_seqlens.shape[0] - 1, so
# dummy initial_state/indices must be sized by N instead of B.
N = cu_seqlens.shape[0] - 1
_initial_state = paddle.zeros([N, H, K, V], dtype=k.dtype)
_initial_state_indices = paddle.arange(N, dtype=paddle.int32)
else:
_initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype)
_initial_state_indices = paddle.arange(B, dtype=paddle.int32)

Copilot uses AI. Check for mistakes.
if scale is None:
scale = k.shape[-1] ** -0.5
if beta is None:
beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV]
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beta is None 时默认值用的是 paddle.ones(q.shape[:-1]),其形状是 [B, T, H]。但该接口文档/Kernel 实际期望 beta 为 [B, T, HV](HV 可能大于 H,用于 GQA/GVA 等场景),Kernel 内部也按 HV 进行指针步进。若 HV!=H 会导致读取越界或计算错误。建议按 v 的 head 维生成默认 beta(例如 [B, T, HV]),并最好加上形状断言(beta.shape[2]==HV)。

Suggested change
beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV]
# When beta is not provided, create an all-ones tensor with shape [B, T, HV]
# HV is derived from v to properly support HV != H (e.g. GQA/GVA scenarios).
beta = paddle.ones(v.shape[:3], dtype=v.dtype)
else:
# Validate that beta matches [B, T, HV] derived from v to avoid kernel shape mismatch
if (
beta.shape[0] != v.shape[0]
or beta.shape[1] != v.shape[1]
or beta.shape[2] != v.shape[2]
):
raise ValueError(
f"beta must have shape [B, T, HV] matching v, but got "
f"beta.shape={beta.shape}, v.shape[:3]={v.shape[:3]}"
)

Copilot uses AI. Check for mistakes.
if scale is None:
scale = k.shape[-1] ** -0.5
if beta is None:
beta = paddle.ones(q.shape[:-1], dtype=q.dtype)
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beta is None 时默认值用的是 paddle.ones(q.shape[:-1])(形状 [B, T, H]),但 update kernel 期望 beta 为 [B, T, HV](按 HV 做指针步进)。在 HV!=H(如 GVA/GQA)时会导致读取越界或计算错误。建议用 v 的 head 维生成默认 beta,并加上 beta.shape[2]==HV 的断言。

Suggested change
beta = paddle.ones(q.shape[:-1], dtype=q.dtype)
# Default beta should match v's head dimension: [B, T, HV]
beta = paddle.ones(v.shape[:-1], dtype=v.dtype)
else:
# Validate beta shape to prevent out-of-bounds access in the kernel
assert beta.ndim == 3, "beta must be 3D tensor of shape [B, T, HV]"
assert beta.shape == v.shape[:-1], (
f"beta shape {beta.shape} must match v.shape[:-1] {v.shape[:-1]} "
"for fused_recurrent_gated_delta_rule_update"
)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants