diff --git a/fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py b/fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py new file mode 100644 index 00000000000..f9a3c13b93d --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/causal_conv1d.py @@ -0,0 +1,625 @@ +# Adapt from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +# Original: Copyright (c) 2024, Tri Dao (Apache License 2.0) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +Causal Conv1d Triton Kernels — FastDeploy edition (GDN Prefill/Decode path). + +Public API: + causal_conv1d_fn(x, weight, bias, conv_states, query_start_loc, + seq_lens_cpu, cache_indices, has_initial_state, activation) + x: (dim, cu_seqlen) - all sequences concatenated + weight: (dim, width) + bias: (dim,) or None + conv_states: [max_seqs, dim, width-1] (pool, in-place update) + query_start_loc: [N+1] int32 + seq_lens_cpu: List[int] + cache_indices: [N] int32 (slot index) + has_initial_state: [N] bool + activation: "silu" or None + → out: (dim, cu_seqlen) + + causal_conv1d_update(x, conv_state, weight, bias, activation, conv_state_indices) + x: (batch, dim) + conv_state: [max_seqs, dim, state_len] (pool, in-place update) + weight: (dim, width) + bias: (dim,) or None + activation: "silu" or None + conv_state_indices: [batch] int32 (slot index) + → out: (batch, dim) +""" + +from typing import List, Optional, Union + +import paddle +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +# ============================================================ +# Prefill kernel (unchanged from SGLang) +# ============================================================ + + +@triton.jit() +def _causal_conv1d_fwd_kernel( + x_ptr, # (dim, cu_seqlen) + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr: [max_seqs, dim, width-1] + cache_indices_ptr, # conv_state_indices_ptr: [N] + has_initial_states_ptr, # [N] bool + query_start_loc_ptr, # [N+1] + o_ptr, # (dim, cu_seqlen) + # dimensions + dim: tl.constexpr, + seqlen: tl.int32, + num_cache_lines: tl.constexpr, + # strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # meta + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = KERNEL_WIDTH - 1 + + idx_seq = tl.program_id(0) + chunk_offset = tl.program_id(1) + idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + if segment_len <= 0: + return + + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim + + if IS_CONTINUOUS_BATCHING: + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + return + conv_states_base = ( + conv_states_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + + w_base = w_ptr + (idx_feats * stride_w_dim) + + if chunk_offset == 0: + load_init_state = False + if HAS_INITIAL_STATES: + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) + if load_init_state: + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + if KERNEL_WIDTH >= 2: + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + + if state_len <= seqlen: + idx_tokens_last = (seqlen - state_len) + tl.arange(0, NP2_STATELEN) + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) + mask_x = (idx_tokens_last >= 0)[:, None] & (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) + conv_states_ptrs_target = conv_states_base[None, :] + (idx_tokens_conv * stride_conv_state_tok)[:, None] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: + if load_init_state: + idx_tokens_conv = tl.arange(0, NP2_STATELEN) + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + VAL = state_len - seqlen + x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + new_conv_state = tl.where(mask, conv_state, loaded_x) + conv_states_ptrs_target = conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: + idx_tokens_conv = tl.arange(0, NP2_STATELEN) + VAL = state_len - seqlen + x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + conv_states_ptrs_target = conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token + + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + acc += matrix_x * matrix_w + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & (idx_feats < dim) + o_ptrs = ( + o_ptr + (sequence_start_index + token_offset + idx_token) * stride_o_token + (idx_feats * stride_o_dim) + ) + tl.store(o_ptrs, acc, mask=mask_1d) + + +# ============================================================ +# Decode kernel (simplified from SGLang: seqlen=1, no spec decoding) +# ============================================================ + + +@triton.jit() +def _causal_conv1d_update_kernel( + x_ptr, # (batch, dim) — seqlen=1 decode token + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, # [max_seqs, dim, state_len] + conv_state_indices_ptr, # [batch] + o_ptr, # (batch, dim) + # dimensions + batch: int, + dim: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, + # strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # meta + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # seqlen == 1 for single-token decode + seqlen = 1 + + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + return + + conv_states_base = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + # STEP 1: READ old conv_state (sliding window history) + prior_tokens = conv_states_base # start at index 0 + if KERNEL_WIDTH >= 2: + col0 = tl.load(prior_tokens, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + col1 = tl.load(prior_tokens + 1 * stride_conv_state_tok, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + col2 = tl.load(prior_tokens + 2 * stride_conv_state_tok, mask_w, 0.0) + + # STEP 2: Shift-left conv_state and append new x (sliding window update) + idx_tokens = tl.arange(0, NP2_STATELEN) + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] + + # Load old state shifted by seqlen=1 (elements [1..state_len-1]) + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + ) + mask_old = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + old_conv_state = tl.load(conv_state_ptrs_source, mask_old, other=0.0) + + # Load new x (only the last slot, VAL = state_len - 1) + VAL = state_len - seqlen + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_dim)[:, None] + ) # stride_x_dim used for token offset in dim-contiguous layout + mask_x = (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + + tl.debug_barrier() + new_conv_state = tl.where(mask_old, old_conv_state, loaded_x) + + # Write back new conv_state + conv_state_ptrs_target = conv_states_base + (idx_tokens * stride_conv_state_tok)[:, None] + mask_store = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask_store) + + # STEP 3: Load weights and compute convolution output + if HAS_BIAS: + acc = tl.load(bias_ptr + idx_feats, mask=mask_w, other=0.0).to(tl.float32) + else: + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + w_base = w_ptr + (idx_feats * stride_w_dim) + if KERNEL_WIDTH >= 2: + w_col0 = tl.load(w_base + 0 * stride_w_width, mask_w, other=0.0) + w_col1 = tl.load(w_base + 1 * stride_w_width, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_col2 = tl.load(w_base + 2 * stride_w_width, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_col3 = tl.load(w_base + 3 * stride_w_width, mask_w, other=0.0) + + x_now = tl.load(x_base, mask_w, 0.0) + if KERNEL_WIDTH == 2: + acc += col0 * w_col0 + x_now * w_col1 + elif KERNEL_WIDTH == 3: + acc += col0 * w_col0 + col1 * w_col1 + x_now * w_col2 + elif KERNEL_WIDTH == 4: + acc += col0 * w_col0 + col1 * w_col1 + col2 * w_col2 + x_now * w_col3 + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + + o_ptrs = o_ptr + idx_seq * stride_o_seq + idx_feats * stride_o_dim + tl.store(o_ptrs, acc, mask=mask_w) + + +# ============================================================ +# Python Wrappers (paddle edition) +# ============================================================ + + +def causal_conv1d_fn( + x: paddle.Tensor, + weight: paddle.Tensor, + bias: Optional[paddle.Tensor], + conv_states: paddle.Tensor, + query_start_loc: paddle.Tensor, + seq_lens_cpu: List[int], + cache_indices: Optional[paddle.Tensor] = None, + has_initial_state: Optional[paddle.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +) -> paddle.Tensor: + """ + Causal conv1d forward (Prefill varlen path). + + Args: + x: (dim, cu_seqlen) — all sequences concatenated + weight: (dim, width) — convolution kernel + bias: (dim,) or None + conv_states: [max_seqs, dim, width-1] — conv state pool (in-place update) + query_start_loc: [N+1] int32 — start position of each sequence in x + seq_lens_cpu: List[int] — length of each sequence (host side) + cache_indices: [N] int32 — pool slot index for each sequence + has_initial_state: [N] bool — whether initial state exists (read from pool) + activation: "silu" or None + pad_slot_id: padding slot sentinel (skipped during processing) + + Returns: + out: (dim, cu_seqlen) + """ + if isinstance(activation, bool) and activation: + activation = "silu" + + out = paddle.empty_like(x) + + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + stride_x_seq = 0 + stride_x_dim = x.strides[0] + stride_x_token = x.strides[1] + stride_w_dim = weight.strides[0] + stride_w_width = weight.strides[1] + + num_cache_lines = 0 + stride_istate_seq = stride_istate_dim = stride_istate_token = 0 + if conv_states is not None: + num_cache_lines = conv_states.shape[0] + stride_istate_seq = conv_states.strides[0] + stride_istate_dim = conv_states.strides[1] + stride_istate_token = conv_states.strides[2] + + stride_o_seq = 0 + stride_o_dim = out.strides[0] + stride_o_token = out.strides[1] + + def grid(META): + max_seq_len = max(seq_lens_cpu) + return ( + len(seq_lens_cpu), + (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"], + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_fwd_kernel[grid]( + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + out, + dim, + cu_seqlen, + num_cache_lines, + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + pad_slot_id, + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out + + +def causal_conv1d_update( + x: paddle.Tensor, + conv_state: paddle.Tensor, + weight: paddle.Tensor, + bias: Optional[paddle.Tensor] = None, + activation: Union[bool, str, None] = None, + conv_state_indices: Optional[paddle.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, +) -> paddle.Tensor: + """ + Causal conv1d single-token update (Decode path). + + Args: + x: (batch, dim) — current token + conv_state: [max_seqs, dim, state_len] — conv state pool (in-place update) + weight: (dim, width) — convolution kernel + bias: (dim,) or None + activation: "silu" or None + conv_state_indices: [batch] int32 — pool slot index + pad_slot_id: padding slot sentinel + + Returns: + out: (batch, dim) + """ + if isinstance(activation, bool): + activation = "silu" if activation else None + elif activation is not None: + assert activation in ["silu", "swish"] + + batch, dim = x.shape + _, width = weight.shape + num_cache_lines, _, state_len = conv_state.shape + + out = paddle.empty_like(x) + + stride_w_dim, stride_w_width = weight.strides[0], weight.strides[1] + stride_x_seq, stride_x_dim = x.strides[0], x.strides[1] + stride_o_seq, stride_o_dim = out.strides[0], out.strides[1] + stride_istate_seq = conv_state.strides[0] + stride_istate_dim = conv_state.strides[1] + stride_istate_token = conv_state.strides[2] + stride_state_indices = conv_state_indices.strides[0] if conv_state_indices is not None else 0 + + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return (batch, triton.cdiv(dim, META["BLOCK_N"])) + + _causal_conv1d_update_kernel[grid]( + x, + weight, + bias, + conv_state, + conv_state_indices, + out, + batch, + dim, + state_len, + num_cache_lines, + stride_x_seq, + stride_x_dim, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + pad_slot_id, + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + return out diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/__init__.py b/fastdeploy/model_executor/ops/triton_ops/fla/__init__.py new file mode 100644 index 00000000000..68c2aa13fd6 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/__init__.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FLA (Flash Linear Attention) Triton Kernel package — FastDeploy edition. + +Vendored from SGLang (which itself adapts from fla-org/flash-linear-attention), +ported to PaddlePaddle. Triton kernel code is unchanged; only Python wrappers +are adapted from torch to paddle. + +Public API: + Prefill path: + chunk_gated_delta_rule — 6-step chunk algorithm (main entry) + + Decode path: + fused_recurrent_gated_delta_rule — standard fused recurrent (with initial/final state) + fused_recurrent_gated_delta_rule_update — pool-index variant (in-place read/write of ssm_pool) + + Utilities: + chunk_local_cumsum — chunk-local prefix cumulative sum + l2norm_fwd — L2 normalization + solve_tril — lower-triangular matrix inversion + fused_gdn_gating — fused GDN gating (softplus + exp + sigmoid) +""" + +from fastdeploy.model_executor.ops.triton_ops.fla.chunk import ( + chunk_gated_delta_rule, + chunk_gated_delta_rule_fwd, +) +from fastdeploy.model_executor.ops.triton_ops.fla.chunk_delta_h import ( + chunk_gated_delta_rule_fwd_h, +) +from fastdeploy.model_executor.ops.triton_ops.fla.chunk_o import chunk_fwd_o +from fastdeploy.model_executor.ops.triton_ops.fla.chunk_scaled_dot_kkt import ( + chunk_scaled_dot_kkt_fwd, +) +from fastdeploy.model_executor.ops.triton_ops.fla.cumsum import chunk_local_cumsum +from fastdeploy.model_executor.ops.triton_ops.fla.fused_gdn_gating import ( + fused_gdn_gating, +) +from fastdeploy.model_executor.ops.triton_ops.fla.fused_recurrent import ( + fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_fwd, + fused_recurrent_gated_delta_rule_update, + fused_recurrent_gated_delta_rule_update_fwd, +) +from fastdeploy.model_executor.ops.triton_ops.fla.index import ( + prepare_chunk_indices, + prepare_chunk_offsets, + prepare_lens, +) +from fastdeploy.model_executor.ops.triton_ops.fla.l2norm import l2norm_fwd +from fastdeploy.model_executor.ops.triton_ops.fla.solve_tril import solve_tril +from fastdeploy.model_executor.ops.triton_ops.fla.wy_fast import recompute_w_u_fwd + +__all__ = [ + # Prefill path + "chunk_gated_delta_rule", + "chunk_gated_delta_rule_fwd", + "chunk_gated_delta_rule_fwd_h", + "chunk_fwd_o", + "chunk_scaled_dot_kkt_fwd", + "chunk_local_cumsum", + "solve_tril", + "recompute_w_u_fwd", + # Decode path + "fused_recurrent_gated_delta_rule", + "fused_recurrent_gated_delta_rule_fwd", + "fused_recurrent_gated_delta_rule_update", + "fused_recurrent_gated_delta_rule_update_fwd", + # Utilities + "l2norm_fwd", + "fused_gdn_gating", + "prepare_lens", + "prepare_chunk_indices", + "prepare_chunk_offsets", +] diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/chunk.py b/fastdeploy/model_executor/ops/triton_ops/fla/chunk.py new file mode 100644 index 00000000000..af08fdaf0fd --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/chunk.py @@ -0,0 +1,199 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +GDN Chunk Algorithm Coordinator — Prefill path core implementation. + +Executes the standard 6-step chunk GDN algorithm: + 1. chunk_local_cumsum(g) — compute local decay cumulative sum + 2. chunk_scaled_dot_kkt_fwd(k,beta) — compute A = beta * K * K^T + 3. solve_tril(A) — compute (I+A)^{-1} + 4. recompute_w_u_fwd(k,v,beta,A) — compute W, U (WY decomposition) + 5. chunk_gated_delta_rule_fwd_h — state propagation + 6. chunk_fwd_o — compute output + +Porting notes: + - Removed torch.autograd.Function (no backprop needed for inference) + - Removed @torch.compiler.disable (not applicable to paddle) + - Removed einops rearrange (head_first=False is the only supported layout) + - Removed SUPPRESS_LEVEL / autocast_custom_fwd (not relevant for inference) + - assert q.dtype != torch.float32 → assert q.dtype != paddle.float32 + - .to(q.dtype) → .cast(q.dtype) +""" + +from typing import Optional, Tuple + +import paddle + +from fastdeploy.model_executor.ops.triton_ops.fla.chunk_delta_h import ( + chunk_gated_delta_rule_fwd_h, +) +from fastdeploy.model_executor.ops.triton_ops.fla.chunk_o import chunk_fwd_o +from fastdeploy.model_executor.ops.triton_ops.fla.chunk_scaled_dot_kkt import ( + chunk_scaled_dot_kkt_fwd, +) +from fastdeploy.model_executor.ops.triton_ops.fla.cumsum import chunk_local_cumsum +from fastdeploy.model_executor.ops.triton_ops.fla.l2norm import l2norm_fwd +from fastdeploy.model_executor.ops.triton_ops.fla.solve_tril import solve_tril +from fastdeploy.model_executor.ops.triton_ops.fla.utils import input_guard +from fastdeploy.model_executor.ops.triton_ops.fla.wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + initial_state_indices: paddle.Tensor, + cu_seqlens: Optional[paddle.Tensor] = None, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + GDN 6-step chunk algorithm forward (internal implementation). + + Args: + q, k: [B, T, H, K] + v: [B, T, H, V] + g: [B, T, H] log decay (negative values) + beta: [B, T, H] write gate + scale: Q scale factor + initial_state: [N, H, K, V] initial SSM state pool + initial_state_indices: [N] pool slot index per sequence + cu_seqlens: [N+1] varlen mode (optional) + + Returns: + o: [B, T, H, V] + h: [B, NT, H, K, V] initial state at each chunk (for debugging/testing) + """ + # Step 1: compute chunk-local cumulative sum of g (force float32; safe_exp requires fp32/fp64) + g = chunk_local_cumsum(g, chunk_size=64, output_dtype=paddle.float32, cu_seqlens=cu_seqlens) + + # Step 2: compute A = beta * K * K^T (strictly lower-triangular) + A = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=paddle.float32, + ) + + # Step 3: compute (I + A)^{-1} + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + + # Step 4: compute W, U (WY decomposition) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + + # Step 5: state propagation + # The kernel always loads initial_state_indices even when USE_INITIAL_STATE=False, + # so dummy values are needed to avoid NoneType errors when initial_state is None. + B, T, H, K = k.shape + V = u.shape[-1] + _initial_state = initial_state + _initial_state_indices = initial_state_indices + if _initial_state is None: + # dummy: zero state, indices pointing to slot 0 + _initial_state = paddle.zeros([B, H, K, V], dtype=k.dtype) + _initial_state_indices = paddle.arange(B, dtype=paddle.int32) + h, v_new = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=_initial_state, + initial_state_indices=_initial_state_indices, + cu_seqlens=cu_seqlens, + ) + + # Step 6: compute output + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + return o, h + + +@input_guard +def chunk_gated_delta_rule( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: Optional[float] = None, + initial_state: Optional[paddle.Tensor] = None, + initial_state_indices: Optional[paddle.Tensor] = None, + cu_seqlens: Optional[paddle.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + """ + GDN Chunk Algorithm public interface (Prefill path). + + Only supports head_first=False (batch-first) layout: [B, T, H, ...]. + + Args: + q, k: [B, T, H, K] + v: [B, T, H, V] + g: [B, T, H] log decay (negative values) + beta: [B, T, H] write gate + scale: Q scale factor; defaults to 1/sqrt(K) when None + initial_state: [N, H, K, V] initial state (from SSM pool) + initial_state_indices: [N] pool slot indices + cu_seqlens: [N+1] varlen mode + use_qk_l2norm_in_kernel: whether to apply L2 norm to Q/K inside the kernel + + Returns: + o: [B, T, H, V] + h: [B, NT, H, K, V] initial state at each chunk (can be used for debugging) + """ + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + assert q.dtype != paddle.float32, "chunk_gated_delta_rule does not support float32; use bfloat16 or float16." + assert len(beta.shape) == 3, "beta must have shape [B, T, H] (head_first=False)" + + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"batch_size must be 1 in varlen mode, but got {q.shape[0]}. " + "Please concatenate variable-length inputs before passing in." + ) + if ( + cu_seqlens is not None + and initial_state_indices is not None + and initial_state_indices.shape[0] != cu_seqlens.shape[0] - 1 + ): + raise ValueError( + f"initial_state_indices length must equal the number of sequences " + f"{cu_seqlens.shape[0] - 1}, but got {initial_state_indices.shape[0]}." + ) + + if scale is None: + scale = k.shape[-1] ** -0.5 + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + o, h = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + initial_state_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + ) + return o.cast(q.dtype), h diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/chunk_delta_h.py b/fastdeploy/model_executor/ops/triton_ops/fla/chunk_delta_h.py new file mode 100644 index 00000000000..49c1fd88c91 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/chunk_delta_h.py @@ -0,0 +1,316 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +GDN chunk state propagation Triton Kernel. + +Porting notes: + - k.new_empty(B, NT, H, K, V) → paddle.empty([B, NT, H, K, V], dtype=k.dtype) + - torch.empty_like(u) → paddle.empty_like(u) + - is_nvidia_hopper does not affect the inference path; fixed config used directly + - Triton kernel code is unchanged +""" + +from typing import Optional, Tuple + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.index import ( + prepare_chunk_indices, + prepare_chunk_offsets, +) +from fastdeploy.model_executor.ops.triton_ops.fla.op import exp, safe_exp + +CHUNK_SIZE = 64 + + +# ============================================================ +# Triton Kernel (unchanged from SGLang) +# ============================================================ + + +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + initial_state, + initial_state_indices, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + INPLACE_UPDATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + + index = tl.load(initial_state_indices + i_n).to(tl.int32) + h0 = initial_state + index * stride_h + ht = initial_state + index * stride_h + if USE_INITIAL_STATE: + h0 = h0 + i_h * K * V + if INPLACE_UPDATE: + ht = ht + i_h * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = b_v * safe_exp(b_g_last - b_g)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + + # epilogue + if INPLACE_UPDATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +# ============================================================ +# Python Wrapper (paddle edition) +# ============================================================ + + +def chunk_gated_delta_rule_fwd_h( + k: paddle.Tensor, + w: paddle.Tensor, + u: paddle.Tensor, + g: Optional[paddle.Tensor] = None, + gk: Optional[paddle.Tensor] = None, + initial_state: Optional[paddle.Tensor] = None, + initial_state_indices: Optional[paddle.Tensor] = None, + save_new_value: bool = True, + cu_seqlens: Optional[paddle.Tensor] = None, +) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + """ + GDN chunk state propagation forward (Prefill). + + Args: + k: [B, T, Hg, K] + w: [B, T, H, K] — from recompute_w_u_fwd + u: [B, T, H, V] — from recompute_w_u_fwd (new value vectors) + g: [B, T, H] cumsum of log decay (optional) + gk: [B, T, H, K] key-wise decay (optional, mutually exclusive with g) + initial_state: [N, H, K, V] initial state pool + initial_state_indices: [N] pool slot index for each request + save_new_value: whether to save the updated v + cu_seqlens: cumulative sequence lengths [N+1] for varlen mode + + Returns: + h: [B, NT, H, K, V] — state at the start of each chunk + v_new: [B, T, H, V] — updated value tensor (if save_new_value=True) + """ + B, T, Hg, K = k.shape + V = u.shape[-1] + H = u.shape[-2] + BT = CHUNK_SIZE + + chunk_indices = prepare_chunk_indices(cu_seqlens, CHUNK_SIZE) if cu_seqlens is not None else None + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, f"current kernel does not support head dimension larger than 256 (K={K})" + + h = paddle.empty([B, NT, H, K, V], dtype=k.dtype) + v_new = paddle.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + initial_state=initial_state, + initial_state_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BV=32, + USE_G=g is not None, + USE_GK=gk is not None, + USE_INITIAL_STATE=initial_state is not None, + INPLACE_UPDATE=True, + SAVE_NEW_VALUE=v_new is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=2, + ) + return h, v_new diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/chunk_o.py b/fastdeploy/model_executor/ops/triton_ops/fla/chunk_o.py new file mode 100644 index 00000000000..47a5a195c1d --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/chunk_o.py @@ -0,0 +1,175 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_o.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +GDN chunk output computation Triton Kernel. + +Porting notes: + - torch.zeros_like(v) → paddle.zeros_like(v) + - Triton kernel code is unchanged +""" + +from typing import Optional + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.index import prepare_chunk_indices +from fastdeploy.model_executor.ops.triton_ops.fla.op import exp, safe_exp + +# ============================================================ +# Triton Kernel (unchanged from SGLang) +# ============================================================ + + +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +# ============================================================ +# Python Wrapper (paddle edition) +# ============================================================ + + +def chunk_fwd_o( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + h: paddle.Tensor, + g: Optional[paddle.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[paddle.Tensor] = None, + chunk_size: int = 64, +) -> paddle.Tensor: + """ + Chunk output computation (last step of Prefill). + + Args: + q: [B, T, Hg, K] + k: [B, T, Hg, K] + v: [B, T, H, V] — updated value vectors (from chunk_gated_delta_rule_fwd_h) + h: [B, NT, H, K, V] — initial state at each chunk + g: [B, T, H] cumsum of log decay (optional) + scale: Q scale factor; defaults to 1/sqrt(K) when None + cu_seqlens: cumulative sequence lengths [N+1] for varlen mode + chunk_size: chunk size + + Returns: + o: [B, T, H, V] + """ + B, T, Hg, K = q.shape + V = v.shape[-1] + H = v.shape[-2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = paddle.zeros_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=64, + USE_G=g is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=2, + ) + return o diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/chunk_scaled_dot_kkt.py b/fastdeploy/model_executor/ops/triton_ops/fla/chunk_scaled_dot_kkt.py new file mode 100644 index 00000000000..a766010d4fb --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/chunk_scaled_dot_kkt.py @@ -0,0 +1,143 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_scaled_dot_kkt.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +Chunk-level beta * K * K^T computation Triton Kernel. + +Porting notes: + - torch.empty(..., device=k.device, dtype=output_dtype) + → paddle.empty([...], dtype=output_dtype) + - Triton kernel code is unchanged +""" + +from typing import Optional + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.index import prepare_chunk_indices +from fastdeploy.model_executor.ops.triton_ops.fla.op import safe_exp + +# ============================================================ +# Triton Kernel (unchanged from SGLang) +# ============================================================ + + +# @triton.autotune( +# configs=[ +# triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) +# for BK in [32, 64, 128] +# for num_warps in [2, 4, 8] +# for num_stages in [2, 3, 4] +# ], +# key=["H", "K", "BT", "IS_VARLEN"], +# ) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = tl.arange(0, BT) + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +# ============================================================ +# Python Wrapper (paddle edition) +# ============================================================ + + +def chunk_scaled_dot_kkt_fwd( + k: paddle.Tensor, + beta: paddle.Tensor, + g_cumsum: Optional[paddle.Tensor] = None, + cu_seqlens: Optional[paddle.Tensor] = None, + chunk_size: int = 64, + output_dtype=None, +) -> paddle.Tensor: + r""" + Compute beta * K * K^T (within chunk). + + Args: + k: [B, T, Hg, K] + beta: [B, T, H] + g_cumsum: [B, T, H] cumsum of log decay; no decay applied when None + cu_seqlens: [N+1] for varlen mode + chunk_size: chunk size + output_dtype: output dtype (defaults to float32) + + Returns: + A: [B, T, H, BT] where BT = chunk_size + """ + B, T, Hg, K = k.shape + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + out_dtype = output_dtype if output_dtype is not None else paddle.float32 + A = paddle.empty([B, T, H, BT], dtype=out_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=64, + IS_VARLEN=cu_seqlens is not None, + USE_G=g_cumsum is not None, + num_warps=8, + num_stages=3, + ) + return A diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/cumsum.py b/fastdeploy/model_executor/ops/triton_ops/fla/cumsum.py new file mode 100644 index 00000000000..b1573999359 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/cumsum.py @@ -0,0 +1,304 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/cumsum.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +FLA chunk-local prefix cumulative sum Triton Kernel. + +Porting notes: + - torch.empty_like(g, dtype=...) → paddle.empty_like(g, dtype=...) + - autotune left commented out (consistent with SGLang) + - check_shared_mem() and is_nvidia_hopper moved to local utils + - Triton kernel code is unchanged +""" + +from typing import Optional + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.index import prepare_chunk_indices +from fastdeploy.model_executor.ops.triton_ops.fla.utils import input_guard + +# BS_LIST depends on shared memory size; use conservative defaults for inference +BS_LIST = [32, 64] + + +# ============================================================ +# Triton Kernel — Scalar cumsum (3D tensor) +# ============================================================ + + +# @triton.autotune( +# configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], +# key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +# ) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +# ============================================================ +# Triton Kernel — Vector cumsum (4D tensor) +# ============================================================ + + +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE", "HAS_SCALE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +# ============================================================ +# Python Wrappers (paddle edition) +# ============================================================ + + +def chunk_local_cumsum_scalar( + g: paddle.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[paddle.Tensor] = None, + head_first: bool = False, + output_dtype=None, +) -> paddle.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + out_dtype = output_dtype or g.dtype + g_org = g + g = paddle.empty_like(g, dtype=out_dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + HAS_SCALE=scale is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=8, + num_stages=3, + ) + return g + + +def chunk_local_cumsum_vector( + g: paddle.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[paddle.Tensor] = None, + head_first: bool = False, + output_dtype=None, +) -> paddle.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + out_dtype = output_dtype or g.dtype + g_org = g + g = paddle.empty_like(g, dtype=out_dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + chunk_local_cumsum_vector_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + HAS_SCALE=scale is not None, + IS_VARLEN=cu_seqlens is not None, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: paddle.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[paddle.Tensor] = None, + head_first: bool = False, + output_dtype=None, + **kwargs, +) -> paddle.Tensor: + """ + Chunk-local prefix cumulative sum (forward). + + Args: + g: [B, T, H] or [B, T, H, S] + chunk_size: chunk size (must be a power of 2) + reverse: whether to compute reverse cumsum + scale: optional scale factor + cu_seqlens: cumulative sequence lengths [N+1] for varlen mode + head_first: whether layout is head-first + output_dtype: output dtype (defaults to same as input) + + Returns: + cumsum tensor with the same shape as g + """ + if cu_seqlens is not None: + assert g.shape[0] == 1, "batch_size must be 1 in varlen mode" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/fused_gdn_gating.py b/fastdeploy/model_executor/ops/triton_ops/fla/fused_gdn_gating.py new file mode 100644 index 00000000000..9963571850d --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/fused_gdn_gating.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fused GDN gating Triton kernel. + +Ported from SGLang (sglang/srt/layers/attention/fla/fused_gdn_gating.py). +Computes in a single kernel launch: + g = -exp(A_log) * softplus(a + dt_bias) + beta_output = sigmoid(b) +""" + +from __future__ import annotations + +from typing import Tuple + +import paddle +import triton +import triton.language as tl + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + dt_bias: paddle.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Fused GDN gating: g = -exp(A_log)*softplus(a+dt_bias), beta = sigmoid(b). + + Args: + A_log: [num_heads] - log of A matrix + a: [num_tokens, num_heads] - alpha values + b: [num_tokens, num_heads] - beta values + dt_bias: [num_heads] - delta-time bias + + Returns: + g: [num_tokens, num_heads] float32 + beta_output: [num_tokens, num_heads] float32 + """ + num_tokens, num_heads = a.shape + seq_len = 1 + grid = (num_tokens, seq_len, triton.cdiv(num_heads, 8)) + g = paddle.empty([num_tokens, num_heads], dtype=paddle.float32) + beta_output = paddle.empty([num_tokens, num_heads], dtype=paddle.float32) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1, + ) + return g, beta_output diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/fused_recurrent.py b/fastdeploy/model_executor/ops/triton_ops/fla/fused_recurrent.py new file mode 100644 index 00000000000..8f9766d3f5e --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/fused_recurrent.py @@ -0,0 +1,553 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +GDN Fused Recurrent Kernel — Decode path core implementation. + +Provides two public functions: + 1. fused_recurrent_gated_delta_rule + Standard interface: accepts initial_state / outputs final_state Tensor + Suitable for state saving after Prefill, single-pass inference + State layout: [N, HV, K, V] (K-first) + + 2. fused_recurrent_gated_delta_rule_update + Pool-index interface: in-place read/write of state at pool[indices] + Suitable for serving Decode phase (no external gather/scatter needed) + Pool layout: [max_seqs, HV, K, V] + +Notes: + - Triton kernel code is identical to SGLang, no modifications needed + - Python wrapper replaces torch.Tensor → paddle.Tensor, removes torch.autograd.Function + - FD inference does not require backpropagation; fwd functions are called directly +""" + +from typing import Optional, Tuple + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.op import exp +from fastdeploy.model_executor.ops.triton_ops.fla.utils import input_guard + +# ============================================================ +# Triton Kernel — Standard fused recurrent (full state in/out) +# Source: SGLang fused_recurrent.py lines 15-121 +# Triton code is unchanged +# ============================================================ + + +@triton.jit(do_not_specialize=["T"]) +def _fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) + b_q = b_q * scale + + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + b_h += b_k[:, None] * b_v[None, :] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +# ============================================================ +# Triton Kernel — Pool-index fused recurrent (in-place state read/write) +# Source: SGLang fused_recurrent.py fused_recurrent_gated_delta_rule_update_fwd_kernel +# Key feature: reads/writes state directly at h0_source[h0_indices[i]], no external gather/scatter +# Triton code is unchanged +# ============================================================ + + +@triton.jit(do_not_specialize=["T"]) +def _fused_recurrent_gated_delta_rule_update_kernel( + q, + k, + v, + g, + beta, + o, + h0_source, + h0_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + DISABLE_STATE_UPDATE: tl.constexpr, + IS_KDA: tl.constexpr, +): + """ + Pool-index variant: reads initial state from h0_source[h0_indices[i_n]], + and writes the final state back in-place to the same location after computation. + Requests with PAD_SLOT_ID=-1 skip state read/write automatically (safe for CUDA Graph padding). + """ + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: # skip when PAD_SLOT_ID=-1 + p_h0 = h0_source + idx * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) + b_q = b_q * scale + + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + b_h += b_k[:, None] * b_v[None, :] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + # In-place write-back to pool + if not DISABLE_STATE_UPDATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: # skip write-back when PAD_SLOT_ID=-1 + p_h0 = h0_source + idx * HV * K * V + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + +# ============================================================ +# Python Wrapper — Standard interface (paddle edition) +# ============================================================ + + +@input_guard +def fused_recurrent_gated_delta_rule_fwd( + q: paddle.Tensor, # [B, T, H, K] + k: paddle.Tensor, # [B, T, H, K] + v: paddle.Tensor, # [B, T, HV, V] + g: paddle.Tensor, # [B, T, HV] + beta: paddle.Tensor, # [B, T, HV] or [B, T, HV, V] + scale: float, + initial_state: Optional[paddle.Tensor], # [N, HV, K, V] + output_final_state: bool, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[paddle.Tensor] = None, # [N+1] int64 +) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + """ + Standard fused recurrent forward. + + Args: + q, k: [B, T, H, K] (H = num_k_heads) + v: [B, T, HV, V] (HV = num_v_heads, HV >= H for GVA) + g: [B, T, HV] log decay (negative values) + beta: [B, T, HV] write gate [0, 1] + scale: float Q scale factor (typically 1/sqrt(K)) + initial_state: [N, HV, K, V] initial SSM state (K-first layout) + output_final_state: whether to output the final state + use_qk_l2norm_in_kernel: whether to apply L2 norm inside the kernel + cu_seqlens: [N+1] int64, cumulative sequence lengths for varlen mode + + Returns: + o: [B, T, HV, V] + final_state: [N, HV, K, V] if output_final_state else None + """ + B, T, H, K = q.shape[0], q.shape[1], q.shape[2], q.shape[3] + HV, V = v.shape[2], v.shape[3] + N = B if cu_seqlens is None else cu_seqlens.shape[0] - 1 + + BK = triton.next_power_of_2(K) + BV = min(triton.next_power_of_2(V), 32) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, f"NK > 1 is not supported yet (K={K}, BK={BK})" + + num_stages = 3 + num_warps = 1 + + # output Tensor (NK=1, squeezed) + o = paddle.empty([NK, B, T, HV, V], dtype=v.dtype) + final_state = None + if output_final_state: + final_state = paddle.empty([N, HV, K, V], dtype=paddle.float32) + + grid = (NK, NV, N * HV) + _fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_VARLEN=cu_seqlens is not None, + IS_KDA=False, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) # [B, T, HV, V] + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: Optional[paddle.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[paddle.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[paddle.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + """ + GDN Fused Recurrent public interface (standard). + + For use in Prefill phase or test comparison scenarios. + For Decode phase, prefer fused_recurrent_gated_delta_rule_update (pool-index variant). + + Args: + q, k: [B, T, H, K] + v: [B, T, HV, V] + g: [B, T, HV] log decay + beta: [B, T, HV] write gate; all-ones when None + scale: Q scale; defaults to 1/sqrt(K) when None + initial_state: [N, HV, K, V] + output_final_state: whether to return final state + cu_seqlens: [N+1] varlen mode + + Returns: + o: [B, T, HV, V] + final_state: [N, HV, K, V] or None + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"batch_size must be 1 in varlen mode, but got {q.shape[0]}. " + "Please concatenate variable-length inputs before passing in." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + if beta is None: + beta = paddle.ones(q.shape[:-1], dtype=q.dtype) # [B, T, HV] + + return fused_recurrent_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens, + ) + + +# ============================================================ +# Python Wrapper — Pool-index interface (Decode core) +# ============================================================ + + +@input_guard +def fused_recurrent_gated_delta_rule_update_fwd( + q: paddle.Tensor, # [B, T, H, K] + k: paddle.Tensor, # [B, T, H, K] + v: paddle.Tensor, # [B, T, HV, V] + g: paddle.Tensor, # [B, T, HV] + beta: paddle.Tensor, # [B, T, HV] + scale: float, + ssm_pool: paddle.Tensor, # [max_seqs, HV, K, V] in-place read/write + ssm_indices: paddle.Tensor, # [N] int32/int64, PAD_SLOT_ID=-1 safe + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[paddle.Tensor] = None, # [N+1] + disable_state_update: bool = False, +) -> paddle.Tensor: + """ + Pool-index fused recurrent forward (Decode phase core). + + Reads initial state from ssm_pool[ssm_indices[i]] and writes back in-place, + avoiding external gather/scatter operations, compatible with CUDA Graph. + + Args: + q, k: [B, T, H, K] + v: [B, T, HV, V] + g: [B, T, HV] log decay + beta: [B, T, HV] write gate + scale: float + ssm_pool: [max_seqs, HV, K, V] full SSM state pool (K-first layout) + ssm_indices: [N] pool slot index per request in this step; + requests with PAD_SLOT_ID=-1 skip state read/write + disable_state_update: when True, only computes output without updating pool state + + Returns: + o: [B, T, HV, V] + """ + B, T, H, K = q.shape[0], q.shape[1], q.shape[2], q.shape[3] + HV, V = v.shape[2], v.shape[3] + N = B if cu_seqlens is None else cu_seqlens.shape[0] - 1 + + BK = triton.next_power_of_2(K) + BV = min(triton.next_power_of_2(V), 32) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, f"NK > 1 is not supported yet (K={K}, BK={BK})" + + num_stages = 3 + num_warps = 1 + + o = paddle.empty([NK, B, T, HV, V], dtype=v.dtype) + + grid = (NK, NV, N * HV) + _fused_recurrent_gated_delta_rule_update_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0_source=ssm_pool, + h0_indices=ssm_indices, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_INITIAL_STATE=ssm_pool is not None, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_VARLEN=cu_seqlens is not None, + DISABLE_STATE_UPDATE=disable_state_update, + IS_KDA=False, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) # [B, T, HV, V] + return o + + +def fused_recurrent_gated_delta_rule_update( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: Optional[paddle.Tensor] = None, + scale: Optional[float] = None, + ssm_pool: Optional[paddle.Tensor] = None, + ssm_indices: Optional[paddle.Tensor] = None, + cu_seqlens: Optional[paddle.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + disable_state_update: bool = False, +) -> paddle.Tensor: + """ + GDN Pool-index Fused Recurrent public interface (Decode core). + + Recommended interface for serving Decode phase. + Operates directly on the SSM Pool, no external gather/scatter needed. + + Args: + q, k: [B, T, H, K] (T=1 for Decode) + v: [B, T, HV, V] + g: [B, T, HV] log decay + beta: [B, T, HV] write gate; all-ones when None + scale: defaults to 1/sqrt(K) when None + ssm_pool: [max_seqs, HV, K, V] SSM state pool (K-first) + ssm_indices: [N] int pool slot index per request + disable_state_update: read-only when True (for debugging) + + Returns: + o: [B, T, HV, V] + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if beta is None: + beta = paddle.ones(q.shape[:-1], dtype=q.dtype) + + return fused_recurrent_gated_delta_rule_update_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + ssm_pool=ssm_pool, + ssm_indices=ssm_indices, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens, + disable_state_update=disable_state_update, + ) diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/index.py b/fastdeploy/model_executor/ops/triton_ops/fla/index.py new file mode 100644 index 00000000000..b924c1c23d6 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/index.py @@ -0,0 +1,48 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +FLA sequence chunking index utility functions. + +Porting notes: + - Replaced torch with paddle + - torch.cat([...]) → paddle.concat([...]) + - torch.arange(n) → paddle.arange(n) + - torch.stack([a,b], 1) → paddle.stack([a,b], axis=1) + - tensor.eq(0) → (tensor == 0) + - .cumsum(0) → .cumsum(axis=0) + - cu_seqlens.new_tensor([0]) → paddle.to_tensor([0], dtype=cu_seqlens.dtype) +""" + +import paddle +import triton + +from fastdeploy.model_executor.ops.triton_ops.fla.utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: paddle.Tensor) -> paddle.Tensor: + """Compute the length of each sequence [N].""" + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: paddle.Tensor, chunk_size: int) -> paddle.Tensor: + """ + Generate (seq_idx, chunk_in_seq_idx) pairs for each chunk. + Returns shape: [num_chunks, 2], dtype matches cu_seqlens. + """ + indices = paddle.concat([paddle.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + # (indices == 0) marks the first chunk of each sequence + seq_ids = (indices == 0).cast(paddle.int64).cumsum(axis=0) - 1 + return paddle.stack([seq_ids, indices], axis=1).cast(cu_seqlens.dtype) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: paddle.Tensor, chunk_size: int) -> paddle.Tensor: + """ + Compute the chunk start offset for each sequence (cumulative chunk count per sequence). + Returns shape: [N+1], dtype matches cu_seqlens. + """ + lens_in_chunks = triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + return paddle.concat([paddle.to_tensor([0], dtype=cu_seqlens.dtype), lens_in_chunks]).cumsum(axis=0) diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/l2norm.py b/fastdeploy/model_executor/ops/triton_ops/fla/l2norm.py new file mode 100644 index 00000000000..b10182b0dc2 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/l2norm.py @@ -0,0 +1,142 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +L2 Norm Triton Kernel. + +Porting notes: + - Removed torch.autograd.Function and nn.Module (no backprop needed for inference) + - torch.empty_like(x) → paddle.empty_like(x) + - Retained both Triton kernels unchanged (pure GPU instructions) + - Exposed l2norm_fwd directly as the main entry point +""" + + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.utils import input_guard + +# ============================================================ +# Triton Kernels (unchanged from SGLang) +# ============================================================ + +BT_LIST = [8, 16, 32, 64, 128] + + +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.jit +def l2norm_fwd_kernel( + x, + y, + eps, + NB: tl.constexpr, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +# ============================================================ +# Python Wrapper (paddle edition) +# ============================================================ + + +@input_guard +def l2norm_fwd( + x: paddle.Tensor, + eps: float = 1e-6, + output_dtype=None, +) -> paddle.Tensor: + """ + L2 normalization forward (Triton-accelerated). + + Args: + x: arbitrary shape, last dimension is feature dim D + eps: numerical stability term + output_dtype: output dtype; None means same as input + + Returns: + L2 normalized tensor, same shape as x + """ + x_shape_og = x.shape + x = x.reshape([-1, x.shape[-1]]) + + if output_dtype is None: + y = paddle.empty_like(x) + else: + y = paddle.empty_like(x, dtype=output_dtype) + + assert y.strides[-1] == 1 if hasattr(y, "strides") else True + T, D = x.shape[0], x.shape[-1] + + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + BT=16, + num_warps=8, + num_stages=3, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + num_warps=8, + num_stages=3, + ) + + return y.reshape(x_shape_og) + + +# Aliases for SGLang API compatibility +l2norm = l2norm_fwd +l2_norm = l2norm_fwd diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/op.py b/fastdeploy/model_executor/ops/triton_ops/fla/op.py new file mode 100644 index 00000000000..12da231b680 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/op.py @@ -0,0 +1,53 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +FLA base Triton operation helpers. + +Porting notes: + - Triton kernel code is unchanged (pure GPU instructions, independent of torch/paddle) + - Only removed dependency on sglang, replaced with local utils +""" + +import os + +import triton +import triton.language as tl + +try: + import triton.language.extra.libdevice as tldevice + + _HAS_LIBDEVICE = True +except ImportError: + _HAS_LIBDEVICE = False + +from fastdeploy.model_executor.ops.triton_ops.fla.utils import is_gather_supported + +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1" and _HAS_LIBDEVICE: + exp = tldevice.fast_expf + exp2 = tldevice.exp2 + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + exp = tl.exp + exp2 = tl.math.exp2 + log = tl.log + log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + return exp(tl.where(x <= 0, x, float("-inf"))) + + +if not is_gather_supported: + + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Fallback: placeholder when tl.gather is unavailable. + """ + return None + +else: + gather = tl.gather diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/solve_tril.py b/fastdeploy/model_executor/ops/triton_ops/fla/solve_tril.py new file mode 100644 index 00000000000..6ef3e6fa282 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/solve_tril.py @@ -0,0 +1,356 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/solve_tril.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +Lower-triangular matrix inversion Triton Kernel. + +Porting notes: + - torch.empty(..., device=A.device, dtype=torch.float) → paddle.empty([...], dtype=paddle.float32) + - torch.float → paddle.float32 + - Triton kernel code is unchanged (all three kernels retained as-is) +""" + +from typing import Optional + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.index import prepare_chunk_indices +from fastdeploy.model_executor.ops.triton_ops.fla.utils import input_guard + +# ============================================================ +# Triton Kernels (unchanged from SGLang) +# ============================================================ + + +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee") + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee") + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)) + tl.store( + p_Ai_12, + fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_13, + fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_14, + fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_23, + fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_24, + fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_34, + fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +# ============================================================ +# Python Wrapper (paddle edition) +# ============================================================ + + +@input_guard +def solve_tril( + A: paddle.Tensor, + cu_seqlens: Optional[paddle.Tensor] = None, + output_dtype=None, +) -> paddle.Tensor: + """ + Compute the inverse of a strictly lower-triangular matrix: (I + A)^{-1}. + + Args: + A: [B, T, H, K], where K ∈ {16, 32, 64} + cu_seqlens: cumulative sequence lengths [N+1] for varlen mode + output_dtype: output dtype (defaults to float32) + + Returns: + (I + A)^{-1}, same shape as A + """ + assert A.shape[-1] in [16, 32, 64], f"BT must be 16/32/64, got {A.shape[-1]}" + + B, T, H, BT = A.shape + out_dtype = output_dtype if output_dtype is not None else paddle.float32 + # Ad: 16x16 block inverses (intermediate result) + Ad_dtype = paddle.float32 if BT != 16 else out_dtype + Ad = paddle.empty([B, T, H, 16], dtype=Ad_dtype) # 16x16 block inverses (intermediate) + chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[(NT, B * H)]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + IS_VARLEN=cu_seqlens is not None, + num_warps=1, + num_stages=4, + ) + if BT == 16: + return Ad + + Ai = paddle.empty([B, T, H, BT], dtype=out_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[(NT, B * H)]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=3, + ) + return Ai diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/utils.py b/fastdeploy/model_executor/ops/triton_ops/fla/utils.py new file mode 100644 index 00000000000..fa4be3790f8 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/utils.py @@ -0,0 +1,110 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +FLA utility functions. + +Porting notes: + - Removed torch dependency, replaced with paddle + - Removed dependency on sglang/transformers + - Retained core logic of input_guard and tensor_cache decorators + - is_gather_supported checks whether tl.gather is available +""" + +import functools +import logging +import os +from functools import lru_cache +from typing import Any, Callable + +import paddle +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +# ============================================================ +# Environment flags +# ============================================================ + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" + + +@lru_cache(maxsize=None) +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cuda" + + +@lru_cache(maxsize=None) +def get_multiprocessor_count(device_idx: int = 0) -> int: + try: + return triton.runtime.driver.active.utils.get_device_properties(device_idx)["multiprocessor_count"] + except BaseException: + return -1 + + +# tl.gather availability check (Triton >= 3.2.0) +is_gather_supported: bool = hasattr(tl, "gather") + + +# ============================================================ +# input_guard decorator: ensure all Tensors are contiguous +# ============================================================ + + +def input_guard(fn: Callable) -> Callable: + """ + Ensure all input Tensors are contiguous and run on the correct CUDA device. + Ported from SGLang, removed torch dependency, replaced with paddle. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + # make all Tensors contiguous + contiguous_args = tuple(arg.contiguous() if isinstance(arg, paddle.Tensor) else arg for arg in args) + contiguous_kwargs = {k: (v.contiguous() if isinstance(v, paddle.Tensor) else v) for k, v in kwargs.items()} + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +contiguous = input_guard + + +# ============================================================ +# tensor_cache decorator: cache results of the last N calls +# ============================================================ + + +def tensor_cache(fn: Callable) -> Callable: + """ + Cache results of the last cache_size calls (matched by object identity). + Suitable for idempotent functions such as shape computations. + """ + cache_entries: list = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + if all(a is b for a, b in zip(args, last_args)) and all( + k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items() + ): + # LRU: move to end + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper diff --git a/fastdeploy/model_executor/ops/triton_ops/fla/wy_fast.py b/fastdeploy/model_executor/ops/triton_ops/fla/wy_fast.py new file mode 100644 index 00000000000..7dc447e37e7 --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/fla/wy_fast.py @@ -0,0 +1,172 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py +# Original: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang (MIT License) +# Adapted for FastDeploy (PaddlePaddle) by PaddlePaddle Authors, 2025. +""" +WY decomposition W/U matrix recomputation Triton Kernel. + +Porting notes: + - torch.empty_like(v) → paddle.empty_like(v) + - k.new_empty(B, T, H, K) → paddle.empty([B, T, H, K], dtype=k.dtype) + - Triton kernel code is unchanged +""" + +from typing import Optional, Tuple + +import paddle +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.fla.index import prepare_chunk_indices + +# ============================================================ +# Triton Kernel (unchanged from SGLang) +# ============================================================ + + +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +# ============================================================ +# Python Wrapper (paddle edition) +# ============================================================ + + +def recompute_w_u_fwd( + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + g_cumsum: paddle.Tensor, + A: paddle.Tensor, + cu_seqlens: Optional[paddle.Tensor], +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Recompute W and U matrices from the WY decomposition. + + Args: + k: [B, T, Hg, K] + v: [B, T, H, V] + beta: [B, T, H] + g_cumsum: [B, T, H] cumsum of log decay + A: [B, T, H, BT] lower-triangular matrix inverse (from solve_tril) + cu_seqlens: cumulative sequence lengths [N+1] for varlen mode + + Returns: + w: [B, T, H, K] + u: [B, T, H, V] — updated value vectors (new value tensor) + """ + B, T, Hg, K = k.shape + V = v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = paddle.empty_like(v) + w = paddle.empty([B, T, H, K], dtype=k.dtype) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=3, + ) + return w, u + + +# Alias for SGLang API compatibility +fwd_recompute_w_u = recompute_w_u_fwd diff --git a/tests/model_executor/ops/triton_ops/test_gdn_kernels.py b/tests/model_executor/ops/triton_ops/test_gdn_kernels.py new file mode 100644 index 00000000000..63f6f963ebc --- /dev/null +++ b/tests/model_executor/ops/triton_ops/test_gdn_kernels.py @@ -0,0 +1,923 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GDN (Gated Delta Network) Triton Kernel 单元测试。 + +测试覆盖: + 1. TestFusedRecurrentGDN — fused_recurrent_gated_delta_rule (Decode 路径) + 2. TestChunkGDN — chunk_gated_delta_rule (Prefill 路径) + 3. TestCausalConv1dUpdate — causal_conv1d_update (Decode conv) + 4. TestCausalConv1dFn — causal_conv1d_fn (Prefill conv, varlen) + 5. TestFusedGDNGating — fused_gdn_gating (GDN 门控 Triton kernel) + +参考基准: + - GDN: Transformers 动态图中的 torch_recurrent_gated_delta_rule / + torch_chunk_gated_delta_rule(Pure PyTorch 实现,与 FLA 论文对齐) + 移植为纯 paddle 实现后作为 baseline。 + - Conv1d: torch_causal_conv1d_update(来自 Transformers 的 F.conv1d 参考) + 移植为纯 paddle 实现后作为 baseline。 + +运行方法: + cd /root/.../FastDeploy + python -m pytest tests/model_executor/ops/triton_ops/test_gdn_kernels.py -v + # 或 + python tests/model_executor/ops/triton_ops/test_gdn_kernels.py +""" + +import unittest + +import numpy as np +import paddle +import paddle.nn.functional as F + +# ============================================================ +# Pure-Paddle Reference Implementations (ported from Transformers) +# ============================================================ + + +def _l2norm_paddle(x: paddle.Tensor, dim: int = -1, eps: float = 1e-6) -> paddle.Tensor: + """L2 norm, aligned with FLA's l2norm_fwd.""" + inv_norm = paddle.rsqrt((x * x).sum(axis=dim, keepdim=True) + eps) + return x * inv_norm + + +def paddle_causal_conv1d_update_ref( + hidden_states: paddle.Tensor, + conv_state: paddle.Tensor, + weight: paddle.Tensor, + bias: paddle.Tensor = None, + activation: str = "silu", +) -> paddle.Tensor: + """ + Pure-Paddle reference for single-token causal conv1d update. + + Args: + hidden_states: [batch, dim, 1] (unsqueezed) + conv_state: [batch, dim, state_len] (single-sequence, NOT pool) + weight: [dim, width] + bias: [dim,] or None + activation: "silu" or None + + Returns: + out: [batch, dim] + (conv_state is updated in-place) + """ + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + hidden_states_new = paddle.concat([conv_state, hidden_states], axis=-1).cast(weight.dtype) + # update conv_state in-place (shift left) + conv_state_new = hidden_states_new[:, :, -state_len:] + for i in range(conv_state.shape[0]): + conv_state[i] = conv_state_new[i] + # grouped conv1d: weight [dim, width] → [dim, 1, width] + w = weight.unsqueeze(1) # [dim, 1, width] + out = F.conv1d(hidden_states_new, w, bias, padding=0, groups=hidden_size) + if activation in ["silu", "swish"]: + out = F.silu(out) + out = out[:, :, -seq_len:] # keep last seq_len output + return out.squeeze(-1) # [batch, dim] + + +def paddle_recurrent_gated_delta_rule_ref( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple: + """ + Pure-Paddle reference for fused recurrent GDN (Decode 路径). + + Args: + query, key: [B, T, H, K] + value: [B, T, H, V] + g: [B, T, H] log decay (negative) + beta: [B, T, H] write gate + initial_state: [B, H, K, V] or None + output_final_state: bool + use_qk_l2norm_in_kernel: bool + + Returns: + out: [B, T, H, V] + last_state: [B, H, K, V] if output_final_state else None + """ + if use_qk_l2norm_in_kernel: + query = _l2norm_paddle(query, dim=-1) + key = _l2norm_paddle(key, dim=-1) + + # Transpose to [B, H, T, D] and cast to float32 + query, key, value, beta, g = [ + x.transpose([0, 2, 1, 3]).cast(paddle.float32) if x.ndim == 4 else x.transpose([0, 2, 1]).cast(paddle.float32) + for x in (query, key, value, beta, g) + ] + + B, H, T, K = key.shape + V = value.shape[-1] + scale = 1.0 / (K**0.5) + query = query * scale + + out = paddle.zeros([B, H, T, V], dtype=paddle.float32) + last_state = ( + paddle.zeros([B, H, K, V], dtype=paddle.float32) + if initial_state is None + else initial_state.cast(paddle.float32) + ) + + for i in range(T): + q_t = query[:, :, i] # [B, H, K] + k_t = key[:, :, i] # [B, H, K] + v_t = value[:, :, i] # [B, H, V] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + beta_t = beta[:, :, i].unsqueeze(-1) # [B, H, 1] + + last_state = last_state * g_t + kv_mem = (last_state * k_t.unsqueeze(-1)).sum(axis=-2) # [B, H, V] + delta = (v_t - kv_mem) * beta_t # [B, H, V] + last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + out[:, :, i] = (last_state * q_t.unsqueeze(-1)).sum(axis=-2) + + if not output_final_state: + last_state = None + + out = out.transpose([0, 2, 1, 3]) # [B, T, H, V] + return out, last_state + + +def paddle_chunk_gated_delta_rule_ref( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + chunk_size: int = 64, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple: + """ + Pure-Paddle reference for chunk GDN (Prefill 路径). + + Closely mirrors Transformers' torch_chunk_gated_delta_rule. + """ + if use_qk_l2norm_in_kernel: + query = _l2norm_paddle(query, dim=-1) + key = _l2norm_paddle(key, dim=-1) + + initial_dtype = query.dtype + query, key, value, beta, g = [ + x.transpose([0, 2, 1, 3]).cast(paddle.float32) if x.ndim == 4 else x.transpose([0, 2, 1]).cast(paddle.float32) + for x in (query, key, value, beta, g) + ] + + B, H, T, K = key.shape + V = value.shape[-1] + pad_size = (chunk_size - T % chunk_size) % chunk_size + query = F.pad(query, [0, 0, 0, pad_size]) + key = F.pad(key, [0, 0, 0, pad_size]) + value = F.pad(value, [0, 0, 0, pad_size]) + beta = F.pad(beta, [0, pad_size]) + g = F.pad(g, [0, pad_size]) + TT = T + pad_size + + scale = 1.0 / (K**0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + # reshape to chunks + NC = TT // chunk_size + query = query.reshape([B, H, NC, chunk_size, K]) + key = key.reshape([B, H, NC, chunk_size, K]) + value = value.reshape([B, H, NC, chunk_size, V]) + k_beta = k_beta.reshape([B, H, NC, chunk_size, K]) + v_beta = v_beta.reshape([B, H, NC, chunk_size, V]) + g = g.reshape([B, H, NC, chunk_size]) + + mask = paddle.triu(paddle.ones([chunk_size, chunk_size], dtype=paddle.bool), diagonal=0) + + g = g.cumsum(axis=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().cast(paddle.float32)).tril() + + attn = -((k_beta @ key.transpose([0, 1, 2, 4, 3])) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[:, :, :, i, :i].clone() + sub = attn[:, :, :, :i, :i].clone() + attn[:, :, :, i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + + eye = paddle.eye(chunk_size, dtype=attn.dtype) + attn = attn + eye + + value_new = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_state = ( + paddle.zeros([B, H, K, V], dtype=value_new.dtype) + if initial_state is None + else initial_state.cast(value_new.dtype) + ) + core_attn_out = paddle.zeros_like(value_new) + mask2 = paddle.triu(paddle.ones([chunk_size, chunk_size], dtype=paddle.bool), diagonal=1) + + for i in range(NC): + q_i = query[:, :, i] # [B, H, cs, K] + k_i = key[:, :, i] + v_i = value_new[:, :, i] + attn_i = (q_i @ k_i.transpose([0, 1, 3, 2]) * decay_mask[:, :, i]).masked_fill(mask2, 0) + v_prime = k_cumdecay[:, :, i] @ last_state + v_new_i = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new_i + last_state = ( + last_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp().unsqueeze(-1)).transpose([0, 1, 3, 2]) @ v_new_i + ) + + if not output_final_state: + last_state = None + + core_attn_out = core_attn_out.reshape([B, H, TT, V]) + core_attn_out = core_attn_out[:, :, :T] + core_attn_out = core_attn_out.transpose([0, 2, 1, 3]).cast(initial_dtype) + return core_attn_out, last_state + + +# ============================================================ +# Test Cases +# ============================================================ + + +class TestFusedRecurrentGDN(unittest.TestCase): + """测试 fused_recurrent_gated_delta_rule (Decode 路径 SSM kernel).""" + + def setUp(self): + paddle.seed(42) + self.dtype = paddle.bfloat16 + self.B, self.T = 2, 8 + self.H, self.K, self.V = 4, 64, 64 + + def _make_inputs(self, T=None): + T = T or self.T + B, H, K, V = self.B, self.H, self.K, self.V + q = paddle.randn([B, T, H, K], dtype=paddle.float32).cast(self.dtype) + k = paddle.randn([B, T, H, K], dtype=paddle.float32).cast(self.dtype) + v = paddle.randn([B, T, H, V], dtype=paddle.float32).cast(self.dtype) + # g: negative log decay + g = -F.softplus(paddle.randn([B, T, H], dtype=paddle.float32)).cast(self.dtype) + beta = paddle.sigmoid(paddle.randn([B, T, H], dtype=paddle.float32)).cast(self.dtype) + return q, k, v, g, beta + + def test_fused_recurrent_no_state(self): + """不带初始状态,kernel 输出应与 baseline 一致。""" + from fastdeploy.model_executor.ops.triton_ops.fla import ( + fused_recurrent_gated_delta_rule, + ) + + q, k, v, g, beta = self._make_inputs() + + ref_out, _ = paddle_recurrent_gated_delta_rule_ref( + q.cast(paddle.float32), + k.cast(paddle.float32), + v.cast(paddle.float32), + g.cast(paddle.float32), + beta.cast(paddle.float32), + output_final_state=False, + use_qk_l2norm_in_kernel=False, + ) + + kernel_out, _ = fused_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=1e-2, + atol=1e-2, + err_msg="fused_recurrent_gated_delta_rule (no state) mismatch", + ) + + def test_fused_recurrent_with_l2norm(self): + """带 L2 norm,kernel 输出应与 baseline 一致。""" + from fastdeploy.model_executor.ops.triton_ops.fla import ( + fused_recurrent_gated_delta_rule, + ) + + q, k, v, g, beta = self._make_inputs() + + ref_out, _ = paddle_recurrent_gated_delta_rule_ref( + q.cast(paddle.float32), + k.cast(paddle.float32), + v.cast(paddle.float32), + g.cast(paddle.float32), + beta.cast(paddle.float32), + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + kernel_out, _ = fused_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=1e-3, + atol=1e-3, + err_msg="fused_recurrent_gated_delta_rule (l2norm) mismatch", + ) + + def test_fused_recurrent_output_final_state(self): + """output_final_state=True 时,验证最终状态形状与数值正确。""" + from fastdeploy.model_executor.ops.triton_ops.fla import ( + fused_recurrent_gated_delta_rule, + ) + + q, k, v, g, beta = self._make_inputs() + + ref_out, ref_state = paddle_recurrent_gated_delta_rule_ref( + q.cast(paddle.float32), + k.cast(paddle.float32), + v.cast(paddle.float32), + g.cast(paddle.float32), + beta.cast(paddle.float32), + output_final_state=True, + ) + + kernel_out, kernel_state = fused_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + output_final_state=True, + ) + + self.assertIsNotNone(kernel_state) + self.assertEqual(kernel_state.shape, [self.B, self.H, self.K, self.V]) + + np.testing.assert_allclose( + kernel_state.cast(paddle.float32).numpy(), + ref_state.numpy(), + rtol=1e-3, + atol=1e-3, + err_msg="fused_recurrent final state mismatch", + ) + + def test_fused_recurrent_with_initial_state(self): + """带初始 SSM 状态,验证状态传播正确。""" + from fastdeploy.model_executor.ops.triton_ops.fla import ( + fused_recurrent_gated_delta_rule, + ) + + q, k, v, g, beta = self._make_inputs() + init_state = paddle.randn([self.B, self.H, self.K, self.V], dtype=paddle.float32) + + ref_out, ref_state = paddle_recurrent_gated_delta_rule_ref( + q.cast(paddle.float32), + k.cast(paddle.float32), + v.cast(paddle.float32), + g.cast(paddle.float32), + beta.cast(paddle.float32), + initial_state=init_state.clone(), + output_final_state=True, + ) + + kernel_out, kernel_state = fused_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + initial_state=init_state.clone(), + output_final_state=True, + ) + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=1e-2, + atol=1e-2, + err_msg="fused_recurrent (with init state) output mismatch", + ) + np.testing.assert_allclose( + kernel_state.cast(paddle.float32).numpy(), + ref_state.numpy(), + rtol=1e-2, + atol=1e-2, + err_msg="fused_recurrent (with init state) final state mismatch", + ) + + +class TestChunkGDN(unittest.TestCase): + """测试 chunk_gated_delta_rule (Prefill 路径 SSM kernel).""" + + def setUp(self): + paddle.seed(42) + self.dtype = paddle.bfloat16 + self.B = 1 + self.H_k, self.H_v = 4, 4 # num_k_heads, num_v_heads (no GVA for simplicity) + self.K, self.V = 64, 64 + self.T = 128 # must be multiple of chunk_size=64 + + def _make_inputs(self, T=None): + T = T or self.T + B, Hk, Hv, K, V = self.B, self.H_k, self.H_v, self.K, self.V + q = paddle.randn([B, T, Hk, K], dtype=paddle.float32).cast(self.dtype) + k = paddle.randn([B, T, Hk, K], dtype=paddle.float32).cast(self.dtype) + v = paddle.randn([B, T, Hv, V], dtype=paddle.float32).cast(self.dtype) + g = -F.softplus(paddle.randn([B, T, Hk], dtype=paddle.float32)).cast(self.dtype) + beta = paddle.sigmoid(paddle.randn([B, T, Hk], dtype=paddle.float32)).cast(self.dtype) + return q, k, v, g, beta + + def test_chunk_gdn_no_state(self): + """不带初始状态,chunk kernel 输出应与 baseline 一致(使用 l2norm 保证数值稳定)。""" + from fastdeploy.model_executor.ops.triton_ops.fla import chunk_gated_delta_rule + + q, k, v, g, beta = self._make_inputs() + + ref_out, _ = paddle_chunk_gated_delta_rule_ref( + q.cast(paddle.float32), + k.cast(paddle.float32), + v.cast(paddle.float32), + g.cast(paddle.float32), + beta.cast(paddle.float32), + chunk_size=64, + output_final_state=False, + use_qk_l2norm_in_kernel=True, # l2norm 保证数值不溢出 bf16 + ) + + kernel_out, _ = chunk_gated_delta_rule( + q, + k, + v, + g, + beta, + use_qk_l2norm_in_kernel=True, + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=2e-2, + atol=2e-2, + err_msg="chunk_gated_delta_rule (no state) mismatch", + ) + + def test_chunk_gdn_with_l2norm(self): + """带 L2 norm,chunk kernel 输出应与 baseline 一致。""" + from fastdeploy.model_executor.ops.triton_ops.fla import chunk_gated_delta_rule + + q, k, v, g, beta = self._make_inputs() + + ref_out, _ = paddle_chunk_gated_delta_rule_ref( + q.cast(paddle.float32), + k.cast(paddle.float32), + v.cast(paddle.float32), + g.cast(paddle.float32), + beta.cast(paddle.float32), + chunk_size=64, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + kernel_out, _ = chunk_gated_delta_rule( + q, + k, + v, + g, + beta, + use_qk_l2norm_in_kernel=True, + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=2e-2, + atol=2e-2, + err_msg="chunk_gated_delta_rule (l2norm) mismatch", + ) + + def test_chunk_recurrent_consistency(self): + """chunk 和 recurrent 在相同输入下输出应接近(数值等价性验证)。""" + from fastdeploy.model_executor.ops.triton_ops.fla import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, + ) + + # Use short T=64 for recurrent to be affordable + q, k, v, g, beta = self._make_inputs(T=64) + + chunk_out, _ = chunk_gated_delta_rule( + q, + k, + v, + g, + beta, + use_qk_l2norm_in_kernel=True, + ) + recurrent_out, _ = fused_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + np.testing.assert_allclose( + chunk_out.cast(paddle.float32).numpy(), + recurrent_out.cast(paddle.float32).numpy(), + rtol=2e-2, + atol=2e-2, + err_msg="chunk vs recurrent output mismatch", + ) + + +class TestCausalConv1dUpdate(unittest.TestCase): + """测试 causal_conv1d_update (Decode 单 token conv).""" + + def setUp(self): + paddle.seed(42) + self.dtype = paddle.bfloat16 + self.batch = 4 + self.dim = 512 # conv_dim = key_dim * 2 + value_dim + self.kernel_width = 4 # conv_kernel_size + self.state_len = self.kernel_width - 1 # = 3 + + def _make_inputs(self): + batch, dim, width = self.batch, self.dim, self.kernel_width + state_len = self.state_len + x = paddle.randn([batch, dim], dtype=paddle.float32).cast(self.dtype) + # conv_state pool: [max_seqs, dim, state_len] + max_seqs = batch + 2 + conv_pool = paddle.randn([max_seqs, dim, state_len], dtype=paddle.float32).cast(self.dtype) + weight = paddle.randn([dim, width], dtype=paddle.float32).cast(self.dtype) + bias = paddle.randn([dim], dtype=paddle.float32).cast(self.dtype) + # slot ids: each batch item maps to a pool slot + slot_ids = paddle.arange(batch, dtype=paddle.int32) + return x, conv_pool, weight, bias, slot_ids + + def _paddle_ref(self, x, conv_state_per_seq, weight, bias, activation): + """Pure-Paddle reference (per-sequence, no pool).""" + batch, dim = x.shape + state_len = conv_state_per_seq.shape[-1] + + outs = [] + for i in range(batch): + h = conv_state_per_seq[i : i + 1] # [1, dim, state_len] + xi = x[i : i + 1].unsqueeze(-1) # [1, dim, 1] + # concat and update + combined = paddle.concat([h, xi], axis=-1) # [1, dim, state_len+1] + h_new = combined[:, :, -state_len:] + conv_state_per_seq[i] = h_new[0] + # conv1d + w = weight.unsqueeze(1) # [dim, 1, width] + out = F.conv1d(combined, w, bias, padding=0, groups=dim) + if activation in ["silu", "swish"]: + out = F.silu(out) + outs.append(out[:, :, -1]) # last token + return paddle.concat(outs, axis=0) # [batch, dim] + + def test_causal_conv1d_update_no_bias(self): + """无 bias,causal_conv1d_update 与纯 Paddle 基准对齐。""" + from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_update, + ) + + x, conv_pool, weight, bias, slot_ids = self._make_inputs() + + # Extract per-seq states for reference (using slot_ids) + ref_conv_state = conv_pool[slot_ids].clone() # [batch, dim, state_len] + + ref_out = self._paddle_ref( + x.cast(paddle.float32), + ref_conv_state.cast(paddle.float32), + weight.cast(paddle.float32), + None, + activation="silu", + ) + + pool_for_kernel = conv_pool.clone() + kernel_out = causal_conv1d_update( + x, + pool_for_kernel, + weight, + bias=None, + activation="silu", + conv_state_indices=slot_ids, + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=1e-2, + atol=1e-2, + err_msg="causal_conv1d_update (no bias) mismatch", + ) + + def test_causal_conv1d_update_with_bias(self): + """有 bias,causal_conv1d_update 与纯 Paddle 基准对齐。""" + from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_update, + ) + + x, conv_pool, weight, bias, slot_ids = self._make_inputs() + + ref_conv_state = conv_pool[slot_ids].clone() + ref_out = self._paddle_ref( + x.cast(paddle.float32), + ref_conv_state.cast(paddle.float32), + weight.cast(paddle.float32), + bias.cast(paddle.float32), + activation="silu", + ) + + pool_for_kernel = conv_pool.clone() + kernel_out = causal_conv1d_update( + x, + pool_for_kernel, + weight, + bias=bias, + activation="silu", + conv_state_indices=slot_ids, + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=1e-2, + atol=1e-2, + err_msg="causal_conv1d_update (with bias) mismatch", + ) + + def test_causal_conv1d_update_state_inplace(self): + """验证 conv_state pool 被正确 in-place 更新(滑窗移位)。""" + from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_update, + ) + + x, conv_pool, weight, bias, slot_ids = self._make_inputs() + + ref_conv_state = conv_pool[slot_ids].clone() + + # Build expected new states via reference + for i in range(self.batch): + h = ref_conv_state[i : i + 1] + xi = x[i : i + 1].cast(paddle.float32).unsqueeze(-1) + combined = paddle.concat([h.cast(paddle.float32), xi], axis=-1) + ref_conv_state[i] = combined[:, :, -self.state_len :].cast(self.dtype)[0] + + pool_for_kernel = conv_pool.clone() + _ = causal_conv1d_update( + x, + pool_for_kernel, + weight, + activation="silu", + conv_state_indices=slot_ids, + ) + + # Check pool slots updated correctly + for i in range(self.batch): + slot = slot_ids[i].item() + np.testing.assert_allclose( + pool_for_kernel[slot].cast(paddle.float32).numpy(), + ref_conv_state[i].cast(paddle.float32).numpy(), + rtol=1e-3, + atol=1e-3, + err_msg=f"conv_state pool slot {slot} not updated correctly", + ) + + +class TestCausalConv1dFn(unittest.TestCase): + """测试 causal_conv1d_fn (Prefill varlen conv).""" + + def setUp(self): + paddle.seed(42) + self.dtype = paddle.bfloat16 + self.dim = 256 + self.kernel_width = 4 + self.state_len = self.kernel_width - 1 + + def _make_varlen_inputs(self, seq_lens): + """ + 构造 varlen 输入。 + + Returns: + x: [dim, total_tokens] (channel-last layout) + weight: [dim, kernel_width] + bias: [dim,] + conv_pool: [max_seqs, dim, state_len] + slot_ids: [N] + has_initial_state: [N] bool + query_start_loc: [N+1] + seq_lens_cpu: List[int] + """ + dim, width, state_len = self.dim, self.kernel_width, self.state_len + N = len(seq_lens) + total = sum(seq_lens) + # channel-last: (dim, total_tokens) + x = paddle.randn([dim, total], dtype=paddle.float32).cast(self.dtype) + weight = paddle.randn([dim, width], dtype=paddle.float32).cast(self.dtype) + bias = paddle.randn([dim], dtype=paddle.float32).cast(self.dtype) + max_seqs = N + 2 + conv_pool = paddle.zeros([max_seqs, dim, state_len], dtype=self.dtype) + slot_ids = paddle.arange(N, dtype=paddle.int32) + has_initial_state = paddle.zeros([N], dtype=paddle.bool) + offsets = [0] + for l in seq_lens: + offsets.append(offsets[-1] + l) + query_start_loc = paddle.to_tensor(offsets, dtype=paddle.int32) + return x, weight, bias, conv_pool, slot_ids, has_initial_state, query_start_loc, seq_lens + + def _paddle_ref_prefill(self, x, weight, bias, seq_lens, activation): + """ + Pure-Paddle reference: process each sequence independently. + + x: [dim, total_tokens] (channel-last) + Returns: [dim, total_tokens] + """ + dim, width = weight.shape + state_len = width - 1 + out_parts = [] + offset = 0 + for seqlen in seq_lens: + x_seq = x[:, offset : offset + seqlen] # [dim, seqlen] + # pad left with zeros (no initial state) + padded = F.pad(x_seq.unsqueeze(0), [state_len, 0]) # [1, dim, seqlen+state_len] + w = weight.unsqueeze(1) # [dim, 1, width] + out = F.conv1d( + padded.cast(paddle.float32), + w.cast(paddle.float32), + bias.cast(paddle.float32) if bias is not None else None, + padding=0, + groups=dim, + ) + if activation in ["silu", "swish"]: + out = F.silu(out) + out_parts.append(out.squeeze(0)) # [dim, seqlen] + offset += seqlen + return paddle.concat(out_parts, axis=-1) # [dim, total_tokens] + + def test_causal_conv1d_fn_no_initial_state(self): + """无初始状态(全零)的 prefill varlen conv。""" + from fastdeploy.model_executor.ops.triton_ops.causal_conv1d import ( + causal_conv1d_fn, + ) + + seq_lens = [16, 32, 8] + x, weight, bias, conv_pool, slot_ids, has_init, query_start_loc, seq_lens_cpu = self._make_varlen_inputs( + seq_lens + ) + + ref_out = self._paddle_ref_prefill( + x.cast(paddle.float32), + weight.cast(paddle.float32), + bias.cast(paddle.float32), + seq_lens, + activation="silu", + ) + + kernel_out = causal_conv1d_fn( + x, + weight, + bias, + conv_pool, + query_start_loc, + seq_lens_cpu, + cache_indices=slot_ids, + has_initial_state=has_init, + activation="silu", + ) + + np.testing.assert_allclose( + kernel_out.cast(paddle.float32).numpy(), + ref_out.numpy(), + rtol=2e-2, + atol=5e-2, + err_msg="causal_conv1d_fn (no initial state) mismatch", + ) + + +class TestFusedGDNGating(unittest.TestCase): + """测试 fused_gdn_gating Triton kernel (GDN 门控融合算子).""" + + def setUp(self): + paddle.seed(42) + self.dtype = paddle.bfloat16 + + def _paddle_ref_gating(self, A_log, a, b, dt_bias): + """Pure-Paddle reference for GDN gating.""" + x = a.cast(paddle.float32) + dt_bias.cast(paddle.float32) + softplus_x = F.softplus(x) + g = -paddle.exp(A_log.cast(paddle.float32)) * softplus_x + beta = F.sigmoid(b.cast(paddle.float32)) + return g, beta + + def test_fused_gdn_gating_basic(self): + """基本功能: Triton kernel 输出应与纯 Paddle 基准一致。""" + from fastdeploy.model_executor.ops.triton_ops.fla import fused_gdn_gating + + num_tokens, num_heads = 32, 16 + A_log = -paddle.abs(paddle.randn([num_heads], dtype=paddle.float32)).cast(self.dtype) + a = paddle.randn([num_tokens, num_heads], dtype=paddle.float32).cast(self.dtype) + b = paddle.randn([num_tokens, num_heads], dtype=paddle.float32).cast(self.dtype) + dt_bias = paddle.randn([num_heads], dtype=paddle.float32).cast(self.dtype) + + ref_g, ref_beta = self._paddle_ref_gating(A_log, a, b, dt_bias) + kernel_g, kernel_beta = fused_gdn_gating(A_log, a, b, dt_bias) + + np.testing.assert_allclose( + kernel_g.numpy(), + ref_g.numpy(), + rtol=1e-3, + atol=1e-3, + err_msg="fused_gdn_gating g mismatch", + ) + # beta: Triton kernel stores to b.dtype (bf16) then reads back as fp32, + # so bf16 rounding is expected. Use relaxed tolerance. + np.testing.assert_allclose( + kernel_beta.numpy(), + ref_beta.numpy(), + rtol=5e-3, + atol=5e-3, + err_msg="fused_gdn_gating beta mismatch", + ) + + def test_fused_gdn_gating_output_shape(self): + """输出 shape 应为 [num_tokens, num_heads]。""" + from fastdeploy.model_executor.ops.triton_ops.fla import fused_gdn_gating + + num_tokens, num_heads = 64, 8 + A_log = paddle.randn([num_heads], dtype=self.dtype) + a = paddle.randn([num_tokens, num_heads], dtype=self.dtype) + b = paddle.randn([num_tokens, num_heads], dtype=self.dtype) + dt_bias = paddle.randn([num_heads], dtype=self.dtype) + + g, beta = fused_gdn_gating(A_log, a, b, dt_bias) + self.assertEqual(g.shape, [num_tokens, num_heads]) + self.assertEqual(beta.shape, [num_tokens, num_heads]) + self.assertEqual(g.dtype, paddle.float32) + self.assertEqual(beta.dtype, paddle.float32) + + def test_fused_gdn_gating_single_token(self): + """单 token (decode) 场景: num_tokens=1。""" + from fastdeploy.model_executor.ops.triton_ops.fla import fused_gdn_gating + + num_tokens, num_heads = 1, 16 + A_log = -paddle.abs(paddle.randn([num_heads], dtype=paddle.float32)).cast(self.dtype) + a = paddle.randn([num_tokens, num_heads], dtype=paddle.float32).cast(self.dtype) + b = paddle.randn([num_tokens, num_heads], dtype=paddle.float32).cast(self.dtype) + dt_bias = paddle.randn([num_heads], dtype=paddle.float32).cast(self.dtype) + + ref_g, ref_beta = self._paddle_ref_gating(A_log, a, b, dt_bias) + kernel_g, kernel_beta = fused_gdn_gating(A_log, a, b, dt_bias) + + np.testing.assert_allclose( + kernel_g.numpy(), + ref_g.numpy(), + rtol=1e-3, + atol=1e-3, + err_msg="fused_gdn_gating (single token) g mismatch", + ) + np.testing.assert_allclose( + kernel_beta.numpy(), + ref_beta.numpy(), + rtol=5e-3, + atol=5e-3, + err_msg="fused_gdn_gating (single token) beta mismatch", + ) + + +# ============================================================ +# Entry Point +# ============================================================ + +if __name__ == "__main__": + unittest.main()