Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,35 @@ def experts(
per_expert_scale=self.per_expert_scale,
)

def fused_shared_experts(
self,
input_tensor: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: int,
num_expert_group: int,
shared_expert_weight: torch.Tensor,
is_prefill: Optional[bool] = None,
) -> torch.Tensor:
return self.fuse_moe_impl.fused_shared_experts(
input_tensor=input_tensor,
router_logits=router_logits,
w13=self.w13,
w2=self.w2,
correction_bias=self.e_score_correction_bias,
scoring_func=self.scoring_func,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
per_expert_scale=self.per_expert_scale,
shared_expert_weight=shared_expert_weight,
)

def low_latency_dispatch(
self,
hidden_states: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,6 @@ def _select_experts(
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
start=self.n_routed_experts,
end=self.n_routed_experts + self.num_fused_shared_experts,
step=1,
dtype=topk_ids.dtype,
device="cuda",
)
.view(1, self.num_fused_shared_experts)
.repeat(topk_ids.shape[0], 1)
)
pad_topk_weights = torch.full(
(topk_weights.shape[0], self.num_fused_shared_experts),
fill_value=1.0,
device="cuda",
dtype=topk_weights.dtype,
)

topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1)
return topk_weights, topk_ids

def _fused_experts(
Expand Down Expand Up @@ -111,6 +90,7 @@ def _fused_experts(
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
shared_expert_id=self.n_routed_experts if self.num_fused_shared_experts > 0 else -1,
)
return input_tensor

Expand Down Expand Up @@ -152,3 +132,56 @@ def __call__(
is_prefill=is_prefill,
)
return output

def fused_shared_experts(
self,
input_tensor: torch.Tensor,
router_logits: torch.Tensor,
w13: WeightPack,
w2: WeightPack,
correction_bias: Optional[torch.Tensor],
scoring_func: str,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: int,
num_expert_group: int,
shared_expert_weight: torch.Tensor,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
):
assert (
type(self) is FuseMoeTriton
), "fused shared expert as MoE is only supported by the Triton fused MoE implementation"
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
router_logits=router_logits,
correction_bias=correction_bias,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
per_expert_scale=per_expert_scale,
)
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn

from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_shared_experts

fused_shared_experts(
hidden_states=input_tensor,
w1=w13_weight,
w2=w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_expert_weight=shared_expert_weight,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
shared_expert_id=self.n_routed_experts,
)
return input_tensor
Loading
Loading