diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..f02e2294cc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -152,6 +152,35 @@ def experts( per_expert_scale=self.per_expert_scale, ) + def fused_shared_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + shared_expert_weight: torch.Tensor, + is_prefill: Optional[bool] = None, + ) -> torch.Tensor: + return self.fuse_moe_impl.fused_shared_experts( + input_tensor=input_tensor, + router_logits=router_logits, + w13=self.w13, + w2=self.w2, + correction_bias=self.e_score_correction_bias, + scoring_func=self.scoring_func, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + is_prefill=is_prefill, + per_expert_scale=self.per_expert_scale, + shared_expert_weight=shared_expert_weight, + ) + def low_latency_dispatch( self, hidden_states: torch.Tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..2fa715ef65 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -62,27 +62,6 @@ def _select_experts( topk_weights.mul_(self.routed_scaling_factor) if per_expert_scale is not None: topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts, - end=self.n_routed_experts + self.num_fused_shared_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) return topk_weights, topk_ids def _fused_experts( @@ -111,6 +90,7 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + shared_expert_id=self.n_routed_experts if self.num_fused_shared_experts > 0 else -1, ) return input_tensor @@ -152,3 +132,56 @@ def __call__( is_prefill=is_prefill, ) return output + + def fused_shared_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + correction_bias: Optional[torch.Tensor], + scoring_func: str, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + shared_expert_weight: torch.Tensor, + is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + ): + assert ( + type(self) is FuseMoeTriton + ), "fused shared expert as MoE is only supported by the Triton fused MoE implementation" + topk_weights, topk_ids = self._select_experts( + input_tensor=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + per_expert_scale=per_expert_scale, + ) + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn + + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_shared_experts + + fused_shared_experts( + hidden_states=input_tensor, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_weight=shared_expert_weight, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + shared_expert_id=self.n_routed_experts, + ) + return input_tensor diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..c18606e31c 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -217,22 +217,51 @@ def moe_align1( def moe_align_fused_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] + shared_expert_weight_ptr, # [token_num, 1] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + routed_topk_num: tl.constexpr, + expert_num: tl.constexpr, topk_num: tl.constexpr, + shared_expert_id: tl.constexpr, BLOCK_SIZE: tl.constexpr, + ZERO_EXPERT_TOKEN_NUM: tl.constexpr, + BLOCK_EXPERT: tl.constexpr, + HAS_SHARED_EXPERT_WEIGHT: tl.constexpr, ): token_block = tl.program_id(0) + if ZERO_EXPERT_TOKEN_NUM: + expert_offs = tl.arange(0, BLOCK_EXPERT) + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + tl.debug_barrier() + if shared_expert_id >= 0: + tl.store(expert_token_num_ptr + shared_expert_id, token_num, mask=token_block == 0) + offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num - expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) - weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) + token_ids = offs // topk_num + topk_offsets = offs - token_ids * topk_num + routed_offsets = token_ids * routed_topk_num + topk_offsets + is_shared_expert = topk_offsets >= routed_topk_num + + expert_ids = tl.load(topk_ids_ptr + routed_offsets, mask=mask & (is_shared_expert == 0), other=0) + expert_ids = tl.where(is_shared_expert, shared_expert_id, expert_ids) + weights = tl.load(topk_weights_ptr + routed_offsets, mask=mask & (is_shared_expert == 0), other=0.0) + if HAS_SHARED_EXPERT_WEIGHT: + shared_weights = tl.load(shared_expert_weight_ptr + token_ids, mask=mask & is_shared_expert, other=0.0).to( + tl.float32 + ) + shared_weights = tl.sigmoid(shared_weights) + else: + shared_weights = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32) + weights = tl.where(is_shared_expert, shared_weights, weights) - # 用 atomic_add 给 expert 分配写位置 - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) + # Shared expert appears exactly once per token, so its position is deterministic. + routed_write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask & (is_shared_expert == 0)) + write_pos = tl.where(is_shared_expert, token_ids, routed_write_pos) # 按 token 顺序写 index 和 weight tl.store( @@ -249,8 +278,11 @@ def moe_align_fused_kernel( def _get_moe_align_fused_static_key( topk_weights: torch.Tensor, + shared_expert_id: int = -1, ) -> dict: topk_num = topk_weights.shape[1] + if shared_expert_id >= 0: + topk_num += 1 return { "topk_num": topk_num, } @@ -275,24 +307,43 @@ def _get_moe_align_fused_configs(): mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + shared_expert_id: int = -1, + shared_expert_weight: Optional[torch.Tensor] = None, + run_config: Optional[dict] = None, ): - token_num, topk_num = topk_ids.shape + token_num, routed_topk_num = topk_ids.shape + topk_num = routed_topk_num + (1 if shared_expert_id >= 0 else 0) if run_config is None: run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) + expert_num = expert_token_num.shape[0] + zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE + if shared_expert_weight is not None: + shared_expert_weight = shared_expert_weight.view(token_num, 1) grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( topk_ids, topk_weights, + shared_expert_weight if shared_expert_weight is not None else topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, token_num, + routed_topk_num, + expert_num, topk_num, + shared_expert_id, BLOCK_SIZE=BLOCK_SIZE, + ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, + BLOCK_EXPERT=triton.next_power_of_2(expert_num), + HAS_SHARED_EXPERT_WEIGHT=shared_expert_weight is not None, num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -911,6 +962,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + shared_expert_id: int = -1, + shared_expert_weight: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -922,7 +975,8 @@ def fused_experts_impl( num_tokens, _ = hidden_states.shape E, N, _ = w1.shape CHUNK_SIZE = FFN_MOE_CHUNK_SIZE - topk_num = topk_ids.shape[1] + routed_topk_num = topk_ids.shape[1] + topk_num = routed_topk_num + (1 if shared_expert_id >= 0 else 0) M = min(num_tokens, CHUNK_SIZE) intermediate_cache13_shared = alloc_tensor_func( @@ -954,20 +1008,30 @@ def fused_experts_impl( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_shared_expert_weight = ( + shared_expert_weight[begin_chunk_idx:end_chunk_idx] if shared_expert_weight is not None else None + ) expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_token_count_in_align_kernel = topk_num * tokens_in_chunk <= 128 + expert_to_token_num = ( + torch.empty((E,), dtype=torch.int32, device="cuda") + if expert_token_count_in_align_kernel + else torch.zeros((E,), dtype=torch.int32, device="cuda") + ) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, expert_token_num=expert_to_token_num, topk_ids=curr_topk_ids, topk_weights=curr_topk_weights, + shared_expert_id=shared_expert_id, + shared_expert_weight=curr_shared_expert_weight, ) reused_mblock_infos = grouped_matmul( - curr_topk_ids.numel(), + tokens_in_chunk * topk_num, curr_hidden_states, a1_scale, expert_to_token_num, @@ -993,7 +1057,7 @@ def fused_experts_impl( ) grouped_matmul( - curr_topk_ids.numel(), + tokens_in_chunk * topk_num, intermediate_cache2.view(-1, N // 2), a2_scale, expert_to_token_num, @@ -1012,7 +1076,8 @@ def fused_experts_impl( ) moe_sum_reduce( - intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx] + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], ) return out_hidden_states @@ -1035,6 +1100,7 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_id: int = -1, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1120,7 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) @@ -1075,6 +1142,7 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_id: int = -1, ) -> None: pass @@ -1105,7 +1173,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, -) -> None: + shared_expert_id: int = -1, +) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, @@ -1124,6 +1193,7 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) @@ -1145,7 +1215,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, -) -> None: + shared_expert_id: int = -1, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1176,6 +1247,7 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_id: int = -1, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1267,7 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) return hidden_states else: @@ -1215,4 +1288,49 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_id=shared_expert_id, ) + + +def fused_shared_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_expert_weight: torch.Tensor, + inplace: bool = True, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + layout: str = "blocked", + alpha: Optional[float] = None, + limit: Optional[float] = None, + shared_expert_id: int = -1, +): + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + use_fp8_w8a8, + use_int8_w8a16, + w1_bias, + w2_bias, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + layout=layout, + alpha=alpha, + limit=limit, + shared_expert_id=shared_expert_id, + shared_expert_weight=shared_expert_weight, + ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index 45c7ea73c6..a63d92692e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -122,7 +122,7 @@ def silu_and_mul_fwd( alpha=None, run_config=None, ): - assert input.is_contiguous() + assert input.stride(-1) == 1 assert output.is_contiguous() assert (limit is None and alpha is None) or (limit is not None and alpha is not None) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..28221344b8 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -47,7 +47,11 @@ def _moe_sum_reduce_kernel( def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): - return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)} + return { + "topk_num": input.shape[1], + "hidden_dim": input.shape[2], + "out_dtype": str(output.dtype), + } def _get_moe_sum_reduce_configs(): diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..c5e88cf2cf 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -37,9 +37,9 @@ def _parse_config(self): self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # enable_fused_shared_experts can only work with tensor parallelism - assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." + if not get_env_start_args().disable_fused_shared_experts and self.is_moe: + # fused shared experts can only work with tensor parallelism + assert not get_env_start_args().enable_ep_moe, "fused shared experts can only work with tp mode." self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..d9ac369960 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -28,14 +28,24 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index e4d80e6ff9..b5ae5c394e 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -10,8 +10,10 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -114,19 +116,33 @@ def _compute_shared_expert( ): input = input.view(-1, self.embed_dim_) shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) + gate = layer_weight.ffn_gate.mm(input) + sigmoid_mul_(shared_expert_out, gate) return shared_expert_out def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) - hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) + if getattr(layer_weight, "num_fused_shared_experts", 0) > 0: + shared_expert_gate = layer_weight.ffn_gate.mm(hidden_states) + layer_weight.experts.fused_shared_experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + shared_expert_weight=shared_expert_gate, + ) + hidden_states = hidden_states.view(num_tokens, hidden_dim) + return hidden_states + + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -169,13 +185,25 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -199,15 +227,24 @@ def _get_o( input, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + o_tensor = self._get_o_local(input=input, infer_state=infer_state, layer_weight=layer_weight) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + return o_tensor + + def _get_o_local( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) + sigmoid_mul_(input, infer_state.gate_value) infer_state.gate_value = None o_tensor = layer_weight.o_proj.mm(input) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor # ==================== GDN Helper Methods ==================== @@ -257,8 +294,9 @@ def gdn_forward( else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - core_attn_out = self._gdn_decode_kernel( + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -406,6 +444,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -413,18 +452,24 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( mixed_qkv, + z, + a, + b, conv_states, layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + layer_weight.linear_conv1d.bias, + infer_state.b_buffer_idx, + self.activation, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, ) - - # Recurrent processing with fused gating; the kernel reads the - # q/k/v/a/b column views directly via per-token strides (no copies) - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -438,4 +483,4 @@ def _gdn_decode_kernel( a_raw=a, b_raw=b, ) - return core_attn_out + return core_attn_out, z diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 0d415ca0e8..cc3747111c 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -10,7 +10,85 @@ TpParameterWeight, QKVROWNMMWeight, QKGEMMANormWeight, + FusedMoeWeight, ) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = 1 + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_tp_padded_head_num(self, head_num): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + if self.tp_world_size_ % head_num == 0: + self.kv_repeat_times = self.tp_world_size_ // head_num + return self.kv_repeat_times * head_num // self.tp_world_size_ + raise ValueError( + f"head_num must be divisible by tp_world_size_ or vice versa, found {head_num} % {self.tp_world_size_}" + ) class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -23,25 +101,39 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( - in_dim=in_dim, - q_head_num=self.q_head_num_, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) + qkv_quant = self.get_quant_method("qkv_proj") + gate_quant = self.get_quant_method("o_gate_proj") + if qkv_quant.method_name == "none" and gate_quant.method_name == "none": + self.qkvo_gate_proj = QKVGatedROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, + ) + else: + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=qkv_quant, + ) + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=gate_quant, + ) def _init_weight(self): if self.is_linear_attention_layer: @@ -57,10 +149,54 @@ def _init_weight(self): self._init_norm() def _init_moe(self): - super()._init_moe() + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.num_fused_shared_experts = 1 if self._can_fuse_shared_expert_as_moe() else 0 + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + if self.num_fused_shared_experts > 0: + self.ffn_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.num_fused_shared_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) self._init_gated_ffn() return + def _can_fuse_shared_expert_as_moe(self): + start_args = get_env_start_args() + if not self.is_moe or start_args.enable_ep_moe or start_args.disable_fused_shared_experts: + return False + return ( + self.network_config_.get("shared_expert_intermediate_size") == self.network_config_["moe_intermediate_size"] + ) + def _init_norm(self): hidden_size = self.network_config_["hidden_size"] self.att_norm_weight_ = NoTpGEMMANormWeight( @@ -85,6 +221,8 @@ def _init_gated_ffn(self): hidden_size = self.network_config_["hidden_size"] if "shared_expert_intermediate_size" not in self.network_config_: return + if getattr(self, "num_fused_shared_experts", 0) > 0: + return prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" inter_size = self.network_config_["shared_expert_intermediate_size"] if get_env_start_args().enable_ep_moe: @@ -143,6 +281,24 @@ def _split_q_with_gate(self, weights): weights[self._q_weight_name] = _q_proj weights[self._o_gate_weight_name] = _gate_proj + def _rename_shared_expert_to_moe_expert(self, weights): + if getattr(self, "num_fused_shared_experts", 0) == 0: + return + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts.{self.n_routed_experts}" + suffixes = [ + self.experts.quant_method.weight_suffix, + self.experts.quant_method.weight_scale_suffix, + self.experts.quant_method.weight_zero_point_suffix, + ] + for proj_name in ("gate_proj", "up_proj", "down_proj"): + for suffix in suffixes: + if suffix is None: + continue + old_name = f"{old_prefix}.{proj_name}.{suffix}" + if old_name in weights: + weights[f"{new_prefix}.{proj_name}.{suffix}"] = weights[old_name] + def _parse_config(self): super()._parse_config() self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] @@ -288,6 +444,7 @@ def _parse_linear_conv1d(self, weight): def load_hf_weights(self, weights): self._split_q_with_gate(weights) + self._rename_shared_expert_to_moe_expert(weights) if self.is_linear_attention_layer: self._preprocess_weight(weights) super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py index b0dc41a3c1..22a93a2c99 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -54,11 +54,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - stride_q_tok: tl.constexpr, - stride_k_tok: tl.constexpr, - stride_v_tok: tl.constexpr, - stride_a_tok: tl.constexpr, - stride_b_tok: tl.constexpr, stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, @@ -99,15 +94,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) - p_q = q + bos * stride_q_tok + i_h * K + o_k - p_k = k + bos * stride_k_tok + i_h * K + o_k - p_v = v + bos * stride_v_tok + i_hv * V + o_v + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v if FUSE_GATING: # Fused gating: load per-head constants once, compute g/beta inline per token b_A_log = tl.load(A_log + i_hv).to(tl.float32) b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) - p_a_raw = a_raw + bos * stride_a_tok + i_hv - p_b_raw = b_raw + bos * stride_b_tok + i_hv + p_a_raw = a_raw + bos * HV + i_hv + p_b_raw = b_raw + bos * HV + i_hv else: if IS_BETA_HEADWISE: p_beta = beta + (bos * HV + i_hv) * V + o_v @@ -198,13 +193,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - p_q += stride_q_tok - p_k += stride_k_tok + p_q += H * K + p_k += H * K p_o += HV * V - p_v += stride_v_tok + p_v += HV * V if FUSE_GATING: - p_a_raw += stride_a_tok - p_b_raw += stride_b_tok + p_a_raw += HV + p_b_raw += HV else: if not IS_KDA: p_g += HV @@ -213,34 +208,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta += HV * (V if IS_BETA_HEADWISE else 1) -def _ensure_qkv_token_strided(x: torch.Tensor, inner_numel: int): - """Return q/k/v and token stride, copying only when needed.""" - if x is None: - return None, 0 - - # Decode layout must be [tokens, 1, head, dim]. - assert x.shape[1] == 1, "q/k/v must use decode layout [tokens, 1, head, dim]" - - # Packed tail [head, dim] means the last two strides are [dim, 1]. - tail_contiguous = x.stride()[-2:] == (x.shape[-1], 1) - if not tail_contiguous: - x = x.contiguous() - return x, inner_numel - else: - return x, x.stride(0) - - -def _ensure_gate_token_strided(x: torch.Tensor, inner_numel: int): - """Return a_raw/b_raw and token stride, copying only when needed.""" - if x is None: - return None, 0 - # a_raw/b_raw are 2D [tokens, HV]; the tail HV dimension must be packed. - if x.stride(1) != 1: - x = x.contiguous() - return x, inner_numel - return x, x.stride(0) - - def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -264,16 +231,7 @@ def fused_recurrent_gated_delta_rule_fwd( ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] - # In LightLLM's Qwen3Next inference path this fused recurrent kernel is - # used only for decode. Prefill/varlen requests are handled by - # chunk_gated_delta_rule, so keep cu_seqlens out of this strided-view path. - assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" - N = B - q, stride_q_tok = _ensure_qkv_token_strided(q, H * K) - k, stride_k_tok = _ensure_qkv_token_strided(k, H * K) - v, stride_v_tok = _ensure_qkv_token_strided(v, HV * V) - a_raw, stride_a_tok = _ensure_gate_token_strided(a_raw, HV) - b_raw, stride_b_tok = _ensure_gate_token_strided(b_raw, HV) + N = B if cu_seqlens is None else len(cu_seqlens) - 1 BK = triton.next_power_of_2(K) if T == 1: # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) @@ -303,23 +261,20 @@ def fused_recurrent_gated_delta_rule_fwd( stride_init_state_token = initial_state.stride(0) stride_final_state_token = final_state.stride(0) - # Strides for read indices. The kernel advances along a row with `+ i_t` - # (token stride 1), so 2D index tensors must have contiguous rows. + # Strides for read indices if ssm_state_indices is None: stride_indices_seq, stride_indices_tok = 1, 1 elif ssm_state_indices.ndim == 1: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 else: - assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - # Strides for write indices (if provided); same contiguous-row requirement + # Strides for write indices (if provided) if ssm_state_write_indices is None: stride_write_indices_seq, stride_write_indices_tok = 1, 1 elif ssm_state_write_indices.ndim == 1: stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 else: - assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() grid = (NK, NV, N * HV) @@ -350,11 +305,6 @@ def fused_recurrent_gated_delta_rule_fwd( V=V, BK=BK, BV=BV, - stride_q_tok=stride_q_tok, - stride_k_tok=stride_k_tok, - stride_v_tok=stride_v_tok, - stride_a_tok=stride_a_tok, - stride_b_tok=stride_b_tok, stride_init_state_token=stride_init_state_token, stride_final_state_token=stride_final_state_token, stride_indices_seq=stride_indices_seq, @@ -398,12 +348,10 @@ def forward( b_raw: torch.Tensor | None = None, out: torch.Tensor | None = None, ): - # q/k/v/a_raw/b_raw may be non-contiguous column views of one projection - # output; the kernel handles them via per-token strides (no copies). o, final_state = fused_recurrent_gated_delta_rule_fwd( - q=q, - k=k, - v=v, + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), g=g.contiguous() if g is not None else None, beta=beta.contiguous() if beta is not None else None, scale=scale, @@ -416,8 +364,8 @@ def forward( use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, A_log=A_log, dt_bias=dt_bias, - a_raw=a_raw, - b_raw=b_raw, + a_raw=a_raw.contiguous() if a_raw is not None else None, + b_raw=b_raw.contiguous() if b_raw is not None else None, out=out, ) @@ -469,9 +417,8 @@ def fused_recurrent_gated_delta_rule( Whether to store the final state in-place to save memory. Default: `True`. cu_seqlens (torch.LongTensor): - Must be `None`. In LightLLM this fused recurrent kernel is used only - by the Qwen3Next decode path; prefill/varlen requests use - `chunk_gated_delta_rule`. + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. ssm_state_indices (Optional[torch.Tensor]): Indices to map the input sequences to the initial/final states. num_accepted_tokens (Optional[torch.Tensor]): @@ -486,9 +433,10 @@ def fused_recurrent_gated_delta_rule( Examples:: >>> import torch >>> import torch.nn.functional as F + >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule - # decode inputs - >>> B, T, H, HV, K, V = 4, 1, 4, 8, 512, 512 + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, HV, V, device='cuda') @@ -499,10 +447,21 @@ def fused_recurrent_gated_delta_rule( q, k, v, g, beta, initial_state=h0, ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) """ - # This wrapper is only used for Qwen3Next decode inference in LightLLM. - # Keep varlen/prefill inputs on chunk_gated_delta_rule instead. - assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if scale is None: scale = k.shape[-1] ** -0.5 else: diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..1c5f088ef1 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,165 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + s0 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s1 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s2 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w0 = tl.load(conv_weight + offs * stride_w_d + 0 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w1 = tl.load(conv_weight + offs * stride_w_d + 1 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w2 = tl.load(conv_weight + offs * stride_w_d + 2 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w3 = tl.load(conv_weight + offs * stride_w_d + 3 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y = s0 * w0 + s1 * w1 + s2 * w2 + x * w3 + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, s1, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, s2, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..e5b0e282ab --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,50 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = tl.sigmoid(gate_vals) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x_arg, + gate_arg, + x_arg.stride(0), + x_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..9de8993de8 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -681,9 +681,12 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) parser.add_argument( - "--enable_fused_shared_experts", + "--disable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when tensor parallelism""", + help=( + "Disable fused shared experts for supported MoE models. " + "It is enabled by default and only works with tensor parallelism." + ), ) parser.add_argument( "--mtp_mode", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40c8028158..46959827f4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -157,6 +157,7 @@ class StartArgs: enable_ep_moe: bool = field(default=False) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) + disable_fused_shared_experts: bool = field(default=False) mtp_mode: Optional[str] = field( default=None, metadata={ diff --git a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py deleted file mode 100644 index cf9d06ec98..0000000000 --- a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import torch - -from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( - fused_recurrent_gated_delta_rule, -) - -if not torch.cuda.is_available(): - pytest.skip("CUDA required", allow_module_level=True) - - -@pytest.mark.parametrize("batch", [1, 2, 16]) -def test_decode_strided_views_match_contiguous(batch): - """q/k/v/a/b passed as column views of one projection output (the decode - path layout) must produce the same result as contiguous copies.""" - torch.manual_seed(0) - H, HV, K, V = 2, 8, 128, 128 - key_dim, value_dim = H * K, HV * V - qkv_dim = 2 * key_dim + value_dim - total_dim = qkv_dim + value_dim + 2 * HV # qkv + z + b + a - cache_slots = 64 - - mixed = torch.randn(batch, total_dim, device="cuda", dtype=torch.bfloat16) - mixed_qkv = mixed[:, :qkv_dim] - b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] - a_raw = mixed[:, qkv_dim + value_dim + HV :] - - query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) - q = query.view(batch, 1, H, K) - k = key.view(batch, 1, H, K) - v = value.view(batch, 1, HV, V) - - A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 - dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 - ssm_state = torch.randn(cache_slots, HV, K, V, device="cuda", dtype=torch.bfloat16) - idx = torch.randperm(cache_slots, device="cuda")[:batch].to(torch.int32) - - def run(q_, k_, v_, a_, b_, state): - out, _ = fused_recurrent_gated_delta_rule( - q=q_, - k=k_, - v=v_, - initial_state=state, - inplace_final_state=True, - ssm_state_indices=idx, - use_qk_l2norm_in_kernel=True, - A_log=A_log, - dt_bias=dt_bias, - a_raw=a_, - b_raw=b_, - ) - return out - - state_ref = ssm_state.clone() - out_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous(), state_ref) - state_strided = ssm_state.clone() - out_strided = run(q, k, v, a_raw, b_raw, state_strided) - - assert torch.equal(out_ref, out_strided) - assert torch.equal(state_ref, state_strided) - - -def test_cu_seqlens_is_not_supported(): - """The fused recurrent kernel is decode-only in LightLLM's Qwen3Next path.""" - H, HV, K, V = 2, 2, 4, 4 - q = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) - k = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) - v = torch.randn(1, 2, HV, V, device="cuda", dtype=torch.bfloat16) - initial_state = torch.randn(1, HV, K, V, device="cuda", dtype=torch.bfloat16) - cu_seqlens = torch.tensor([0, 2], device="cuda", dtype=torch.long) - - with pytest.raises(AssertionError, match="decode-only fused recurrent kernel"): - fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - initial_state=initial_state, - cu_seqlens=cu_seqlens, - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"])