From 064b745d3162ed4c41d0cc8a5ae541c7f0d447ef Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Tue, 16 Jun 2026 19:56:52 +0800 Subject: [PATCH 1/4] feat: add fused moe shared-expert support --- .../fused_moe/fused_moe_weight.py | 4 + .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/deepgemm_impl.py | 2 + .../fused_moe/impl/marlin_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 8 + .../fused_moe/grouped_fused_moe.py | 45 ++- .../triton_kernel/fused_moe/moe_sum_reduce.py | 56 +++- .../triton_kernel/norm/gated_rmsnorm.py | 7 - .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_infer/transformer_layer_infer.py | 77 +++-- .../layer_weights/transformer_layer_weight.py | 127 ++++++-- .../triton_kernel/fla/ops/fused_recurrent.py | 115 +++---- .../triton_kernel/gdn_decode_pack.py | 284 ++++++++++++++++++ .../triton_kernel/shared_expert_gate.py | 108 +++++++ .../qwen3next/test_fused_recurrent_strided.py | 83 ----- 15 files changed, 720 insertions(+), 218 deletions(-) create mode 100644 lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py create mode 100644 lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py delete mode 100644 unit_tests/models/qwen3next/test_fused_recurrent_strided.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..d9a77b39a5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -134,6 +134,8 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -150,6 +152,8 @@ def experts( num_expert_group=num_expert_group, is_prefill=is_prefill, per_expert_scale=self.per_expert_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index dd6f9a6880..b54b03ee05 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -63,5 +63,7 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..bc0e86d7eb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,6 +76,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): output = fused_experts( hidden_states=input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..417d001c72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,6 +30,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..fdda2b2139 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -94,6 +94,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -111,6 +113,8 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return input_tensor @@ -129,6 +133,8 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -150,5 +156,7 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return output diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 76acea25a7..cec9c53e52 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -221,10 +221,17 @@ def moe_align_fused_kernel( expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + expert_num: tl.constexpr, topk_num: tl.constexpr, BLOCK_SIZE: tl.constexpr, + ZERO_EXPERT_TOKEN_NUM: tl.constexpr, + BLOCK_EXPERT: tl.constexpr, ): token_block = tl.program_id(0) + if ZERO_EXPERT_TOKEN_NUM: + expert_offs = tl.arange(0, BLOCK_EXPERT) + tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num @@ -282,6 +289,8 @@ def moe_align_fused( run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) + expert_num = expert_token_num.shape[0] + zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( @@ -291,8 +300,11 @@ def moe_align_fused( expert_to_weight, expert_token_num, token_num, + expert_num, topk_num, BLOCK_SIZE=BLOCK_SIZE, + ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, + BLOCK_EXPERT=triton.next_power_of_2(expert_num), num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -911,6 +923,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -957,7 +971,12 @@ def fused_experts_impl( expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_token_count_in_align_kernel = topk_num * tokens_in_chunk <= 128 + expert_to_token_num = ( + torch.empty((E,), dtype=torch.int32, device="cuda") + if expert_token_count_in_align_kernel + else torch.zeros((E,), dtype=torch.int32, device="cuda") + ) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, @@ -1011,8 +1030,12 @@ def fused_experts_impl( bias=w2_bias, ) + has_shared_gate = shared_expert_out is not None moe_sum_reduce( - intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx] + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + shared=None if not has_shared_gate else shared_expert_out[begin_chunk_idx:end_chunk_idx], + gate=None if not has_shared_gate else shared_expert_gate[begin_chunk_idx:end_chunk_idx], ) return out_hidden_states @@ -1035,6 +1058,8 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1079,8 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1075,6 +1102,8 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: pass @@ -1105,6 +1134,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return fused_experts_impl( hidden_states, @@ -1124,6 +1155,8 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) @@ -1145,6 +1178,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ) -> None: return torch.empty_like(hidden_states) @@ -1176,6 +1211,8 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + shared_expert_out: Optional[torch.Tensor] = None, + shared_expert_gate: Optional[torch.Tensor] = None, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1232,8 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) return hidden_states else: @@ -1215,4 +1254,6 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + shared_expert_out=shared_expert_out, + shared_expert_gate=shared_expert_gate, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..4f95cca7c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -14,12 +14,20 @@ def _moe_sum_reduce_kernel( output_ptr, output_stride_0, output_stride_1, + shared_ptr, + shared_stride_0, + shared_stride_1, + gate_ptr, + gate_stride_0, + gate_stride_1, token_num: int, topk_num: int, hidden_dim: int, BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, NUM_STAGE: tl.constexpr, + HAS_SHARED_GATE: tl.constexpr, + GATE_DIM: tl.constexpr, ): input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) @@ -42,12 +50,38 @@ def _moe_sum_reduce_kernel( for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tmp = tl.load(input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0) accumulator += tmp + if HAS_SHARED_GATE: + shared = tl.load( + shared_ptr + token_index * shared_stride_0 + offs_dim * shared_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + if GATE_DIM == 1: + gate = tl.load(gate_ptr + token_index * gate_stride_0).to(tl.float32) + tl.zeros( + (BLOCK_DIM,), dtype=tl.float32 + ) + else: + gate = tl.load( + gate_ptr + token_index * gate_stride_0 + offs_dim * gate_stride_1, + mask=offs_dim < dim_end, + other=0.0, + ).to(tl.float32) + gate = 1.0 / (1.0 + tl.exp(-gate)) + accumulator += shared * gate store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end) -def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): - return {"topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype)} +def _get_moe_sum_reduce_static_key( + input: torch.Tensor, output: torch.Tensor, shared: torch.Tensor = None, gate: torch.Tensor = None +): + return { + "topk_num": input.shape[1], + "hidden_dim": input.shape[2], + "out_dtype": str(output.dtype), + "has_shared_gate": shared is not None, + "gate_dim": 0 if gate is None else gate.shape[-1], + } def _get_moe_sum_reduce_configs(): @@ -67,12 +101,20 @@ def _get_moe_sum_reduce_configs(): run_key_func=lambda input: input.shape[0], mutates_args=["output"], ) -def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None): +def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate=None, run_config: Dict = None): assert input.is_contiguous() assert output.is_contiguous() token_num, topk_num, hidden_dim = input.shape assert output.shape[0] == token_num and output.shape[1] == hidden_dim + has_shared_gate = shared is not None + if has_shared_gate: + assert gate is not None + shared = shared.view(token_num, hidden_dim) + gate = gate.view(token_num, gate.shape[-1]) + assert shared.is_contiguous() + assert gate.is_contiguous() + assert gate.shape[1] in (1, hidden_dim) if not run_config: run_config = { @@ -97,12 +139,20 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = *input.stride(), output, *output.stride(), + shared if has_shared_gate else output, + shared.stride(0) if has_shared_gate else 0, + shared.stride(1) if has_shared_gate else 0, + gate if has_shared_gate else output, + gate.stride(0) if has_shared_gate else 0, + gate.stride(1) if has_shared_gate else 0, token_num=token_num, topk_num=topk_num, hidden_dim=hidden_dim, BLOCK_M=BLOCK_M, BLOCK_DIM=BLOCK_DIM, NUM_STAGE=NUM_STAGE, + HAS_SHARED_GATE=has_shared_gate, + GATE_DIM=gate.shape[1] if has_shared_gate else 0, num_warps=num_warps, ) return diff --git a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py index 89db5e00cb..c62c5eb5d2 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/gated_rmsnorm.py @@ -16,7 +16,6 @@ def gated_rmsnorm_forward_kernel( W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch (required, not optional) - Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, @@ -33,7 +32,6 @@ def gated_rmsnorm_forward_kernel( X += row * stride_x_row + group * N Y += row * stride_y_row + group * N Z += row * stride_z_row + group * N - Rstd += group * M W += group * N if HAS_BIAS: B += group * N @@ -47,7 +45,6 @@ def gated_rmsnorm_forward_kernel( xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) @@ -128,9 +125,6 @@ def gated_rmsnorm_forward( else: out = torch.empty_like(x) assert out.stride(-1) == 1 - # For RMS norm, we still need rstd for the kernel - rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) - # Default heuristic when autotune is disabled or no config provided if not run_config: # Less than 64KB per feature: enqueue fused kernel @@ -160,7 +154,6 @@ def gated_rmsnorm_forward( weight, bias, z, - rstd, x.stride(0), out.stride(0), z.stride(0), diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..d9ac369960 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -28,14 +28,24 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - o_gate = layer_weight._o_gate_proj.mm(input) - # In-place sigmoid for gate - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index e4d80e6ff9..3492041813 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -10,8 +10,10 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple -from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import add_shared_expert_gate_, sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -114,15 +116,14 @@ def _compute_shared_expert( ): input = input.view(-1, self.embed_dim_) shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input).sigmoid_() - shared_expert_out.mul_(gate) - return shared_expert_out + gate = layer_weight.ffn_gate.mm(input) + return shared_expert_out, gate def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape @@ -135,15 +136,16 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + shared_expert_out=shared_expert_out, + shared_expert_gate=gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) - hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -158,7 +160,7 @@ def _moe_ffn_edp( is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) - ep_output.add_(shared_expert_out) + add_shared_expert_gate_(ep_output, shared_expert_out, gate) return ep_output def _get_qkv( @@ -169,13 +171,25 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - qkv_out = layer_weight.qkv_proj.mm(input) + qkvo_gate_proj = getattr(layer_weight, "qkvo_gate_proj", None) + if qkvo_gate_proj is None: + qkv_out = layer_weight.qkv_proj.mm(input) + o_gate = layer_weight._o_gate_proj.mm(input) + else: + qkv_gate_out = qkvo_gate_proj.mm(input) + qkv_out, o_gate = qkv_gate_out.split( + [ + self.tp_q_head_num_ * self.head_dim_ * 2 + + (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_, + self.tp_q_head_num_ * self.head_dim_, + ], + dim=-1, + ) q, cache_kv = qkv_out.split( [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1, ) - o_gate = layer_weight._o_gate_proj.mm(input) - infer_state.gate_value = o_gate.sigmoid_() + infer_state.gate_value = o_gate layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], @@ -199,15 +213,24 @@ def _get_o( input, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + o_tensor = self._get_o_local(input=input, infer_state=infer_state, layer_weight=layer_weight) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + return o_tensor + + def _get_o_local( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, ) -> torch.Tensor: """Output projection with gating (in-place multiply to save one allocation).""" if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - input.mul_(infer_state.gate_value) + sigmoid_mul_(input, infer_state.gate_value) infer_state.gate_value = None o_tensor = layer_weight.o_proj.mm(input) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) return o_tensor # ==================== GDN Helper Methods ==================== @@ -257,8 +280,9 @@ def gdn_forward( else: mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - core_attn_out = self._gdn_decode_kernel( + core_attn_out, z = self._gdn_decode_kernel( mixed_qkv, + z, conv_states, ssm_states, a, @@ -406,6 +430,7 @@ def _gdn_prefill_kernel( def _gdn_decode_kernel( self, mixed_qkv: torch.Tensor, + z: torch.Tensor, conv_states: torch.Tensor, ssm_states: torch.Tensor, a: torch.Tensor, @@ -413,18 +438,24 @@ def _gdn_decode_kernel( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ): - mixed_qkv = causal_conv1d_update( + # Recurrent processing with fused gating. Decode uses a specialized + # conv+pack kernel to avoid materializing the post-conv qkv tensor + # before immediately splitting it into q/k/v. + query, key, value, z, a, b = conv_pack_gdn_decode_inputs( mixed_qkv, + z, + a, + b, conv_states, layer_weight.linear_conv1d.mm_param.weight, - bias=layer_weight.linear_conv1d.bias, - activation=self.activation, - conv_state_indices=infer_state.b_buffer_idx, + layer_weight.linear_conv1d.bias, + infer_state.b_buffer_idx, + self.activation, + self.tp_num_k_heads, + self.head_k_dim, + self.tp_num_v_heads, + self.head_v_dim, ) - - # Recurrent processing with fused gating; the kernel reads the - # q/k/v/a/b column views directly via per-token strides (no copies) - query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, k=key, @@ -438,4 +469,4 @@ def _gdn_decode_kernel( a_raw=a, b_raw=b, ) - return core_attn_out + return core_attn_out, z diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 0d415ca0e8..51b702039b 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -11,6 +11,83 @@ QKVROWNMMWeight, QKGEMMANormWeight, ) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class QKVGatedROWNMMWeight(MMWeightTpl): + def __init__( + self, + in_dim, + q_head_num, + kv_head_num, + head_dim, + weight_names, + data_type, + bias_names=None, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.q_repeat_times = 1 + self.kv_repeat_times = 1 + assert ( + q_head_num % self.tp_world_size_ == 0 + ), f"q_head_num must be divisible by tp_world_size_, found {q_head_num} % {self.tp_world_size_}" + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or vice versa, " + f"found {kv_head_num} % {self.tp_world_size_}" + ) + q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + super().__init__( + in_dim=in_dim, + out_dims=[q_hidden_size, kv_hidden_size, kv_hidden_size, q_hidden_size], + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.q_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, + ) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 0 or sub_child_index == 3: + return self.q_param_slicer + return self.kv_param_slicer + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + if self.bias_names is not None: + for sub_child_index, bias_name in enumerate(self.bias_names): + if bias_name is None: + self.bias_list[sub_child_index].zero_() + self.bias_list[sub_child_index].load_ok = True + + def _get_tp_padded_head_num(self, head_num): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + if self.tp_world_size_ % head_num == 0: + self.kv_repeat_times = self.tp_world_size_ // head_num + return self.kv_repeat_times * head_num // self.tp_world_size_ + raise ValueError( + f"head_num must be divisible by tp_world_size_ or vice versa, found {head_num} % {self.tp_world_size_}" + ) class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -23,25 +100,39 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _init_qkv(self): in_dim = self.n_embed q_out_dim = self.q_head_num_ * self.head_dim - self.qkv_proj = QKVROWNMMWeight( - in_dim=in_dim, - q_head_num=self.q_head_num_, - kv_head_num=self.k_head_num_, - head_dim=self.head_dim, - weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], - data_type=self.data_type_, - bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_method=self.get_quant_method("qkv_proj"), - ) self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" - self._o_gate_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[q_out_dim], - weight_names=[self._o_gate_weight_name], - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("o_gate_proj"), - ) + qkv_quant = self.get_quant_method("qkv_proj") + gate_quant = self.get_quant_method("o_gate_proj") + if qkv_quant.method_name == "none" and gate_quant.method_name == "none": + self.qkvo_gate_proj = QKVGatedROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name, self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name, None], + quant_method=qkv_quant, + ) + else: + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=qkv_quant, + ) + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=gate_quant, + ) def _init_weight(self): if self.is_linear_attention_layer: diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py index b0dc41a3c1..22a93a2c99 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -54,11 +54,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - stride_q_tok: tl.constexpr, - stride_k_tok: tl.constexpr, - stride_v_tok: tl.constexpr, - stride_a_tok: tl.constexpr, - stride_b_tok: tl.constexpr, stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, @@ -99,15 +94,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) - p_q = q + bos * stride_q_tok + i_h * K + o_k - p_k = k + bos * stride_k_tok + i_h * K + o_k - p_v = v + bos * stride_v_tok + i_hv * V + o_v + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v if FUSE_GATING: # Fused gating: load per-head constants once, compute g/beta inline per token b_A_log = tl.load(A_log + i_hv).to(tl.float32) b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) - p_a_raw = a_raw + bos * stride_a_tok + i_hv - p_b_raw = b_raw + bos * stride_b_tok + i_hv + p_a_raw = a_raw + bos * HV + i_hv + p_b_raw = b_raw + bos * HV + i_hv else: if IS_BETA_HEADWISE: p_beta = beta + (bos * HV + i_hv) * V + o_v @@ -198,13 +193,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - p_q += stride_q_tok - p_k += stride_k_tok + p_q += H * K + p_k += H * K p_o += HV * V - p_v += stride_v_tok + p_v += HV * V if FUSE_GATING: - p_a_raw += stride_a_tok - p_b_raw += stride_b_tok + p_a_raw += HV + p_b_raw += HV else: if not IS_KDA: p_g += HV @@ -213,34 +208,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta += HV * (V if IS_BETA_HEADWISE else 1) -def _ensure_qkv_token_strided(x: torch.Tensor, inner_numel: int): - """Return q/k/v and token stride, copying only when needed.""" - if x is None: - return None, 0 - - # Decode layout must be [tokens, 1, head, dim]. - assert x.shape[1] == 1, "q/k/v must use decode layout [tokens, 1, head, dim]" - - # Packed tail [head, dim] means the last two strides are [dim, 1]. - tail_contiguous = x.stride()[-2:] == (x.shape[-1], 1) - if not tail_contiguous: - x = x.contiguous() - return x, inner_numel - else: - return x, x.stride(0) - - -def _ensure_gate_token_strided(x: torch.Tensor, inner_numel: int): - """Return a_raw/b_raw and token stride, copying only when needed.""" - if x is None: - return None, 0 - # a_raw/b_raw are 2D [tokens, HV]; the tail HV dimension must be packed. - if x.stride(1) != 1: - x = x.contiguous() - return x, inner_numel - return x, x.stride(0) - - def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -264,16 +231,7 @@ def fused_recurrent_gated_delta_rule_fwd( ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] - # In LightLLM's Qwen3Next inference path this fused recurrent kernel is - # used only for decode. Prefill/varlen requests are handled by - # chunk_gated_delta_rule, so keep cu_seqlens out of this strided-view path. - assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" - N = B - q, stride_q_tok = _ensure_qkv_token_strided(q, H * K) - k, stride_k_tok = _ensure_qkv_token_strided(k, H * K) - v, stride_v_tok = _ensure_qkv_token_strided(v, HV * V) - a_raw, stride_a_tok = _ensure_gate_token_strided(a_raw, HV) - b_raw, stride_b_tok = _ensure_gate_token_strided(b_raw, HV) + N = B if cu_seqlens is None else len(cu_seqlens) - 1 BK = triton.next_power_of_2(K) if T == 1: # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) @@ -303,23 +261,20 @@ def fused_recurrent_gated_delta_rule_fwd( stride_init_state_token = initial_state.stride(0) stride_final_state_token = final_state.stride(0) - # Strides for read indices. The kernel advances along a row with `+ i_t` - # (token stride 1), so 2D index tensors must have contiguous rows. + # Strides for read indices if ssm_state_indices is None: stride_indices_seq, stride_indices_tok = 1, 1 elif ssm_state_indices.ndim == 1: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 else: - assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - # Strides for write indices (if provided); same contiguous-row requirement + # Strides for write indices (if provided) if ssm_state_write_indices is None: stride_write_indices_seq, stride_write_indices_tok = 1, 1 elif ssm_state_write_indices.ndim == 1: stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 else: - assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() grid = (NK, NV, N * HV) @@ -350,11 +305,6 @@ def fused_recurrent_gated_delta_rule_fwd( V=V, BK=BK, BV=BV, - stride_q_tok=stride_q_tok, - stride_k_tok=stride_k_tok, - stride_v_tok=stride_v_tok, - stride_a_tok=stride_a_tok, - stride_b_tok=stride_b_tok, stride_init_state_token=stride_init_state_token, stride_final_state_token=stride_final_state_token, stride_indices_seq=stride_indices_seq, @@ -398,12 +348,10 @@ def forward( b_raw: torch.Tensor | None = None, out: torch.Tensor | None = None, ): - # q/k/v/a_raw/b_raw may be non-contiguous column views of one projection - # output; the kernel handles them via per-token strides (no copies). o, final_state = fused_recurrent_gated_delta_rule_fwd( - q=q, - k=k, - v=v, + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), g=g.contiguous() if g is not None else None, beta=beta.contiguous() if beta is not None else None, scale=scale, @@ -416,8 +364,8 @@ def forward( use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, A_log=A_log, dt_bias=dt_bias, - a_raw=a_raw, - b_raw=b_raw, + a_raw=a_raw.contiguous() if a_raw is not None else None, + b_raw=b_raw.contiguous() if b_raw is not None else None, out=out, ) @@ -469,9 +417,8 @@ def fused_recurrent_gated_delta_rule( Whether to store the final state in-place to save memory. Default: `True`. cu_seqlens (torch.LongTensor): - Must be `None`. In LightLLM this fused recurrent kernel is used only - by the Qwen3Next decode path; prefill/varlen requests use - `chunk_gated_delta_rule`. + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. ssm_state_indices (Optional[torch.Tensor]): Indices to map the input sequences to the initial/final states. num_accepted_tokens (Optional[torch.Tensor]): @@ -486,9 +433,10 @@ def fused_recurrent_gated_delta_rule( Examples:: >>> import torch >>> import torch.nn.functional as F + >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule - # decode inputs - >>> B, T, H, HV, K, V = 4, 1, 4, 8, 512, 512 + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, HV, V, device='cuda') @@ -499,10 +447,21 @@ def fused_recurrent_gated_delta_rule( q, k, v, g, beta, initial_state=h0, ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) """ - # This wrapper is only used for Qwen3Next decode inference in LightLLM. - # Keep varlen/prefill inputs on chunk_gated_delta_rule instead. - assert cu_seqlens is None, "cu_seqlens is not supported by the decode-only fused recurrent kernel" + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) if scale is None: scale = k.shape[-1] ** -0.5 else: diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py new file mode 100644 index 0000000000..a025e35c64 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -0,0 +1,284 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + BLOCK_QKV: tl.constexpr, + BLOCK_GATE: tl.constexpr, +): + row = tl.program_id(0) + qkv_offsets = tl.arange(0, BLOCK_QKV) + + q_mask = qkv_offsets < q_dim + q_vals = tl.load(mixed_qkv + row * stride_m_b + qkv_offsets * stride_m_d, mask=q_mask, other=0.0) + tl.store(q_out + row * q_dim + qkv_offsets, q_vals, mask=q_mask) + + k_mask = qkv_offsets < k_dim + k_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + qkv_offsets) * stride_m_d, + mask=k_mask, + other=0.0, + ) + tl.store(k_out + row * k_dim + qkv_offsets, k_vals, mask=k_mask) + + v_mask = qkv_offsets < v_dim + v_vals = tl.load( + mixed_qkv + row * stride_m_b + (q_dim + k_dim + qkv_offsets) * stride_m_d, + mask=v_mask, + other=0.0, + ) + tl.store(v_out + row * v_dim + qkv_offsets, v_vals, mask=v_mask) + + z_vals = tl.load(z_raw + row * stride_z_b + qkv_offsets, mask=v_mask, other=0.0) + tl.store(z_out + row * v_dim + qkv_offsets, z_vals, mask=v_mask) + + gate_offsets = tl.arange(0, BLOCK_GATE) + gate_mask = gate_offsets < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + gate_offsets * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + gate_offsets * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + gate_offsets, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + gate_offsets, b_vals, mask=gate_mask) + + +@torch.no_grad() +def pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim)) + block_gate = triton.next_power_of_2(gate_dim) + _pack_gdn_decode_kernel[(batch,)]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + BLOCK_QKV=block_qkv, + BLOCK_GATE=block_gate, + num_warps=4, + ) + return q, k, v, z, a, b + + +@triton.jit +def _conv_pack_gdn_decode_kernel( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q_out, + k_out, + v_out, + z_out, + a_out, + b_out, + stride_m_b: tl.constexpr, + stride_m_d: tl.constexpr, + stride_z_b: tl.constexpr, + stride_z_h: tl.constexpr, + stride_z_d: tl.constexpr, + stride_a_b: tl.constexpr, + stride_a_d: tl.constexpr, + stride_b_b: tl.constexpr, + stride_b_d: tl.constexpr, + stride_s_b: tl.constexpr, + stride_s_d: tl.constexpr, + stride_s_w: tl.constexpr, + stride_w_d: tl.constexpr, + stride_w_w: tl.constexpr, + q_dim: tl.constexpr, + k_dim: tl.constexpr, + v_dim: tl.constexpr, + gate_dim: tl.constexpr, + conv_dim: tl.constexpr, + HAS_BIAS: tl.constexpr, + APPLY_SILU: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offs = block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < conv_dim + state_idx = tl.load(conv_state_indices + row) + + x = tl.load(mixed_qkv + row * stride_m_b + offs * stride_m_d, mask=mask, other=0.0).to(tl.float32) + s0 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s1 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + s2 = tl.load(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, mask=mask, other=0.0).to( + tl.float32 + ) + w0 = tl.load(conv_weight + offs * stride_w_d + 0 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w1 = tl.load(conv_weight + offs * stride_w_d + 1 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w2 = tl.load(conv_weight + offs * stride_w_d + 2 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + w3 = tl.load(conv_weight + offs * stride_w_d + 3 * stride_w_w, mask=mask, other=0.0).to(tl.float32) + y = s0 * w0 + s1 * w1 + s2 * w2 + x * w3 + if HAS_BIAS: + bias = tl.load(conv_bias + offs, mask=mask, other=0.0).to(tl.float32) + y += bias + if APPLY_SILU: + y = y * tl.sigmoid(y) + + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 0 * stride_s_w, s1, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 1 * stride_s_w, s2, mask=mask) + tl.store(conv_state + state_idx * stride_s_b + offs * stride_s_d + 2 * stride_s_w, x, mask=mask) + + q_mask = offs < q_dim + k_mask = (offs >= q_dim) & (offs < q_dim + k_dim) + v_mask = (offs >= q_dim + k_dim) & (offs < conv_dim) + tl.store(q_out + row * q_dim + offs, y, mask=q_mask) + tl.store(k_out + row * k_dim + (offs - q_dim), y, mask=k_mask) + tl.store(v_out + row * v_dim + (offs - q_dim - k_dim), y, mask=v_mask) + + z_mask = offs < v_dim + z_vals = tl.load(z_raw + row * stride_z_b + offs, mask=z_mask, other=0.0) + tl.store(z_out + row * v_dim + offs, z_vals, mask=z_mask) + + gate_mask = offs < gate_dim + a_vals = tl.load(a_raw + row * stride_a_b + offs * stride_a_d, mask=gate_mask, other=0.0) + b_vals = tl.load(b_raw + row * stride_b_b + offs * stride_b_d, mask=gate_mask, other=0.0) + tl.store(a_out + row * gate_dim + offs, a_vals, mask=gate_mask) + tl.store(b_out + row * gate_dim + offs, b_vals, mask=gate_mask) + + +@torch.no_grad() +def conv_pack_gdn_decode_inputs( + mixed_qkv: torch.Tensor, + z_raw: torch.Tensor, + a_raw: torch.Tensor, + b_raw: torch.Tensor, + conv_state: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + conv_state_indices: torch.Tensor, + activation: str, + num_k_heads: int, + head_k_dim: int, + num_v_heads: int, + head_v_dim: int, +): + batch = mixed_qkv.shape[0] + q_dim = num_k_heads * head_k_dim + k_dim = q_dim + v_dim = num_v_heads * head_v_dim + gate_dim = num_v_heads + conv_dim = q_dim + k_dim + v_dim + + q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + k = torch.empty_like(q) + v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) + z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) + a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) + b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) + + block_size = 256 + grid = (batch, triton.cdiv(conv_dim, block_size)) + _conv_pack_gdn_decode_kernel[grid]( + mixed_qkv, + z_raw, + a_raw, + b_raw, + conv_state, + conv_weight, + conv_bias, + conv_state_indices, + q, + k, + v, + z, + a, + b, + mixed_qkv.stride(0), + mixed_qkv.stride(1), + z_raw.stride(0), + z_raw.stride(1), + z_raw.stride(2), + a_raw.stride(0), + a_raw.stride(1), + b_raw.stride(0), + b_raw.stride(1), + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + q_dim, + k_dim, + v_dim, + gate_dim, + conv_dim, + HAS_BIAS=conv_bias is not None, + APPLY_SILU=activation in ["silu", "swish"], + BLOCK_SIZE=block_size, + num_warps=8, + ) + return q, k, v, z, a, b diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py new file mode 100644 index 0000000000..c2b110def6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -0,0 +1,108 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _add_shared_expert_gate_kernel( + hidden, + shared, + gate, + stride_h_m: tl.constexpr, + stride_h_n: tl.constexpr, + stride_s_m: tl.constexpr, + stride_s_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + + hidden_ptrs = hidden + row * stride_h_m + offs * stride_h_n + shared_vals = tl.load(shared + row * stride_s_m + offs * stride_s_n, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + out = hidden_vals + shared_vals * gate_vals + tl.store(hidden_ptrs, out.to(hidden.dtype.element_ty), mask=mask) + + +@triton.jit +def _sigmoid_mul_kernel( + x, + gate, + stride_x_m: tl.constexpr, + stride_x_n: tl.constexpr, + stride_g_m: tl.constexpr, + stride_g_n: tl.constexpr, + N: tl.constexpr, + GATE_N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_N) + mask = offs < N + x_ptrs = x + row * stride_x_m + offs * stride_x_n + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + if GATE_N == 1: + gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) + else: + gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) + gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) + + +@torch.no_grad() +def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + hidden_arg = hidden.view(-1, hidden.shape[-1]) + shared_arg = shared.view(-1, hidden.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert hidden_arg.shape == shared_arg.shape + assert gate_arg.shape[0] == hidden_arg.shape[0] and gate_arg.shape[1] in (1, hidden_arg.shape[1]) + _, n = hidden_arg.shape + block_n = triton.next_power_of_2(n) + _add_shared_expert_gate_kernel[(hidden_arg.shape[0],)]( + hidden_arg, + shared_arg, + gate_arg, + hidden_arg.stride(0), + hidden_arg.stride(1), + shared_arg.stride(0), + shared_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return hidden + + +@torch.no_grad() +def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + x_arg = x.view(-1, x.shape[-1]) + gate_arg = gate.view(-1, gate.shape[-1]) + assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) + _, n = x_arg.shape + block_n = triton.next_power_of_2(n) + _sigmoid_mul_kernel[(x_arg.shape[0],)]( + x_arg, + gate_arg, + x_arg.stride(0), + x_arg.stride(1), + gate_arg.stride(0), + gate_arg.stride(1), + n, + gate_arg.shape[1], + BLOCK_N=block_n, + num_warps=8, + ) + return x diff --git a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py deleted file mode 100644 index cf9d06ec98..0000000000 --- a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import torch - -from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( - fused_recurrent_gated_delta_rule, -) - -if not torch.cuda.is_available(): - pytest.skip("CUDA required", allow_module_level=True) - - -@pytest.mark.parametrize("batch", [1, 2, 16]) -def test_decode_strided_views_match_contiguous(batch): - """q/k/v/a/b passed as column views of one projection output (the decode - path layout) must produce the same result as contiguous copies.""" - torch.manual_seed(0) - H, HV, K, V = 2, 8, 128, 128 - key_dim, value_dim = H * K, HV * V - qkv_dim = 2 * key_dim + value_dim - total_dim = qkv_dim + value_dim + 2 * HV # qkv + z + b + a - cache_slots = 64 - - mixed = torch.randn(batch, total_dim, device="cuda", dtype=torch.bfloat16) - mixed_qkv = mixed[:, :qkv_dim] - b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] - a_raw = mixed[:, qkv_dim + value_dim + HV :] - - query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) - q = query.view(batch, 1, H, K) - k = key.view(batch, 1, H, K) - v = value.view(batch, 1, HV, V) - - A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 - dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 - ssm_state = torch.randn(cache_slots, HV, K, V, device="cuda", dtype=torch.bfloat16) - idx = torch.randperm(cache_slots, device="cuda")[:batch].to(torch.int32) - - def run(q_, k_, v_, a_, b_, state): - out, _ = fused_recurrent_gated_delta_rule( - q=q_, - k=k_, - v=v_, - initial_state=state, - inplace_final_state=True, - ssm_state_indices=idx, - use_qk_l2norm_in_kernel=True, - A_log=A_log, - dt_bias=dt_bias, - a_raw=a_, - b_raw=b_, - ) - return out - - state_ref = ssm_state.clone() - out_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous(), state_ref) - state_strided = ssm_state.clone() - out_strided = run(q, k, v, a_raw, b_raw, state_strided) - - assert torch.equal(out_ref, out_strided) - assert torch.equal(state_ref, state_strided) - - -def test_cu_seqlens_is_not_supported(): - """The fused recurrent kernel is decode-only in LightLLM's Qwen3Next path.""" - H, HV, K, V = 2, 2, 4, 4 - q = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) - k = torch.randn(1, 2, H, K, device="cuda", dtype=torch.bfloat16) - v = torch.randn(1, 2, HV, V, device="cuda", dtype=torch.bfloat16) - initial_state = torch.randn(1, HV, K, V, device="cuda", dtype=torch.bfloat16) - cu_seqlens = torch.tensor([0, 2], device="cuda", dtype=torch.long) - - with pytest.raises(AssertionError, match="decode-only fused recurrent kernel"): - fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - initial_state=initial_state, - cu_seqlens=cu_seqlens, - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From aaab438ba0ef7186a318993c427d54b69e9891eb Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 17 Jun 2026 09:40:24 +0800 Subject: [PATCH 2/4] feat: reducing more running time for fused moe --- .../fused_moe/grouped_fused_moe.py | 1 + .../fused_moe/moe_silu_and_mul.py | 2 +- .../triton_kernel/fused_moe/moe_sum_reduce.py | 4 +- .../layer_infer/transformer_layer_infer.py | 14 ++- .../layer_weights/transformer_layer_weight.py | 101 +++++++++++++++--- .../triton_kernel/gdn_decode_pack.py | 57 ---------- .../triton_kernel/shared_expert_gate.py | 4 +- 7 files changed, 102 insertions(+), 81 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index cec9c53e52..ecdf9f2f7b 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -231,6 +231,7 @@ def moe_align_fused_kernel( if ZERO_EXPERT_TOKEN_NUM: expert_offs = tl.arange(0, BLOCK_EXPERT) tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) + tl.debug_barrier() offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index 45c7ea73c6..a63d92692e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -122,7 +122,7 @@ def silu_and_mul_fwd( alpha=None, run_config=None, ): - assert input.is_contiguous() + assert input.stride(-1) == 1 assert output.is_contiguous() assert (limit is None and alpha is None) or (limit is not None and alpha is not None) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index 4f95cca7c6..97cda5cb37 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -66,7 +66,7 @@ def _moe_sum_reduce_kernel( mask=offs_dim < dim_end, other=0.0, ).to(tl.float32) - gate = 1.0 / (1.0 + tl.exp(-gate)) + gate = tl.sigmoid(gate) accumulator += shared * gate store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end) @@ -113,7 +113,7 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate= shared = shared.view(token_num, hidden_dim) gate = gate.view(token_num, gate.shape[-1]) assert shared.is_contiguous() - assert gate.is_contiguous() + assert gate.stride(1) == 1 assert gate.shape[1] in (1, hidden_dim) if not run_config: diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 3492041813..86c3c65ba0 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -14,6 +14,7 @@ from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import add_shared_expert_gate_, sigmoid_mul_ +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -115,8 +116,17 @@ def _compute_shared_expert( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): input = input.view(-1, self.embed_dim_) - shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input) + if getattr(layer_weight, "fused_shared_expert_gate", False): + up_gate_and_gate = layer_weight.gate_up_proj.mm(input) + up_gate_dim = layer_weight.gate_up_proj.out_dims[0] + layer_weight.gate_up_proj.out_dims[1] + gate = up_gate_and_gate[:, up_gate_dim : up_gate_dim + 1] + up_gate_out = up_gate_and_gate[:, :up_gate_dim] + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + shared_expert_out = layer_weight.down_proj.mm(ffn1_out) + else: + shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) + gate = layer_weight.ffn_gate.mm(input) return shared_expert_out, gate def _moe_ffn_tp( diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 51b702039b..42e18c7f02 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -90,6 +90,56 @@ def _get_tp_padded_head_num(self, head_num): ) +class SharedGateUpGateROWMMWeight(MMWeightTpl): + gate_pad_dim = 16 + + def __init__( + self, + in_dim, + inter_size, + weight_names, + data_type, + quant_method=None, + tp_rank=None, + tp_world_size=None, + ): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + assert ( + inter_size % self.tp_world_size_ == 0 + ), f"inter_size must be divisible by tp_world_size_, found {inter_size} % {self.tp_world_size_}" + super().__init__( + in_dim=in_dim, + out_dims=[inter_size // self.tp_world_size_, inter_size // self.tp_world_size_, self.gate_pad_dim], + weight_names=weight_names, + bias_names=None, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.tp_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ + ) + self.gate_param_slicer = get_row_slice_mixin("none", tp_rank=0, tp_world_size=1) + + def _get_param_slicer(self, sub_child_index): + if sub_child_index == 2: + return self.gate_param_slicer + return self.tp_param_slicer + + def load_hf_weights(self, weights): + for sub_child_index, param_name in enumerate(self.weight_names): + if sub_child_index != 2: + self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) + continue + if param_name in weights: + weight_pack = self.mm_param_list[sub_child_index] + weight_pack.weight.zero_() + weight_pack.weight[:1].copy_(weights[param_name]) + weight_pack.load_ok[0] = True + + class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): num_full_attention_layers = network_config["full_attention_interval"] @@ -198,13 +248,29 @@ def _init_gated_ffn(self): tp_world_size=1, ) else: - self.gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=self.get_quant_method("gate_up_proj"), - ) + gate_up_quant = self.get_quant_method("gate_up_proj") + if gate_up_quant.method_name == "none": + self.gate_up_proj = SharedGateUpGateROWMMWeight( + in_dim=hidden_size, + inter_size=inter_size, + weight_names=[ + f"{prefix}.gate_proj.weight", + f"{prefix}.up_proj.weight", + f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + ], + data_type=self.data_type_, + quant_method=gate_up_quant, + ) + self.fused_shared_expert_gate = True + else: + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=gate_up_quant, + ) + self.fused_shared_expert_gate = False self.down_proj = COLMMWeight( in_dim=inter_size, out_dims=[hidden_size], @@ -213,16 +279,17 @@ def _init_gated_ffn(self): quant_method=self.get_quant_method("down_proj"), ) - self.ffn_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) + if not getattr(self, "fused_shared_expert_gate", False): + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) def _split_q_with_gate(self, weights): if self._q_weight_name in weights: diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py index a025e35c64..180fdcfab0 100644 --- a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -65,63 +65,6 @@ def _pack_gdn_decode_kernel( tl.store(b_out + row * gate_dim + gate_offsets, b_vals, mask=gate_mask) -@torch.no_grad() -def pack_gdn_decode_inputs( - mixed_qkv: torch.Tensor, - z_raw: torch.Tensor, - a_raw: torch.Tensor, - b_raw: torch.Tensor, - num_k_heads: int, - head_k_dim: int, - num_v_heads: int, - head_v_dim: int, -): - batch = mixed_qkv.shape[0] - q_dim = num_k_heads * head_k_dim - k_dim = q_dim - v_dim = num_v_heads * head_v_dim - gate_dim = num_v_heads - - q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) - k = torch.empty_like(q) - v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) - z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) - a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) - b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) - - block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim)) - block_gate = triton.next_power_of_2(gate_dim) - _pack_gdn_decode_kernel[(batch,)]( - mixed_qkv, - z_raw, - a_raw, - b_raw, - q, - k, - v, - z, - a, - b, - mixed_qkv.stride(0), - mixed_qkv.stride(1), - z_raw.stride(0), - z_raw.stride(1), - z_raw.stride(2), - a_raw.stride(0), - a_raw.stride(1), - b_raw.stride(0), - b_raw.stride(1), - q_dim, - k_dim, - v_dim, - gate_dim, - BLOCK_QKV=block_qkv, - BLOCK_GATE=block_gate, - num_warps=4, - ) - return q, k, v, z, a, b - - @triton.jit def _conv_pack_gdn_decode_kernel( mixed_qkv, diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py index c2b110def6..fd89edd2f1 100644 --- a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -29,7 +29,7 @@ def _add_shared_expert_gate_kernel( else: gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) - gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + gate_vals = tl.sigmoid(gate_vals) out = hidden_vals + shared_vals * gate_vals tl.store(hidden_ptrs, out.to(hidden.dtype.element_ty), mask=mask) @@ -55,7 +55,7 @@ def _sigmoid_mul_kernel( gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) else: gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) - gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) + gate_vals = tl.sigmoid(gate_vals) tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) From 0ce9da6e3e19ae960090111e762b469a2981c673 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 17 Jun 2026 10:58:02 +0800 Subject: [PATCH 3/4] fix --- .../triton_kernel/gdn_decode_pack.py | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py index 180fdcfab0..1c5f088ef1 100644 --- a/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_pack.py @@ -3,68 +3,6 @@ import triton.language as tl -@triton.jit -def _pack_gdn_decode_kernel( - mixed_qkv, - z_raw, - a_raw, - b_raw, - q_out, - k_out, - v_out, - z_out, - a_out, - b_out, - stride_m_b: tl.constexpr, - stride_m_d: tl.constexpr, - stride_z_b: tl.constexpr, - stride_z_h: tl.constexpr, - stride_z_d: tl.constexpr, - stride_a_b: tl.constexpr, - stride_a_d: tl.constexpr, - stride_b_b: tl.constexpr, - stride_b_d: tl.constexpr, - q_dim: tl.constexpr, - k_dim: tl.constexpr, - v_dim: tl.constexpr, - gate_dim: tl.constexpr, - BLOCK_QKV: tl.constexpr, - BLOCK_GATE: tl.constexpr, -): - row = tl.program_id(0) - qkv_offsets = tl.arange(0, BLOCK_QKV) - - q_mask = qkv_offsets < q_dim - q_vals = tl.load(mixed_qkv + row * stride_m_b + qkv_offsets * stride_m_d, mask=q_mask, other=0.0) - tl.store(q_out + row * q_dim + qkv_offsets, q_vals, mask=q_mask) - - k_mask = qkv_offsets < k_dim - k_vals = tl.load( - mixed_qkv + row * stride_m_b + (q_dim + qkv_offsets) * stride_m_d, - mask=k_mask, - other=0.0, - ) - tl.store(k_out + row * k_dim + qkv_offsets, k_vals, mask=k_mask) - - v_mask = qkv_offsets < v_dim - v_vals = tl.load( - mixed_qkv + row * stride_m_b + (q_dim + k_dim + qkv_offsets) * stride_m_d, - mask=v_mask, - other=0.0, - ) - tl.store(v_out + row * v_dim + qkv_offsets, v_vals, mask=v_mask) - - z_vals = tl.load(z_raw + row * stride_z_b + qkv_offsets, mask=v_mask, other=0.0) - tl.store(z_out + row * v_dim + qkv_offsets, z_vals, mask=v_mask) - - gate_offsets = tl.arange(0, BLOCK_GATE) - gate_mask = gate_offsets < gate_dim - a_vals = tl.load(a_raw + row * stride_a_b + gate_offsets * stride_a_d, mask=gate_mask, other=0.0) - b_vals = tl.load(b_raw + row * stride_b_b + gate_offsets * stride_b_d, mask=gate_mask, other=0.0) - tl.store(a_out + row * gate_dim + gate_offsets, a_vals, mask=gate_mask) - tl.store(b_out + row * gate_dim + gate_offsets, b_vals, mask=gate_mask) - - @triton.jit def _conv_pack_gdn_decode_kernel( mixed_qkv, From cd0e48d10afde47290b5e02cbe78f1f3aa2fa331 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 17 Jun 2026 17:24:50 +0800 Subject: [PATCH 4/4] feat: qwen + 257ep + topk=9 --- .../fused_moe/fused_moe_weight.py | 33 +++- .../meta_weights/fused_moe/impl/base_impl.py | 2 - .../fused_moe/impl/deepgemm_impl.py | 2 - .../fused_moe/impl/marlin_impl.py | 2 - .../fused_moe/impl/triton_impl.py | 83 ++++++--- .../fused_moe/grouped_fused_moe.py | 144 +++++++++++---- .../triton_kernel/fused_moe/moe_sum_reduce.py | 50 +----- .../layer_weights/transformer_layer_weight.py | 6 +- .../layer_infer/transformer_layer_infer.py | 44 ++--- .../layer_weights/transformer_layer_weight.py | 169 +++++++++--------- .../triton_kernel/shared_expert_gate.py | 58 ------ lightllm/server/api_cli.py | 7 +- lightllm/server/core/objs/start_args_type.py | 1 + 13 files changed, 312 insertions(+), 289 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index d9a77b39a5..f02e2294cc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -134,8 +134,6 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -152,8 +150,35 @@ def experts( num_expert_group=num_expert_group, is_prefill=is_prefill, per_expert_scale=self.per_expert_scale, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, + ) + + def fused_shared_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + shared_expert_weight: torch.Tensor, + is_prefill: Optional[bool] = None, + ) -> torch.Tensor: + return self.fuse_moe_impl.fused_shared_experts( + input_tensor=input_tensor, + router_logits=router_logits, + w13=self.w13, + w2=self.w2, + correction_bias=self.e_score_correction_bias, + scoring_func=self.scoring_func, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + is_prefill=is_prefill, + per_expert_scale=self.per_expert_scale, + shared_expert_weight=shared_expert_weight, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index b54b03ee05..dd6f9a6880 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -63,7 +63,5 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index bc0e86d7eb..4d4614c007 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,8 +76,6 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, ): output = fused_experts( hidden_states=input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 417d001c72..0094b09b1c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,8 +30,6 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, ): w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index fdda2b2139..2fa715ef65 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -62,27 +62,6 @@ def _select_experts( topk_weights.mul_(self.routed_scaling_factor) if per_expert_scale is not None: topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts, - end=self.n_routed_experts + self.num_fused_shared_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) return topk_weights, topk_ids def _fused_experts( @@ -94,8 +73,6 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -113,8 +90,7 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, + shared_expert_id=self.n_routed_experts if self.num_fused_shared_experts > 0 else -1, ) return input_tensor @@ -133,8 +109,6 @@ def __call__( num_expert_group: int, is_prefill: Optional[bool] = None, per_expert_scale: Optional[torch.Tensor] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -156,7 +130,58 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, ) return output + + def fused_shared_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + correction_bias: Optional[torch.Tensor], + scoring_func: str, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + shared_expert_weight: torch.Tensor, + is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + ): + assert ( + type(self) is FuseMoeTriton + ), "fused shared expert as MoE is only supported by the Triton fused MoE implementation" + topk_weights, topk_ids = self._select_experts( + input_tensor=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + per_expert_scale=per_expert_scale, + ) + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn + + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_shared_experts + + fused_shared_experts( + hidden_states=input_tensor, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_expert_weight=shared_expert_weight, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + shared_expert_id=self.n_routed_experts, + ) + return input_tensor diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index ecdf9f2f7b..c18606e31c 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -217,30 +217,51 @@ def moe_align1( def moe_align_fused_kernel( topk_ids_ptr, # [token_num, topk] topk_weights_ptr, # [token_num, topk] + shared_expert_weight_ptr, # [token_num, 1] expert_to_token_index_ptr, # [expert_num, token_num * topk] expert_to_weight_ptr, # [expert_num, token_num * topk] expert_token_num_ptr, # [expert_num] token_num, + routed_topk_num: tl.constexpr, expert_num: tl.constexpr, topk_num: tl.constexpr, + shared_expert_id: tl.constexpr, BLOCK_SIZE: tl.constexpr, ZERO_EXPERT_TOKEN_NUM: tl.constexpr, BLOCK_EXPERT: tl.constexpr, + HAS_SHARED_EXPERT_WEIGHT: tl.constexpr, ): token_block = tl.program_id(0) if ZERO_EXPERT_TOKEN_NUM: expert_offs = tl.arange(0, BLOCK_EXPERT) tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) tl.debug_barrier() + if shared_expert_id >= 0: + tl.store(expert_token_num_ptr + shared_expert_id, token_num, mask=token_block == 0) offs = token_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offs < token_num * topk_num - expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) - weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) + token_ids = offs // topk_num + topk_offsets = offs - token_ids * topk_num + routed_offsets = token_ids * routed_topk_num + topk_offsets + is_shared_expert = topk_offsets >= routed_topk_num + + expert_ids = tl.load(topk_ids_ptr + routed_offsets, mask=mask & (is_shared_expert == 0), other=0) + expert_ids = tl.where(is_shared_expert, shared_expert_id, expert_ids) + weights = tl.load(topk_weights_ptr + routed_offsets, mask=mask & (is_shared_expert == 0), other=0.0) + if HAS_SHARED_EXPERT_WEIGHT: + shared_weights = tl.load(shared_expert_weight_ptr + token_ids, mask=mask & is_shared_expert, other=0.0).to( + tl.float32 + ) + shared_weights = tl.sigmoid(shared_weights) + else: + shared_weights = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32) + weights = tl.where(is_shared_expert, shared_weights, weights) - # 用 atomic_add 给 expert 分配写位置 - write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) + # Shared expert appears exactly once per token, so its position is deterministic. + routed_write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask & (is_shared_expert == 0)) + write_pos = tl.where(is_shared_expert, token_ids, routed_write_pos) # 按 token 顺序写 index 和 weight tl.store( @@ -257,8 +278,11 @@ def moe_align_fused_kernel( def _get_moe_align_fused_static_key( topk_weights: torch.Tensor, + shared_expert_id: int = -1, ) -> dict: topk_num = topk_weights.shape[1] + if shared_expert_id >= 0: + topk_num += 1 return { "topk_num": topk_num, } @@ -283,29 +307,43 @@ def _get_moe_align_fused_configs(): mutates_args=["expert_to_token_index", "expert_to_weight", "expert_token_num"], ) def moe_align_fused( - expert_to_token_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None + expert_to_token_index, + expert_to_weight, + expert_token_num, + topk_ids, + topk_weights, + shared_expert_id: int = -1, + shared_expert_weight: Optional[torch.Tensor] = None, + run_config: Optional[dict] = None, ): - token_num, topk_num = topk_ids.shape + token_num, routed_topk_num = topk_ids.shape + topk_num = routed_topk_num + (1 if shared_expert_id >= 0 else 0) if run_config is None: run_config = {} BLOCK_SIZE = run_config.get("BLOCK_SIZE", 256) num_warps = run_config.get("num_warps", 4) expert_num = expert_token_num.shape[0] zero_expert_token_num = token_num * topk_num <= BLOCK_SIZE + if shared_expert_weight is not None: + shared_expert_weight = shared_expert_weight.view(token_num, 1) grid = (triton.cdiv(token_num * topk_num, BLOCK_SIZE),) moe_align_fused_kernel[grid]( topk_ids, topk_weights, + shared_expert_weight if shared_expert_weight is not None else topk_weights, expert_to_token_index, expert_to_weight, expert_token_num, token_num, + routed_topk_num, expert_num, topk_num, + shared_expert_id, BLOCK_SIZE=BLOCK_SIZE, ZERO_EXPERT_TOKEN_NUM=zero_expert_token_num, BLOCK_EXPERT=triton.next_power_of_2(expert_num), + HAS_SHARED_EXPERT_WEIGHT=shared_expert_weight is not None, num_warps=num_warps, ) return expert_to_token_index, expert_to_weight, expert_token_num @@ -924,8 +962,8 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, + shared_expert_id: int = -1, + shared_expert_weight: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -937,7 +975,8 @@ def fused_experts_impl( num_tokens, _ = hidden_states.shape E, N, _ = w1.shape CHUNK_SIZE = FFN_MOE_CHUNK_SIZE - topk_num = topk_ids.shape[1] + routed_topk_num = topk_ids.shape[1] + topk_num = routed_topk_num + (1 if shared_expert_id >= 0 else 0) M = min(num_tokens, CHUNK_SIZE) intermediate_cache13_shared = alloc_tensor_func( @@ -969,6 +1008,9 @@ def fused_experts_impl( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_shared_expert_weight = ( + shared_expert_weight[begin_chunk_idx:end_chunk_idx] if shared_expert_weight is not None else None + ) expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") @@ -984,10 +1026,12 @@ def fused_experts_impl( expert_token_num=expert_to_token_num, topk_ids=curr_topk_ids, topk_weights=curr_topk_weights, + shared_expert_id=shared_expert_id, + shared_expert_weight=curr_shared_expert_weight, ) reused_mblock_infos = grouped_matmul( - curr_topk_ids.numel(), + tokens_in_chunk * topk_num, curr_hidden_states, a1_scale, expert_to_token_num, @@ -1013,7 +1057,7 @@ def fused_experts_impl( ) grouped_matmul( - curr_topk_ids.numel(), + tokens_in_chunk * topk_num, intermediate_cache2.view(-1, N // 2), a2_scale, expert_to_token_num, @@ -1031,12 +1075,9 @@ def fused_experts_impl( bias=w2_bias, ) - has_shared_gate = shared_expert_out is not None moe_sum_reduce( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], - shared=None if not has_shared_gate else shared_expert_out[begin_chunk_idx:end_chunk_idx], - gate=None if not has_shared_gate else shared_expert_gate[begin_chunk_idx:end_chunk_idx], ) return out_hidden_states @@ -1059,8 +1100,7 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, + shared_expert_id: int = -1, ) -> None: fused_experts_impl( hidden_states, @@ -1080,8 +1120,7 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, + shared_expert_id=shared_expert_id, ) @@ -1103,8 +1142,7 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, + shared_expert_id: int = -1, ) -> None: pass @@ -1135,9 +1173,8 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, -) -> None: + shared_expert_id: int = -1, +) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, @@ -1156,8 +1193,7 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, + shared_expert_id=shared_expert_id, ) @@ -1179,9 +1215,8 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, -) -> None: + shared_expert_id: int = -1, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1212,8 +1247,7 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, - shared_expert_out: Optional[torch.Tensor] = None, - shared_expert_gate: Optional[torch.Tensor] = None, + shared_expert_id: int = -1, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1233,8 +1267,7 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, + shared_expert_id=shared_expert_id, ) return hidden_states else: @@ -1255,6 +1288,49 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, - shared_expert_out=shared_expert_out, - shared_expert_gate=shared_expert_gate, + shared_expert_id=shared_expert_id, ) + + +def fused_shared_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_expert_weight: torch.Tensor, + inplace: bool = True, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + layout: str = "blocked", + alpha: Optional[float] = None, + limit: Optional[float] = None, + shared_expert_id: int = -1, +): + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + use_fp8_w8a8, + use_int8_w8a16, + w1_bias, + w2_bias, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + layout=layout, + alpha=alpha, + limit=limit, + shared_expert_id=shared_expert_id, + shared_expert_weight=shared_expert_weight, + ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index 97cda5cb37..28221344b8 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -14,20 +14,12 @@ def _moe_sum_reduce_kernel( output_ptr, output_stride_0, output_stride_1, - shared_ptr, - shared_stride_0, - shared_stride_1, - gate_ptr, - gate_stride_0, - gate_stride_1, token_num: int, topk_num: int, hidden_dim: int, BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, NUM_STAGE: tl.constexpr, - HAS_SHARED_GATE: tl.constexpr, - GATE_DIM: tl.constexpr, ): input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) @@ -50,37 +42,15 @@ def _moe_sum_reduce_kernel( for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tmp = tl.load(input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0) accumulator += tmp - if HAS_SHARED_GATE: - shared = tl.load( - shared_ptr + token_index * shared_stride_0 + offs_dim * shared_stride_1, - mask=offs_dim < dim_end, - other=0.0, - ).to(tl.float32) - if GATE_DIM == 1: - gate = tl.load(gate_ptr + token_index * gate_stride_0).to(tl.float32) + tl.zeros( - (BLOCK_DIM,), dtype=tl.float32 - ) - else: - gate = tl.load( - gate_ptr + token_index * gate_stride_0 + offs_dim * gate_stride_1, - mask=offs_dim < dim_end, - other=0.0, - ).to(tl.float32) - gate = tl.sigmoid(gate) - accumulator += shared * gate store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim tl.store(store_t_ptr, accumulator.to(input_ptr.dtype.element_ty), mask=offs_dim < dim_end) -def _get_moe_sum_reduce_static_key( - input: torch.Tensor, output: torch.Tensor, shared: torch.Tensor = None, gate: torch.Tensor = None -): +def _get_moe_sum_reduce_static_key(input: torch.Tensor, output: torch.Tensor): return { "topk_num": input.shape[1], "hidden_dim": input.shape[2], "out_dtype": str(output.dtype), - "has_shared_gate": shared is not None, - "gate_dim": 0 if gate is None else gate.shape[-1], } @@ -101,20 +71,12 @@ def _get_moe_sum_reduce_configs(): run_key_func=lambda input: input.shape[0], mutates_args=["output"], ) -def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate=None, run_config: Dict = None): +def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None): assert input.is_contiguous() assert output.is_contiguous() token_num, topk_num, hidden_dim = input.shape assert output.shape[0] == token_num and output.shape[1] == hidden_dim - has_shared_gate = shared is not None - if has_shared_gate: - assert gate is not None - shared = shared.view(token_num, hidden_dim) - gate = gate.view(token_num, gate.shape[-1]) - assert shared.is_contiguous() - assert gate.stride(1) == 1 - assert gate.shape[1] in (1, hidden_dim) if not run_config: run_config = { @@ -139,20 +101,12 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, shared=None, gate= *input.stride(), output, *output.stride(), - shared if has_shared_gate else output, - shared.stride(0) if has_shared_gate else 0, - shared.stride(1) if has_shared_gate else 0, - gate if has_shared_gate else output, - gate.stride(0) if has_shared_gate else 0, - gate.stride(1) if has_shared_gate else 0, token_num=token_num, topk_num=topk_num, hidden_dim=hidden_dim, BLOCK_M=BLOCK_M, BLOCK_DIM=BLOCK_DIM, NUM_STAGE=NUM_STAGE, - HAS_SHARED_GATE=has_shared_gate, - GATE_DIM=gate.shape[1] if has_shared_gate else 0, num_warps=num_warps, ) return diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..c5e88cf2cf 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -37,9 +37,9 @@ def _parse_config(self): self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # enable_fused_shared_experts can only work with tensor parallelism - assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." + if not get_env_start_args().disable_fused_shared_experts and self.is_moe: + # fused shared experts can only work with tensor parallelism + assert not get_env_start_args().enable_ep_moe, "fused shared experts can only work with tp mode." self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 86c3c65ba0..b5ae5c394e 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -13,8 +13,7 @@ from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating from lightllm.models.qwen3next.triton_kernel.gdn_decode_pack import conv_pack_gdn_decode_inputs -from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import add_shared_expert_gate_, sigmoid_mul_ -from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.models.qwen3next.triton_kernel.shared_expert_gate import sigmoid_mul_ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce @@ -116,28 +115,34 @@ def _compute_shared_expert( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): input = input.view(-1, self.embed_dim_) - if getattr(layer_weight, "fused_shared_expert_gate", False): - up_gate_and_gate = layer_weight.gate_up_proj.mm(input) - up_gate_dim = layer_weight.gate_up_proj.out_dims[0] + layer_weight.gate_up_proj.out_dims[1] - gate = up_gate_and_gate[:, up_gate_dim : up_gate_dim + 1] - up_gate_out = up_gate_and_gate[:, :up_gate_dim] - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - shared_expert_out = layer_weight.down_proj.mm(ffn1_out) - else: - shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) - gate = layer_weight.ffn_gate.mm(input) - return shared_expert_out, gate + shared_expert_out = LlamaTransformerLayerInfer._ffn_tp(self, input, infer_state, layer_weight) + gate = layer_weight.ffn_gate.mm(input) + sigmoid_mul_(shared_expert_out, gate) + return shared_expert_out def _moe_ffn_tp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) - hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) + if getattr(layer_weight, "num_fused_shared_experts", 0) > 0: + shared_expert_gate = layer_weight.ffn_gate.mm(hidden_states) + layer_weight.experts.fused_shared_experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + shared_expert_weight=shared_expert_gate, + ) + hidden_states = hidden_states.view(num_tokens, hidden_dim) + return hidden_states + + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -146,16 +151,15 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, - shared_expert_out=shared_expert_out, - shared_expert_gate=gate, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) + hidden_states.add_(shared_expert_out) return hidden_states def _moe_ffn_edp( self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): - shared_expert_out, gate = self._compute_shared_expert(input, infer_state, layer_weight) + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) hidden_states = input token_num, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -170,7 +174,7 @@ def _moe_ffn_edp( is_prefill=infer_state.is_prefill, ) ep_output = ep_output.view(token_num, hidden_dim) - add_shared_expert_gate_(ep_output, shared_expert_out, gate) + ep_output.add_(shared_expert_out) return ep_output def _get_qkv( diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index 42e18c7f02..cc3747111c 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -10,6 +10,7 @@ TpParameterWeight, QKVROWNMMWeight, QKGEMMANormWeight, + FusedMoeWeight, ) from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import get_row_slice_mixin @@ -90,56 +91,6 @@ def _get_tp_padded_head_num(self, head_num): ) -class SharedGateUpGateROWMMWeight(MMWeightTpl): - gate_pad_dim = 16 - - def __init__( - self, - in_dim, - inter_size, - weight_names, - data_type, - quant_method=None, - tp_rank=None, - tp_world_size=None, - ): - self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() - self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() - assert ( - inter_size % self.tp_world_size_ == 0 - ), f"inter_size must be divisible by tp_world_size_, found {inter_size} % {self.tp_world_size_}" - super().__init__( - in_dim=in_dim, - out_dims=[inter_size // self.tp_world_size_, inter_size // self.tp_world_size_, self.gate_pad_dim], - weight_names=weight_names, - bias_names=None, - data_type=data_type, - quant_method=quant_method, - tp_rank=self.tp_rank_, - tp_world_size=self.tp_world_size_, - ) - self.tp_param_slicer = get_row_slice_mixin( - self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ - ) - self.gate_param_slicer = get_row_slice_mixin("none", tp_rank=0, tp_world_size=1) - - def _get_param_slicer(self, sub_child_index): - if sub_child_index == 2: - return self.gate_param_slicer - return self.tp_param_slicer - - def load_hf_weights(self, weights): - for sub_child_index, param_name in enumerate(self.weight_names): - if sub_child_index != 2: - self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - continue - if param_name in weights: - weight_pack = self.mm_param_list[sub_child_index] - weight_pack.weight.zero_() - weight_pack.weight[:1].copy_(weights[param_name]) - weight_pack.load_ok[0] = True - - class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): num_full_attention_layers = network_config["full_attention_interval"] @@ -198,10 +149,54 @@ def _init_weight(self): self._init_norm() def _init_moe(self): - super()._init_moe() + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.num_fused_shared_experts = 1 if self._can_fuse_shared_expert_as_moe() else 0 + self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + if self.num_fused_shared_experts > 0: + self.ffn_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.num_fused_shared_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) self._init_gated_ffn() return + def _can_fuse_shared_expert_as_moe(self): + start_args = get_env_start_args() + if not self.is_moe or start_args.enable_ep_moe or start_args.disable_fused_shared_experts: + return False + return ( + self.network_config_.get("shared_expert_intermediate_size") == self.network_config_["moe_intermediate_size"] + ) + def _init_norm(self): hidden_size = self.network_config_["hidden_size"] self.att_norm_weight_ = NoTpGEMMANormWeight( @@ -226,6 +221,8 @@ def _init_gated_ffn(self): hidden_size = self.network_config_["hidden_size"] if "shared_expert_intermediate_size" not in self.network_config_: return + if getattr(self, "num_fused_shared_experts", 0) > 0: + return prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" inter_size = self.network_config_["shared_expert_intermediate_size"] if get_env_start_args().enable_ep_moe: @@ -248,29 +245,13 @@ def _init_gated_ffn(self): tp_world_size=1, ) else: - gate_up_quant = self.get_quant_method("gate_up_proj") - if gate_up_quant.method_name == "none": - self.gate_up_proj = SharedGateUpGateROWMMWeight( - in_dim=hidden_size, - inter_size=inter_size, - weight_names=[ - f"{prefix}.gate_proj.weight", - f"{prefix}.up_proj.weight", - f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - ], - data_type=self.data_type_, - quant_method=gate_up_quant, - ) - self.fused_shared_expert_gate = True - else: - self.gate_up_proj = ROWMMWeight( - in_dim=hidden_size, - out_dims=[inter_size, inter_size], - weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], - data_type=self.data_type_, - quant_method=gate_up_quant, - ) - self.fused_shared_expert_gate = False + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) self.down_proj = COLMMWeight( in_dim=inter_size, out_dims=[hidden_size], @@ -279,17 +260,16 @@ def _init_gated_ffn(self): quant_method=self.get_quant_method("down_proj"), ) - if not getattr(self, "fused_shared_expert_gate", False): - self.ffn_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[1], - weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", - data_type=self.data_type_, - bias_names=None, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) def _split_q_with_gate(self, weights): if self._q_weight_name in weights: @@ -301,6 +281,24 @@ def _split_q_with_gate(self, weights): weights[self._q_weight_name] = _q_proj weights[self._o_gate_weight_name] = _gate_proj + def _rename_shared_expert_to_moe_expert(self, weights): + if getattr(self, "num_fused_shared_experts", 0) == 0: + return + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts.{self.n_routed_experts}" + suffixes = [ + self.experts.quant_method.weight_suffix, + self.experts.quant_method.weight_scale_suffix, + self.experts.quant_method.weight_zero_point_suffix, + ] + for proj_name in ("gate_proj", "up_proj", "down_proj"): + for suffix in suffixes: + if suffix is None: + continue + old_name = f"{old_prefix}.{proj_name}.{suffix}" + if old_name in weights: + weights[f"{new_prefix}.{proj_name}.{suffix}"] = weights[old_name] + def _parse_config(self): super()._parse_config() self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] @@ -446,6 +444,7 @@ def _parse_linear_conv1d(self, weight): def load_hf_weights(self, weights): self._split_q_with_gate(weights) + self._rename_shared_expert_to_moe_expert(weights) if self.is_linear_attention_layer: self._preprocess_weight(weights) super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py index fd89edd2f1..e5b0e282ab 100644 --- a/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py +++ b/lightllm/models/qwen3next/triton_kernel/shared_expert_gate.py @@ -3,37 +3,6 @@ import triton.language as tl -@triton.jit -def _add_shared_expert_gate_kernel( - hidden, - shared, - gate, - stride_h_m: tl.constexpr, - stride_h_n: tl.constexpr, - stride_s_m: tl.constexpr, - stride_s_n: tl.constexpr, - stride_g_m: tl.constexpr, - stride_g_n: tl.constexpr, - N: tl.constexpr, - GATE_N: tl.constexpr, - BLOCK_N: tl.constexpr, -): - row = tl.program_id(0) - offs = tl.arange(0, BLOCK_N) - mask = offs < N - - hidden_ptrs = hidden + row * stride_h_m + offs * stride_h_n - shared_vals = tl.load(shared + row * stride_s_m + offs * stride_s_n, mask=mask, other=0.0).to(tl.float32) - if GATE_N == 1: - gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) - else: - gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) - hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) - gate_vals = tl.sigmoid(gate_vals) - out = hidden_vals + shared_vals * gate_vals - tl.store(hidden_ptrs, out.to(hidden.dtype.element_ty), mask=mask) - - @triton.jit def _sigmoid_mul_kernel( x, @@ -59,33 +28,6 @@ def _sigmoid_mul_kernel( tl.store(x_ptrs, (x_vals * gate_vals).to(x.dtype.element_ty), mask=mask) -@torch.no_grad() -def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - hidden_arg = hidden.view(-1, hidden.shape[-1]) - shared_arg = shared.view(-1, hidden.shape[-1]) - gate_arg = gate.view(-1, gate.shape[-1]) - assert hidden_arg.shape == shared_arg.shape - assert gate_arg.shape[0] == hidden_arg.shape[0] and gate_arg.shape[1] in (1, hidden_arg.shape[1]) - _, n = hidden_arg.shape - block_n = triton.next_power_of_2(n) - _add_shared_expert_gate_kernel[(hidden_arg.shape[0],)]( - hidden_arg, - shared_arg, - gate_arg, - hidden_arg.stride(0), - hidden_arg.stride(1), - shared_arg.stride(0), - shared_arg.stride(1), - gate_arg.stride(0), - gate_arg.stride(1), - n, - gate_arg.shape[1], - BLOCK_N=block_n, - num_warps=8, - ) - return hidden - - @torch.no_grad() def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: x_arg = x.view(-1, x.shape[-1]) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..9de8993de8 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -681,9 +681,12 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) parser.add_argument( - "--enable_fused_shared_experts", + "--disable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when tensor parallelism""", + help=( + "Disable fused shared experts for supported MoE models. " + "It is enabled by default and only works with tensor parallelism." + ), ) parser.add_argument( "--mtp_mode", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 40c8028158..46959827f4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -157,6 +157,7 @@ class StartArgs: enable_ep_moe: bool = field(default=False) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) + disable_fused_shared_experts: bool = field(default=False) mtp_mode: Optional[str] = field( default=None, metadata={