diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 72cd6dc7c48..727816a6c96 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -76,6 +76,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_PD_CHANGEABLE": lambda: os.getenv("FD_PD_CHANGEABLE", "0"), # Whether to use DeepGemm for FP8 blockwise MoE. "FD_USE_DEEP_GEMM": lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "0"))), + # Whether to use DeepGemm for FP8 blockwise MoE. + "FD_USE_BLACKWELL_GEMM": lambda: bool(int(os.getenv("FD_USE_BLACKWELL_GEMM", "0"))), # Whether to use PFCCLab/DeepEP. "FD_USE_PFCC_DEEP_EP": lambda: bool(int(os.getenv("FD_USE_PFCC_DEEP_EP", "0"))), # Whether to use aggregate send. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py new file mode 100644 index 00000000000..bc1bee531fb --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/fused_moe_blackwell_backend.py @@ -0,0 +1,1000 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import threading +from typing import Callable + +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +import fastdeploy +from fastdeploy.model_executor.layers.moe.ep import deep_ep +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + deep_gemm, + paddlefleet_ops, +) +from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.ops.gpu import ( + count_tokens_per_expert_func, + depermute_prefill_combine, + prefill_permute_to_masked_gemm, +) +from fastdeploy.platforms import current_platform +from fastdeploy.utils import register_custom_python_op +from fastdeploy.worker.tbo import let_another_thread_run + +from .fused_moe_backend_base import MoEMethodBase +from .fused_moe_triton_backend import BlockWiseFP8MoEMethod + +if current_platform.is_cuda(): + try: + m_grouped_fp8_gemm_nt_contiguous = deep_gemm.m_grouped_fp8_gemm_nt_contiguous + m_grouped_fp8_gemm_nt_masked = deep_gemm.m_grouped_fp8_gemm_nt_masked + except: + m_grouped_fp8_gemm_nt_contiguous = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous + m_grouped_fp8_gemm_nt_masked = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked +else: + m_grouped_fp8_gemm_nt_contiguous = None + m_grouped_fp8_gemm_nt_masked = None + +from ..utils import get_sm_version + +if current_platform.is_cuda() and get_sm_version() >= 100: + try: + import blackwell_ops + from blackwell_ops import group_gemm_masked + except: + group_gemm_masked = None + +global_values = {} + + +def call_prefill_permute_to_masked_gemm( + x: paddle.Tensor, + scale: paddle.Tensor, + topk_ids: paddle.Tensor, + num_local_experts: int, + max_token_num: int, +): + """ + Permute input tokens and scales from token-major to expert-major layout + for MoE masked GEMM operations. + + Args: + x: Input hidden states [num_tokens, hidden]. + scale: Input scales [num_tokens, hidden_scale]. + topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32). + num_local_experts: Number of local experts on this device. + max_token_num: Maximum tokens per expert buffer. + + Returns: + tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert) + """ + if topk_ids.dtype != paddle.int64: + topk_ids = topk_ids.cast(paddle.int64) + + results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num) + + return results[0], results[1], results[2], results[3] + + +def call_depermute_prefill_combine( + x: paddle.Tensor, + indice_map: paddle.Tensor, + topk_weights: paddle.Tensor, + num_worst_tokens: int, +): + """ + Depermute and combine expert outputs back to token-major layout. + + Args: + x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden]. + indice_map: Flat index tensor [num_worst_tokens, topk] (int32). + topk_weights: Combination weights [num_worst_tokens, topk] (float32). + num_worst_tokens: Number of output tokens to produce. + + Returns: + depermuted_x: Combined output [num_worst_tokens, hidden]. + """ + results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens) + + return results + + +def m_grouped_fp8_gemm_nt_contiguous_custom_python_op_infermeta( + permute_input: "paddle.static.MetaTensor", + permute_scale: "paddle.static.MetaTensor", + layer_added_weight_attrs_0: "paddle.static.MetaTensor", + layer_added_scale_attrs_0: "paddle.static.MetaTensor", + m_indices: "paddle.static.MetaTensor", + layer_added_weight_attrs_1: "paddle.static.MetaTensor", + layer_added_scale_attrs_1: "paddle.static.MetaTensor", + quant_config_weight_block_size_0: int, +): + return paddle.static.MetaTensor( + shape=[permute_input.shape[0], layer_added_weight_attrs_1.shape[1]], dtype=paddle.bfloat16 + ) + + +@register_custom_python_op( + name="m_grouped_fp8_gemm_nt_contiguous_custom", + infer_meta=m_grouped_fp8_gemm_nt_contiguous_custom_python_op_infermeta, + input_names=[ + "permute_input", + "permute_scale", + "layer_added_weight_attrs_0", + "layer_added_scale_attrs_0", + "m_indices", + "layer_added_weight_attrs_1", + "layer_added_scale_attrs_1", + ], + output_names=["ffn_new_out"], + inplace_map={}, +) +def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( + permute_input: paddle.Tensor, + permute_scale: paddle.Tensor, + layer_added_weight_attrs_0: paddle.Tensor, # getattr(layer, self.added_weight_attrs[0]) + layer_added_scale_attrs_0: paddle.Tensor, # getattr(layer, self.added_scale_attrs[0]) + m_indices: paddle.Tensor, + layer_added_weight_attrs_1: paddle.Tensor, # getattr(layer, self.added_weight_attrs[1]) + layer_added_scale_attrs_1: paddle.Tensor, # getattr(layer, self.added_scale_attrs[1]) + quant_config_weight_block_size_0: int, # self.quant_config.weight_block_size[0] + disable_ue8m0_cast: bool, + dst_weights: paddle.Tensor, +): + + # up_gate_proj + ffn_out = paddle.empty( + (permute_input.shape[0], layer_added_weight_attrs_0.shape[1]), + dtype=paddle.bfloat16, + ) + # if disable_ue8m0_cast: + if permute_scale.strides[0] != 1: + permute_scale = permute_scale.transpose([1, 0]).contiguous() + permute_scale = permute_scale.transpose([1, 0]) + # disable_ue8m0_cast is False for SM100 + m_grouped_fp8_gemm_nt_contiguous( + (permute_input, permute_scale), + (layer_added_weight_attrs_0, layer_added_scale_attrs_0), + ffn_out, + m_indices, + ) + + # swiglu + if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: + ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant( + ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=not disable_ue8m0_cast + ) + + ffn_in_x_scale_tensor = paddle.transpose(paddle.transpose(ffn_in_x_scale_tensor, [1, 0]).contiguous(), [1, 0]) + else: + ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out) + + # down_proj + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( + ffn_out, quant_config_weight_block_size_0, not disable_ue8m0_cast + ) + + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous() + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) + else: + ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + ffn_out, + using_pow2_scale=not disable_ue8m0_cast, + using_ue8m0_scale=not disable_ue8m0_cast, + ) + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] + + ffn_out = paddle.empty( + (permute_input.shape[0], layer_added_weight_attrs_1.shape[1]), + dtype=paddle.bfloat16, + ) + # disable_ue8m0_cast is False for SM100 + m_grouped_fp8_gemm_nt_contiguous( + (ffn_in_x, ffn_in_x_scale_tensor), + (layer_added_weight_attrs_1, layer_added_scale_attrs_1), + ffn_out, + m_indices, + ) + return ffn_out + + +def moe_topk_select( + gating_output: paddle.Tensor, + n_group: int, + topk_group: int, + top_k: int, + routed_scaling_factor: float, + e_score_correction_bias: paddle.Tensor, + renormalize: bool = False, +): + """ + Topk selection using paddle PHI topk API. + + Args: + gating_output: gate output logits, shape [seq_len, n_experts] + n_group: number of expert groups + topk_group: number of top-k groups to select + top_k: number of top experts per token + routed_scaling_factor: scaling factor for routed experts + e_score_correction_bias: bias for expert selection + renormalize: whether to renormalize topk probabilities + + Returns: + topk_weights: normalized topk probabilities, shape [seq_len, top_k] + topk_ids: topk expert indices, shape [seq_len, top_k] + """ + # compute gate probs via sigmoid + gate_probs = paddle.nn.functional.sigmoid(gating_output) + # probs_for_choice includes correction bias for topk selection + probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs + # group-based topk selection + n_group = n_group if n_group > 0 else 1 + topk_group = topk_group if topk_group > 0 else 1 + if n_group > 1 and topk_group < n_group: + seq_length, n_experts = probs_for_choice.shape + group_scores = ( + probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + ) # [seq_len, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis( + group_idx, paddle.to_tensor(1.0, dtype=group_scores.dtype), axis=-1 + ) + score_mask = ( + group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1]) + ) # [seq_len, n_experts] + probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) + + _, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1) + topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1) + + # normalize combine weights + if renormalize: + topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12) + + # apply routed scaling factor + if routed_scaling_factor: + topk_weights = topk_weights * routed_scaling_factor + + return topk_weights, topk_ids + + +class BlackwellGemmFusedMoeMethod(MoEMethodBase): + """ + DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend. + """ + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + deepgemm create weight process. + """ + BlockWiseFP8MoEMethod.create_weights(self, layer, **extra_weight_attrs) + + def process_weights_after_loading(self, layer): + """ """ + BlockWiseFP8MoEMethod.process_weights_after_loading(self, layer) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + deepgemm create weight process. + """ + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + + self.check(layer, up_gate_proj_weights, down_proj_weights) + + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): + weight_name = self.added_weight_attrs[idx] + scale_name = self.added_scale_attrs[idx] + + weight_list = [] + weight_scale_list = [] + for i in range(layer.num_local_experts): + from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8 + + quant_weight, scale = per_block_cast_to_fp8(weight_tensor[i], self.quant_config.weight_block_size) + + weight_list.append(quant_weight) + weight_scale_list.append(scale) + quanted_weight = paddle.stack(weight_list, axis=0) + quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous() + getattr(layer, weight_name).copy_(quanted_weight, False) + + quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) + quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous() + getattr(layer, scale_name).set_value(quanted_weight_scale) + + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): + """ + Paddle cutlass process prequanted weights. + """ + up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None) + up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) + down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) + + up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange + ) + # self.check(layer, up_gate_proj_weights, down_proj_weights) + up_gate_proj_weight_scale = [] + down_proj_weight_scale = [] + + if isinstance(state_dict, list): + state_dict = dict(state_dict) + + for expert_idx in logical_expert_ids: + up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx) + down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx) + + up_gate_proj_weight_scale.append( + get_tensor( + ( + state_dict.pop(up_gate_proj_expert_weight_scale_key_name) + if up_gate_proj_expert_weight_scale_key_name in state_dict + else up_gate_proj_expert_weight_scale_key_name + ), + layer.fd_config.model_config.model, + ) + ) + down_proj_weight_scale.append( + get_tensor( + ( + state_dict.pop(down_proj_expert_weight_scale_key_name) + if down_proj_expert_weight_scale_key_name in state_dict + else down_proj_expert_weight_scale_key_name + ), + layer.fd_config.model_config.model, + ) + ) + + if not self.quant_config.deepgemm_scale_ue8m0: + up_gate_proj_weight = ( + paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") + ) + down_proj_weight = ( + paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") + ) + up_gate_proj_weight_scale = ( + paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous() + ) + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous() + else: + up_gate_proj_weight = ( + paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") + ) + down_proj_weight = ( + paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") + ) + up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]) + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]) + + name_tensor_map = { + "up_gate_proj_weight": up_gate_proj_weight, + "down_proj_weight": down_proj_weight, + "up_gate_proj_weight_scale_inv": up_gate_proj_weight_scale, + "down_proj_weight_scale_inv": down_proj_weight_scale, + } + for name, tensor in name_tensor_map.items(): + getattr(layer, name).data = tensor + + def apply_ep_prefill( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + shared_experts: nn.Layer = None, + ) -> paddle.Tensor: + """ + Apply the EP prefill method. + """ + gate_out = gate(x) + gate_out = gate_out.cast("float32") + + hidden_size = x.shape[1] + + # 1. Select topk experts and weights + if ( + fastdeploy.envs.FD_USE_PHI_MOE_TOPK + and layer.redundant_table_manger is None + and layer.topk_method == "noaux_tc" + ): + topk_weights, topk_idx = moe_topk_select( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + else: + topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + + # 2. Dynamic compute blockwise quantization scales + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + x_fp8, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( + x, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0 + ) + else: + x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + x_scale_tensor = ( + x_scale_tensor[: x.shape[0]] + if not self.quant_config.deepgemm_scale_ue8m0 + else x_scale_tensor.T[: x.shape[0]] + ) + + event = deep_ep.Buffer.capture() + + if self.ep_prefill_runner.num_worst_tokens <= 0: + let_another_thread_run() + # 3. EP Dispatch + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + handle, + event, + ) = self.ep_prefill_runner.dispatch( + x_fp8, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128, previous_event=event + ) + + if self.ep_prefill_runner.num_worst_tokens > 0: + let_another_thread_run() + + thread_name = threading.current_thread().name + + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() + + global global_values + + if thread_name not in global_values: + global_values[thread_name] = {} + + (recv_x_value, recv_x_scale) = recv_x + (recv_x_value, recv_x_scale) = recv_x + + global_values[thread_name]["x"] = x + global_values[thread_name]["topk_idx"] = topk_idx + global_values[thread_name]["topk_weights"] = topk_weights + global_values[thread_name]["x_scale_tensor"] = x_scale_tensor + + global_values[thread_name]["recv_x_value"] = recv_x_value + global_values[thread_name]["recv_x_scale"] = recv_x_scale + global_values[thread_name]["recv_topk_idx"] = recv_topk_idx + global_values[thread_name]["recv_topk_weights"] = recv_topk_weights + global_values[thread_name]["handle"] = handle + global_values[thread_name]["recv_num_tokens_per_expert_list"] = recv_num_tokens_per_expert_list + + token_all_num = sum(recv_num_tokens_per_expert_list) + + # Note(ZKK): + # below code have many del, so ugly! + # but considering MoE Prefill will reach peak GPU memory, + # so here we manually del a var as soon as it's not used. + + # 4. Compute ffn + if self.ep_prefill_runner.num_worst_tokens > 0: + token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1 + max_tokens_per_rank = ( + layer.fd_config.scheduler_config.max_num_batched_tokens + // layer.fd_config.parallel_config.tensor_parallel_size + // token_split_factor + ) + + logger.debug(f"max_tokens_per_rank {max_tokens_per_rank}") + + permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = ( + call_prefill_permute_to_masked_gemm( + x=recv_x_value, + scale=recv_x_scale, + topk_ids=recv_topk_idx, + num_local_experts=layer.num_local_experts, + max_token_num=layer.ep_size * max_tokens_per_rank, + ) + ) + + up_gate_proj_out = paddle.zeros( + [ + layer.num_local_experts, + layer.ep_size * max_tokens_per_rank, + layer.moe_intermediate_size * 2, + ], + dtype=paddle.bfloat16, + ) + permute_input = permute_input.reshape([-1, permute_input.shape[-1]]) + + permute_scale_new = blackwell_ops.unpack_and_convert_scale(permute_scale, token_nums_per_expert) + + # masked group gemm + # a: [num_local_experts * expected_m, k] + # b: [num_local_experts, n, k] + # sfa: [num_local_experts * expected_m, k // 32] + # sfb: [num_local_experts, n, k // 32] + # masked_m: [num_local_experts] + # out: [num_local_experts * expected_m, n] + # bias: [num_local_experts, n] Optional + # max_m_per_expert: int + # sm_cout: -1 + group_gemm_masked( + permute_input, + getattr(layer, self.added_weight_attrs[0]), + permute_scale_new, + getattr(layer, self.added_scale_attrs[0] + "_bw"), # weight_scale_new + token_nums_per_expert, + up_gate_proj_out, + None, + layer.ep_size * max_tokens_per_rank, + -1, + ) + + act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.fused_mask_swiglu_fp8_quant( + up_gate_proj_out, + token_nums_per_expert, + self.quant_config.weight_block_size[0], + use_ue8m0=self.quant_config.deepgemm_scale_ue8m0, + ) + + if layer.hidden_size == layer.moe_intermediate_size * 2: + ffn_out = up_gate_proj_out + else: + ffn_out = paddle.zeros( + [ + layer.num_local_experts, + layer.ep_size * max_tokens_per_rank, + layer.hidden_size, + ], + dtype=paddle.bfloat16, + ) + + act_out_fp8 = act_out_fp8.reshape([-1, act_out_fp8.shape[-1]]) + + act_out_fp8_scale = blackwell_ops.unpack_and_convert_scale(scale, token_nums_per_expert) + + group_gemm_masked( + act_out_fp8, + getattr(layer, self.added_weight_attrs[1]), + act_out_fp8_scale, + getattr(layer, self.added_scale_attrs[1] + "_bw"), # weight_scale_new + token_nums_per_expert, + ffn_out, + None, + layer.ep_size * max_tokens_per_rank, + -1, + ) + + tmp_ffn_out = call_depermute_prefill_combine( + x=ffn_out, + indice_map=permuted_indice_map, + topk_weights=recv_topk_weights, + num_worst_tokens=recv_x_value.shape[0], + ) + + elif token_all_num > 0: + logger.debug(f"token_all_num {token_all_num}") + + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE: + recv_topk_idx = recv_topk_idx.astype(paddle.int32) + ( + permute_input, + permute_indices_per_token, # == zipped_expertwise_rowmap + dst_weights, + permute_scale, + m_indices, + ) = paddle.nn.functional.moe_permute( + hidden_states=recv_x_value, + scale=recv_x_scale, + expert_routemap_topk=recv_topk_idx, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + return_expert_indices=True, + override_buffer_size=token_all_num, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + else: + token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) + ( + permute_input, + permute_scale, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + recv_num_tokens_per_expert_list_padded_cumsum, + dst_weights, + dst_indices, + cumsum_idx_gpu, + m_indices, + ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8( + recv_x_value, + recv_x_scale, + recv_topk_idx, + recv_topk_weights, + token_nums_this_rank[0], + token_nums_this_rank[1], + True, # use_in_ep + token_all_num, + ) + + assert permute_input.shape[0] == token_all_num + + if permute_scale.strides[0] != 1: + permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0]) + + # up_gate_proj + ffn_out = paddle.empty( + (token_all_num, getattr(layer, self.added_weight_attrs[0]).shape[1]), + dtype=paddle.bfloat16, + ) + m_grouped_fp8_gemm_nt_contiguous( + (permute_input, permute_scale), + (getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])), + ffn_out, + m_indices, + ) + + if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: + ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant( + ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0 + ) + + ffn_in_x_scale_tensor = paddle.transpose( + paddle.transpose(ffn_in_x_scale_tensor, [1, 0]).contiguous(), [1, 0] + ) + else: + # swiglu + ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None) + + # down_proj + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( + ffn_out, self.quant_config.weight_block_size[0], self.quant_config.deepgemm_scale_ue8m0 + ) + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0]) + else: + ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + ffn_out, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] + + ffn_out = paddle.empty( + (token_all_num, getattr(layer, self.added_weight_attrs[1]).shape[1]), + dtype=paddle.bfloat16, + ) + m_grouped_fp8_gemm_nt_contiguous( + (ffn_in_x, ffn_in_x_scale_tensor), + (getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])), + ffn_out, + m_indices, + ) + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE: + tmp_ffn_out, out_probs = paddle.nn.functional.moe_unpermute( + hidden_states_unzipped=ffn_out, + zipped_expertwise_rowmap=permute_indices_per_token, + expert_routemap_topk=recv_topk_idx, + token_prob_unzipped=dst_weights, + total_zipped_tokens=recv_x_value.shape[0], + num_experts=layer.num_local_experts, + using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE, + ) + + else: + # prmt back per rank + tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( + ffn_out, + dst_weights, + permute_indices_per_token, + dst_indices, + None, # down_proj_bias + False, # norm_topk_prob + 1.0, + ) + else: + tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16) + + if shared_experts is not None: + s_x = shared_experts(x) + + # 5. EP combine + event = deep_ep.Buffer.capture() + if self.ep_prefill_runner.num_worst_tokens <= 0: + let_another_thread_run() + + global_values[thread_name]["combine_in"] = tmp_ffn_out + tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event) + + if self.ep_prefill_runner.num_worst_tokens > 0: + let_another_thread_run() + + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() + + global_values[thread_name]["combine_out"] = tmp_ffn_out + if shared_experts is not None: + tmp_ffn_out += s_x + + return tmp_ffn_out + + def apply_ep_decode( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + shared_experts: nn.Layer = None, + ) -> paddle.Tensor: + """ + Apply the EP decoder method. + """ + gate_out = gate(x) + gate_out = gate_out.cast("float32") + # 1. Select topk experts and weights + topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + + # 2. EP Dispatch + permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( + x, topk_idx, topk_weights, use_fp8=True, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0 + ) + # 3. Compute ffn + assert isinstance(permute_input, tuple) + up_gate_proj_out = paddle.zeros( + [ + layer.num_local_experts, + layer.ep_size * layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, + layer.moe_intermediate_size * 2, + ], + dtype=paddle.bfloat16, + ) + + ffn_out = paddle.empty( + [ + layer.num_local_experts, + layer.ep_size * layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, + layer.hidden_size, + ], + dtype=paddle.bfloat16, + ) + + expected_m = 128 + + # group_gemm_masked( + # permute_input[0], + # getattr(layer, self.added_weight_attrs[0]), + # permute_input[1], + # getattr(layer, self.added_scale_attrs[0]), + # up_gate_proj_out, + # token_nums_per_expert, + # None, + # layer.ep_size * max_tokens_per_rank, + # -1, + # ) + + # disable_ue8m0_cast is False for SM100 + m_grouped_fp8_gemm_nt_masked( + permute_input, + ( + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_scale_attrs[0]), + ), + up_gate_proj_out, + token_nums_per_expert, + expected_m, + ) + + act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.fused_mask_swiglu_fp8_quant( + up_gate_proj_out, + token_nums_per_expert, + self.quant_config.weight_block_size[0], + use_ue8m0=self.quant_config.deepgemm_scale_ue8m0, + ) + + # group_gemm_masked( + # act_out_fp8, + # getattr(layer, self.added_weight_attrs[1]), + # scale + # getattr(layer, self.added_scale_attrs[1]), + # ffn_out, + # token_nums_per_expert, + # None, + # layer.ep_size * max_tokens_per_rank, + # -1, + # ) + # disable_ue8m0_cast is False for SM100 + m_grouped_fp8_gemm_nt_masked( + (act_out_fp8, scale), + ( + getattr(layer, self.added_weight_attrs[1]), + getattr(layer, self.added_scale_attrs[1]), + ), + ffn_out, + token_nums_per_expert, + expected_m, + ) + + if shared_experts is not None: + s_x = shared_experts(x) + + # 4. EP combine + out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle) + + if shared_experts is not None: + out += s_x + return out + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + ) -> paddle.Tensor: + """ + Paddle Use DeepGemm compute Fused MoE. + below is TP compute method. + """ + gate_out = gate(x) + gate_out = gate_out.cast("float32") + + if layer.topk_method == "noaux_tc": + + if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK: + _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + else: + topk_weights, topk_ids = moe_topk_select( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( + x, 128, self.quant_config.deepgemm_scale_ue8m0 + ) + else: + recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + recv_x_scale = ( + recv_x_scale[: recv_x.shape[0]] + if not self.quant_config.deepgemm_scale_ue8m0 + else recv_x_scale.T[: recv_x.shape[0]] + ) + + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE: + topk_ids = topk_ids.astype(paddle.int32) + override_buffer_size = recv_x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) + ( + permute_input, + permute_indices_per_token, # == zipped_expertwise_rowmap + dst_weights, + permute_scale, + m_indices, + ) = paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=recv_x_scale, + expert_routemap_topk=topk_ids, + expert_prob_topk=topk_weights, + num_experts=layer.num_experts, + tokens_per_expert=[], + padding_alignment=128, + return_expert_indices=True, + override_buffer_size=override_buffer_size, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + else: + tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) + ( + permute_input, + permute_scale, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + recv_num_tokens_per_expert_list_padded_cumsum, + dst_weights, + dst_indices, + cumsum_idx_gpu, + m_indices, + ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8( + recv_x, + recv_x_scale, + topk_ids, + topk_weights, + tmp[0], + tmp[1], + False, # use_in_ep + -1, + ) + + ffn_out = m_grouped_fp8_gemm_nt_contiguous_custom_python_op( + permute_input, + permute_scale, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_scale_attrs[0]), + m_indices, + getattr(layer, self.added_weight_attrs[1]), + getattr(layer, self.added_scale_attrs[1]), + self.quant_config.weight_block_size[0], + disable_ue8m0_cast=not self.quant_config.deepgemm_scale_ue8m0, + dst_weights=dst_weights if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE else None, + ) + + # prmt back per rank + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE: + tmp_ffn_out, out_probs = paddle.nn.functional.moe_unpermute( + hidden_states_unzipped=ffn_out, + zipped_expertwise_rowmap=permute_indices_per_token, + expert_routemap_topk=topk_ids, + token_prob_unzipped=dst_weights, + total_zipped_tokens=recv_x.shape[0], + num_experts=layer.num_experts, + using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE, + ) + else: + tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( + ffn_out, + dst_weights, + permute_indices_per_token, + dst_indices, + None, + False, # norm_topk_prob + 1.0, + ) + return tmp_ffn_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index d1db43a3241..ed89e3ed3bb 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,6 +20,7 @@ from paddle import nn import fastdeploy +from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import ( TensorTracker, @@ -1672,6 +1673,7 @@ def _process_quantize(weight_idx): weight[expert_id].copy_(w_q, False) scale_list.append(s_ue8m0) scale = paddle.to_tensor(scale_list) + scale = scale.transpose([0, 2, 1]).contiguous().transpose([0, 2, 1]) free_tensor(getattr(layer, unquantized_weight_name)) free_tensor(getattr(layer, weight_name)) @@ -1700,7 +1702,22 @@ def _process_quantize(weight_idx): else: getattr(layer, weight_name).copy_(weight, False) scale_param = getattr(layer, scale_name) - scale_param.data = scale.transpose([0, 2, 1]).contiguous().transpose([0, 2, 1]) + scale_param.data = scale + if bool(envs.FD_USE_BLACKWELL_GEMM): + import blackwell_ops + + scale_bw = blackwell_ops.unpack_and_convert_scale(scale, None) + scale_bw_name = scale_name + "_bw" + setattr( + layer, + scale_bw_name, + layer.create_parameter( + shape=scale_bw.shape, + dtype=scale_bw.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, scale_bw_name).copy_(scale_bw, False) if self.quant_config.is_checkpoint_bf16: # dynamic quantize diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index 007cc0fddd2..de4ea3458ea 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -66,6 +66,7 @@ def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool self.quant_min_bound = -448 self.quant_round_type = 1 self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM) + self.use_blackwell_gemm = bool(envs.FD_USE_BLACKWELL_GEMM) self.is_checkpoint_bf16 = is_checkpoint_bf16 self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False @@ -83,7 +84,16 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]: Get quantization method. """ if isinstance(layer, FusedMoE): - if layer.ep_size > 1 or self.use_deep_gemm: + if self.use_blackwell_gemm: + assert ( + self.use_deep_gemm + ), "Blackwell gemm is supported only for prefill moe, please set FD_USE_DEEP_GEMM=1 as well" + from fastdeploy.model_executor.layers.moe.fused_moe_blackwell_backend import ( + BlackwellGemmFusedMoeMethod, + ) + + return BlackwellGemmFusedMoeMethod(self) + elif layer.ep_size > 1 or self.use_deep_gemm: from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import ( DeepGemmFusedMoeMethod, ) diff --git a/fastdeploy/model_executor/layers/quantization/fp8_utils.py b/fastdeploy/model_executor/layers/quantization/fp8_utils.py index 65d30d4004d..bdeda8640e5 100644 --- a/fastdeploy/model_executor/layers/quantization/fp8_utils.py +++ b/fastdeploy/model_executor/layers/quantization/fp8_utils.py @@ -26,6 +26,8 @@ if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import per_token_group_fp8_quant +import numpy as np + from ..utils import get_sm_version @@ -259,3 +261,82 @@ def fused_stack_transpose_quant(expert_weight_list, use_ue8m0=False): raise RuntimeError("'fuse_stack_transpose_fp8_quant' is not available in the current paddlefleet_ops.") return w, scale + + +def reorder_sf_to_cutlass(sf_logical, mn_dim, kb_dim): + """ + 将逻辑布局的 scale factor (UE8M0 uint8) 重排为 CUTLASS SfAtom 交错布局。 + + CUTLASS Sm1xxBlockScaledConfig 要求 scale factor 按照 SfAtom 的特殊交错模式排列: + SfAtom shape: ((32, 4), (SFVecSize, 4)) + SfAtom stride: ((16, 4), (0, 1)) + 即每 128 行的 MN 元素被分成 (32, 4) 的子块,并与 K 维度的 4 个 scale 交错存储。 + + 参数: + sf_logical: 逻辑布局的 scale factor tensor, shape (..., mn_dim, kb_dim), dtype=uint8 + mn_dim: M 或 N 维度大小 + kb_dim: K // block_size (scale factor 的 K 维度) + 返回: + 重排后的 scale factor tensor, 相同 shape 和 dtype + """ + sf_np = sf_logical.numpy() + orig_shape = sf_np.shape + flat = sf_np.reshape(-1, mn_dim, kb_dim) + + # 向量化计算: 构建所有 (n, kb) 索引到 (n_phys, kb_phys) 的映射 + n_idx = np.arange(mn_dim) + kb_idx = np.arange(kb_dim) + n_grid, kb_grid = np.meshgrid(n_idx, kb_idx, indexing="ij") # (mn_dim, kb_dim) + + n_tile = n_grid // 128 + n_local = n_grid % 128 + mn_i = n_local % 32 + mn_j = n_local // 32 + k_tile = kb_grid // 4 + sf_l = kb_grid % 4 + num_k_tiles = kb_dim // 4 + + cutlass_byte = (n_tile * num_k_tiles + k_tile) * 512 + mn_i * 16 + mn_j * 4 + sf_l + n_phys = cutlass_byte // kb_dim + kb_phys = cutlass_byte % kb_dim + + # 用高级索引一次性完成重排 + result = np.empty_like(flat) + result[:, n_phys, kb_phys] = flat[:, n_grid, kb_grid] + + return paddle.to_tensor(result.reshape(orig_shape), dtype=paddle.uint8) + + +def unpack_and_convert_scale(x, masked_m=None): + """ + 方案4:使用 tile 实现重复 + """ + if hasattr(x, "numpy"): + x_np = x.numpy() + else: + x_np = np.array(x) + + e, m, n = x_np.shape + + # 确保数组是连续的 + if not x_np.flags["C_CONTIGUOUS"]: + x_np = np.ascontiguousarray(x_np) + + # 将 int32 数组视为 uint8 数组(直接内存重解释) + x_uint8 = x_np.view(np.uint8).reshape(e, m, n, 4) + + # 使用 repeat 重复每个值4次 + # 先增加一个维度,然后 repeat + x_expanded = x_uint8[:, :, :, :, np.newaxis] # [e, m, n, 4, 1] + x_repeated = np.repeat(x_expanded, 4, axis=-1) # [e, m, n, 4, 4] + + # 重新排列维度 + x_reshaped = x_repeated.reshape(e, m, n, 16) + + # 合并最后两维 + output_np = x_reshaped.reshape(e, m, n * 16) + + output = paddle.to_tensor(output_np, dtype=paddle.uint8) + output = reorder_sf_to_cutlass(output, m, n * 16) + + return output