From 41f3947092033eaa00dcecce62a4a7b903d44bf2 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 8 May 2026 17:34:09 +0800 Subject: [PATCH 01/13] feat: deep_ep v2 and upgrade cuda 13.0 --- docker/Dockerfile | 53 +-- docker/scripts/build.sh | 14 +- .../layer_infer/cache_tensor_manager.py | 5 +- .../fused_moe/fused_moe_weight.py | 27 ++ .../fused_moe/impl/deepgemm_impl.py | 154 ++++++-- .../fused_moe/grouped_fused_moe_ep.py | 75 ++-- lightllm/common/quantization/deepgemm.py | 72 ++++ lightllm/distributed/communication_op.py | 93 +++-- .../layer_infer/transformer_layer_infer.py | 15 +- lightllm/models/deepseek2/model.py | 7 +- lightllm/models/gemma4/model.py | 7 +- lightllm/models/glm4_moe_lite/model.py | 7 +- .../layer_infer/transformer_layer_infer.py | 15 +- lightllm/models/qwen3_moe/model.py | 7 +- lightllm/models/qwen3next/model.py | 7 - lightllm/server/api_start.py | 6 +- lightllm/utils/device_utils.py | 5 + lightllm/utils/dist_check_utils.py | 4 +- lightllm/utils/envs_utils.py | 17 +- requirements.txt | 20 +- test/benchmark/service/benchmark_client.py | 9 +- test/benchmark/service/benchmark_multiturn.py | 347 +++++++++++++----- test/benchmark/service/benchmark_qps.py | 9 +- 23 files changed, 724 insertions(+), 251 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 439ecddb34..bba404c965 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,14 +1,17 @@ -ARG CUDA_VERSION=12.8.0 +ARG CUDA_VERSION=13.0.0 FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.16.0 +ARG VLLM_VERSION=0.21.0 +ARG NIXL_REF=v1.1.0 ARG FLASH_MLA_REF=47c35a7 +ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 ARG ENABLE_CACHE=1 +ARG ENABLE_SM100=0 ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda @@ -44,13 +47,18 @@ WORKDIR /root COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip -RUN pip install -r /lightllm/requirements.txt --no-cache-dir -RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 \ + vllm==${VLLM_VERSION} +RUN pip install -r /lightllm/requirements.txt --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 +RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \ + git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ cd /root/FlashMLA && \ git checkout ${FLASH_MLA_REF} && \ git submodule update --init --recursive && \ - FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . + FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \ + pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* @@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \ RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \ set -e; \ ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \ - NVSHMEM_VERSION=3.3.9; \ - CUDA_ARCHS=90; \ - wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \ - && cd nvshmem \ - && rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \ - && cmake --build build --target install -j64; \ - DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \ - cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \ - cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \ + python -m pip install --upgrade --no-deps \ + "nvidia-nccl-cu13==2.30.4" \ + "nvidia-nvshmem-cu13==3.6.5"; \ + cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \ + pip install --no-build-isolation .; \ fi +RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \ + cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \ + git submodule update --init --recursive && \ + pip install --no-build-isolation . + RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ @@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y pkg-config tmux net-tools && \ cd /usr/local/src; \ pip install --upgrade meson pybind11 patchelf; \ - git clone https://github.com/ai-dynamo/nixl.git -b main && \ + git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \ cd nixl && \ rm -rf build && \ mkdir build && \ diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 355d6c65b3..bc1fd73da3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -18,21 +18,23 @@ set -euo pipefail # --no-nixl Disable NIXL (default: enabled) # --no-cache Disable cache (default: enabled) # --lite Disable DEEPEP, NIXL and cache in one shot -# --cuda-version CUDA version (default: 12.8.0) +# --cuda-version CUDA version (default: 13.0.0) # --image-prefix Image prefix (default: lightllm) # --image-tag Image tag (default: generated from enabled features) +# --enable-sm100 Enable SM100 support (default: disabled) # -h / --help Show help ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" cd "${ROOT_DIR}" IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}" -CUDA_VERSION="${CUDA_VERSION:-12.8.0}" +CUDA_VERSION="${CUDA_VERSION:-13.0.0}" IMAGE_TAG="${IMAGE_TAG:-}" ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}" ENABLE_NIXL="${ENABLE_NIXL:-1}" ENABLE_CACHE="${ENABLE_CACHE:-1}" +ENABLE_SM100="${ENABLE_SM100:-0}" print_help() { sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//' @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do --no-deepep) ENABLE_DEEPEP=0 ;; --no-nixl) ENABLE_NIXL=0 ;; --no-cache) ENABLE_CACHE=0 ;; + --enable-sm100) ENABLE_SM100=1 ;; --lite) ENABLE_DEEPEP=0 ENABLE_NIXL=0 @@ -78,13 +81,16 @@ done # - Other combos: composed from enabled feature names if [[ -z "${IMAGE_TAG}" ]]; then tag_parts=() + if [[ "${ENABLE_SM100}" -eq 1 ]]; then + tag_parts+=("sm100") + fi if [[ "${ENABLE_NIXL}" -eq 1 ]]; then tag_parts+=("nixl") fi if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then tag_parts+=("deepep") fi - if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then + if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then IMAGE_TAG="cuda${CUDA_VERSION}" else prefix="" @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ + --build-arg ENABLE_SM100="${ENABLE_SM100}" \ --progress=plain \ -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . - diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 7889e8090e..8bcf99b992 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -33,6 +33,7 @@ class BufNode: inner_tensor: torch.Tensor shape_key: Tuple[int, torch.dtype] storage_weak_ptr: int + free_use_count_bias: int = 0 shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict) def __del__(self): @@ -99,7 +100,8 @@ def alloc_tensor( # 回收可能消亡的 tensor for ptr in self.changed_ptr: t_buf_node = self.ptr_to_bufnode[ptr] - if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor): + free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) + if self.use_count(ptr) <= free_use_count: self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) self.changed_ptr.clear() @@ -131,6 +133,7 @@ def alloc_tensor( self.ptr_to_bufnode[storage_weak_ptr] = buf_node if shape not in buf_node.shape_to_tensor: buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape) + buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor)) mark_tensor = buf_node.shape_to_tensor[shape] ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断 ans.storage_weak_ptr = buf_node.storage_weak_ptr diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..375725d124 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args +from lightllm.utils.device_utils import is_sm100_gpu from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger @@ -52,6 +53,7 @@ def __init__( self.quant_method = quant_method assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." self.enable_ep_moe = get_env_start_args().enable_ep_moe + self.quant_method = self._maybe_upgrade_quant_method_for_ep_moe(self.quant_method) self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts self._init_config(network_config) @@ -70,6 +72,28 @@ def __init__( self.lock = threading.Lock() self._create_weight() + def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMethod) -> QuantizationMethod: + if not self.enable_ep_moe: + return quant_method + + target_method = "deepgemm-fp8fp4-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128" + if quant_method.method_name == "none": + from lightllm.common.quantization.registry import QUANTMETHODS + + logger.info( + f"enable_ep_moe requires DeepGEMM MoE expert weights; " + f"auto-upgrading fused_moe quantization from `none` to `{target_method}`." + ) + quant_method = QUANTMETHODS.get(target_method) + + if quant_method.method_name != target_method: + raise ValueError( + f"enable_ep_moe currently requires `{target_method}` for fused_moe on this GPU, " + f"but got `{quant_method.method_name}`." + ) + + return quant_method + def _init_config(self, network_config: Dict[str, Any]): self.n_group = network_config.get("n_group", 0) self.use_grouped_topk = self.n_group > 0 @@ -152,6 +176,9 @@ def experts( per_expert_scale=self.per_expert_scale, ) + def use_sm100_mega_moe(self) -> bool: + return bool(getattr(self.fuse_moe_impl, "_use_sm100_fp4_moe", lambda: False)()) + def low_latency_dispatch( self, hidden_states: torch.Tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index c9b8cfa3eb..2adc4343e2 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -4,11 +4,14 @@ from lightllm.distributed import dist_group_manager from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.common.quantization.quantize_method import WeightPack -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, + deepgemm_grouped_fp8_nt_contiguous, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -17,9 +20,84 @@ from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair +from lightllm.utils.device_utils import is_sm100_gpu class FuseMoeDeepGEMM(FuseMoeTriton): + def _get_ep_num_sms(self) -> int: + return getattr(dist_group_manager, "ep_num_sms", None) or 0 + + def _use_sm100_fp4_moe(self) -> bool: + return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp8fp4-b32" + + def _get_mega_moe_weights(self, w13: WeightPack, w2: WeightPack): + cache_key = ( + w13.weight.data_ptr(), + w13.weight_scale.data_ptr(), + w2.weight.data_ptr(), + w2.weight_scale.data_ptr(), + ) + if getattr(self, "_mega_moe_weight_cache_key", None) != cache_key: + import deep_gemm + + self._mega_moe_weight_cache = deep_gemm.transform_weights_for_mega_moe( + (w13.weight, w13.weight_scale), + (w2.weight, w2.weight_scale), + ) + self._mega_moe_weight_cache_key = cache_key + return self._mega_moe_weight_cache + + def _get_mega_moe_stats(self, num_local_experts: int, device: torch.device): + stats = getattr(self, "_mega_moe_stats", None) + if stats is None or stats.numel() != num_local_experts or stats.device != device: + stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) + self._mega_moe_stats = stats + return stats + + def _mega_moe( + self, + hidden_states: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor: + import deep_gemm + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) + if buffer is None: + raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") + + num_tokens = hidden_states.shape[0] + if num_tokens > buffer.num_max_tokens_per_rank: + raise RuntimeError( + f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" + ) + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + l1_weights, l2_weights = self._get_mega_moe_weights(w13, w2) + cumulative_stats = self._get_mega_moe_stats(w13.weight.shape[0], hidden_states.device) + buffer.x[:num_tokens].copy_(qinput_tensor[0]) + buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) + buffer.topk_idx[:num_tokens].copy_(topk_ids) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + output = torch.empty_like(hidden_states) + deep_gemm.fp8_fp4_mega_moe( + output, + l1_weights, + l2_weights, + buffer, + cumulative_local_expert_recv_stats=cumulative_stats, + ) + return output + def _select_experts( self, input_tensor: torch.Tensor, @@ -74,7 +152,11 @@ def _fused_experts( ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale + if self._use_sm100_fp4_moe(): + return self._mega_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long)) + use_fp8_w8a8 = self.quant_method.method_name != "none" + buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer output = fused_experts_impl( hidden_states=input_tensor, w1=w13_weight, @@ -82,7 +164,7 @@ def _fused_experts( topk_weights=topk_weights, topk_idx=topk_ids.to(torch.long), num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, + buffer=buffer, is_prefill=is_prefill, use_fp8_w8a8=use_fp8_w8a8, use_fp8_all2all=use_fp8_w8a8, @@ -118,13 +200,13 @@ def low_latency_dispatch( ) topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() use_fp8_w8a8 = self.quant_method.method_name != "none" - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch( + topk_idx=topk_idx, + x=hidden_states, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + num_experts=self.total_expert_num_contain_redundancy, use_fp8=use_fp8_w8a8, async_finish=False, return_recv_hook=True, @@ -156,6 +238,17 @@ def select_experts_and_quant_input( scoring_func=scoring_func, ) w13_weight, w13_scale = w13.weight, w13.weight_scale + if self._use_sm100_fp4_moe(): + from deep_gemm.utils import per_token_cast_to_fp8 + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=self.quant_method.block_size, + use_packed_ue8m0=True, + ) + return topk_weights, topk_idx.to(torch.long), qinput_tensor + block_size_k = 0 if w13_weight.ndim == 3: block_size_k = w13_weight.shape[2] // w13_scale.shape[2] @@ -171,38 +264,26 @@ def dispatch( overlap_event: Optional[Any] = None, ): buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch( qinput_tensor, topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, expert_alignment=128, + num_sms=self._get_ep_num_sms(), + previous_event=overlap_event, + async_with_compute_stream=True, + allocate_on_comm_stream=True, + do_cpu_sync=True, + do_handle_copy=False, ) def hook(): event.current_stream_wait() - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook def masked_group_gemm( self, @@ -281,7 +362,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -295,7 +376,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous( + deepgemm_grouped_fp8_nt_contiguous( (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices ) # gather and local reduce @@ -319,7 +400,7 @@ def low_latency_combine( topk_weights: torch.Tensor, handle: Any, ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True ) return combined_x, hook @@ -335,8 +416,9 @@ def combine( gemm_out_b, handle, topk_weights=None, - async_finish=True, + num_sms=self._get_ep_num_sms(), previous_event=overlap_event, + async_with_compute_stream=True, allocate_on_comm_stream=True, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..77705b1755 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -1,10 +1,7 @@ """Fused MoE kernel.""" -import os import torch import triton -import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple -import torch.distributed as dist from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( @@ -15,9 +12,11 @@ tma_align_input_scale, ) from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.triton_utils.autotuner import Autotuner -import numpy as np logger = init_logger(__name__) @@ -66,14 +65,14 @@ def fused_experts_impl( topk_weights: torch.Tensor, # [M, topk] topk_idx: torch.Tensor, # [M, topk] num_experts: int, - buffer: "Buffer", + buffer: Any, is_prefill: bool, use_fp8_w8a8: bool = False, use_fp8_all2all: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - previous_event: Optional["EventOverlap"] = None, + previous_event: Optional[Any] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -99,39 +98,27 @@ def fused_experts_impl( combined_x = None if is_prefill: qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, num_experts, previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False - ) - + allocate_on_comm_stream = previous_event is not None # normal dispatch # recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size] # recv_topk_idx [recive_num_tokens, topk_num] # recv_topk_weights [recive_num_tokens, topk_num] # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch( (qinput_tensor, input_scale), topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + num_experts=num_experts, + num_max_tokens_per_rank=get_deepep_num_max_dispatch_tokens_per_rank_prefill(), expert_alignment=128, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + do_cpu_sync=True, + do_handle_copy=False, ) # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + all_tokens = sum(handle.num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -149,7 +136,7 @@ def fused_experts_impl( output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + handle.num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -169,7 +156,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) input_tensor[1] = tma_align_input_scale(input_tensor[1]) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -183,7 +170,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype) - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) + deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) # gather and local reduce ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) @@ -202,13 +189,12 @@ def fused_experts_impl( gather_out, handle, topk_weights=None, - async_finish=False, previous_event=previous_event, - allocate_on_comm_stream=False, + allocate_on_comm_stream=allocate_on_comm_stream, ) else: # low latency dispatch - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts) recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch( hidden_states, @@ -228,7 +214,7 @@ def fused_experts_impl( return combined_x -def _deepgemm_grouped_fp8_nt_contiguous( +def deepgemm_grouped_fp8_nt_contiguous( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, @@ -255,3 +241,22 @@ def _deepgemm_grouped_fp8_nt_masked( if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_masked"): return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version") + + +def deepgemm_grouped_fp8_fp4_nt_contiguous( + input_tuple: Tuple[torch.Tensor, torch.Tensor], + w_tuple: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + grouped_layout: torch.Tensor, + use_psum_layout: bool = False, +): + if HAS_DEEPGEMM and hasattr(deep_gemm, "m_grouped_fp8_fp4_gemm_nt_contiguous"): + return deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + input_tuple, + w_tuple, + out, + grouped_layout, + use_psum_layout=use_psum_layout, + recipe=(1, 1, 32), + ) + raise RuntimeError("deep_gemm does not provide grouped fp8-fp4 NT contiguous GEMM kernel") diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..3b29951f28 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -126,6 +126,78 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["deepgemm-fp8fp4-b32"], platform="cuda") +class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = None + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp8fp4-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + from deep_gemm.utils import per_token_cast_to_fp4 + import deep_gemm + + weight = weight.cuda(output.weight.device) + if weight.dim() == 2: + n, k = weight.shape + packed_weight, weight_scale = per_token_cast_to_fp4(weight, use_ue8m0=True, gran_k=self.block_size) + weight_scale = deep_gemm.transform_sf_into_required_layout(weight_scale, n, k, (1, self.block_size), None) + else: + num_groups, n, k = weight.shape + packed_weight = torch.empty((num_groups, n, k // 2), device=weight.device, dtype=torch.int8) + weight_scale = torch.empty((num_groups, n, k // self.block_size), device=weight.device, dtype=torch.float32) + for i in range(num_groups): + packed_weight[i], weight_scale[i] = per_token_cast_to_fp4( + weight[i], use_ue8m0=True, gran_k=self.block_size + ) + weight_scale = deep_gemm.transform_sf_into_required_layout( + weight_scale, n, k, (1, self.block_size), num_groups + ) + output.weight.copy_(packed_weight) + output.weight_scale.copy_(weight_scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("deepgemm-fp8fp4-b32 is only implemented for fused MoE expert weights") + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension" + assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size" + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.int32).cuda( + device_id + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f01f1c87f7..f15badde25 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -27,7 +27,8 @@ from lightllm.utils.device_utils import has_nvlink from lightllm.utils.envs_utils import ( get_env_start_args, - get_deepep_num_max_dispatch_tokens_per_rank, + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, get_redundancy_expert_num, ) from lightllm.utils.dist_utils import ( @@ -36,7 +37,7 @@ create_new_group_for_current_dp, create_dp_special_inter_group, ) -from lightllm.utils.device_utils import get_device_sm_count +from lightllm.utils.device_utils import get_device_sm_count, is_sm100_gpu from lightllm.utils.torch_dtype_utils import get_torch_dtype logger = init_logger(__name__) @@ -106,6 +107,10 @@ def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, as class DistributeGroupManager: def __init__(self): self.groups = [] + self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None def __len__(self): return len(self.groups) @@ -127,52 +132,92 @@ def get_default_group(self) -> CustomProcessGroup: def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] - def new_deepep_group(self, n_routed_experts, hidden_size): + def new_deepep_group( + self, + n_routed_experts, + hidden_size, + num_experts_per_tok: int = 1, + moe_intermediate_size: Optional[int] = None, + ): enable_ep_moe = get_env_start_args().enable_ep_moe - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + prefill_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + decode_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() if not enable_ep_moe: self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" - self._set_num_sms_for_deep_gemm() global_world_size = get_global_world_size() deepep_group = dist.new_group(list(range(global_world_size))) - low_latency_mode, num_rdma_bytes = True, 0 - if low_latency_mode: - self.ll_num_tokens, self.ll_hidden = num_max_dispatch_tokens_per_rank, hidden_size - self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.ll_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts - ) - self.ep_buffer = deep_ep.Buffer( + self.ll_num_tokens = prefill_num_max_dispatch_tokens_per_rank + self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank + self.ll_hidden = hidden_size + self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size + self.ep_buffer = deep_ep.ElasticBuffer( deepep_group, - int(1e9), - num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=(self.ll_num_experts // global_world_size if low_latency_mode else 1), + num_max_tokens_per_rank=self.ll_num_tokens, + hidden=self.ll_hidden, + num_topk=num_experts_per_tok, + use_fp8_dispatch=True, + allow_multiple_reduction=False, ) + self.ep_mega_moe_buffer = None + self.ep_low_latency_buffer = None + if not is_sm100_gpu(): + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts + ) + self.ep_low_latency_buffer = deep_ep.Buffer( + deepep_group, + int(1e9), + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=(self.ll_num_experts // global_world_size), + ) + else: + if moe_intermediate_size is None: + raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config") + + import deep_gemm + + self.ep_mega_moe_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + deepep_group, + self.ll_num_experts, + self.ll_num_tokens, + num_experts_per_tok, + self.ll_hidden, + moe_intermediate_size, + ) + theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok) + self._set_num_sms_for_deep_gemm(theoretical_sms) - def _set_num_sms_for_deep_gemm(self): + def _set_num_sms_for_deep_gemm(self, deepep_sms: int): try: try: from deep_gemm.jit_kernels.utils import set_num_sms except: from deep_gemm import set_num_sms - deepep_sms = int(os.getenv("DEEPEP_SMS", deep_ep.Buffer.num_sms)) device_sms = get_device_sm_count() - deep_ep.Buffer.set_num_sms(deepep_sms) - set_num_sms(device_sms - deepep_sms) + deepep_sms = max(0, min(deepep_sms, max(device_sms - 2, 0))) + self.ep_num_sms = deepep_sms + if self.ep_low_latency_buffer is not None: + deep_ep.Buffer.set_num_sms(deepep_sms - deepep_sms % 2) + set_num_sms(max(device_sms - deepep_sms, 2)) except BaseException as e: logger.warning(f"set num sms for deep_gemm failed: {e}") def clear_deepep_buffer(self): """ - prefill 之后需要clean 一下,ep buffer 才能正常执行 decode。 + Prefill after using ElasticBuffer may leave the legacy low-latency buffer dirty for decode. """ - if hasattr(self, "ep_buffer") and self.ep_buffer is not None: - self.ep_buffer.clean_low_latency_buffer(self.ll_num_tokens, self.ll_hidden, self.ll_num_experts) + if self.ep_low_latency_buffer is not None: + self.ep_low_latency_buffer.clean_low_latency_buffer( + self.ll_decode_num_tokens, self.ll_hidden, self.ll_num_experts + ) def all_reduce( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..4547ad529a 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -295,7 +295,7 @@ def overlap_tpsp_token_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -421,7 +421,7 @@ def overlap_tpsp_context_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -447,9 +447,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -486,8 +486,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 shared expert if self.n_shared_experts is not None: @@ -518,7 +517,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -533,7 +532,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() if self.n_shared_experts is not None: _0_ffn_out.add_(_0_shared_output) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..ea6620b4e4 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -48,7 +48,12 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index e1df1ec7fd..10b1958b0e 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -130,7 +130,12 @@ def _init_att_backend1(self): def _init_custom(self): self._init_to_get_rotary_gemma4() if self.config.get("enable_moe_block", False): - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", self.config.get("top_k_experts", 1)), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) self._init_ple_static_buffer() def _init_ple_static_buffer(self): diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a8fe49ac5e..1e31306aea 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -25,7 +25,12 @@ def _init_config(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _init_to_get_yarn_rotary(self): rope_scaling = self.config.get("rope_scaling") diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..a39d2f9297 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -133,7 +133,7 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -245,7 +245,7 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -270,9 +270,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -308,8 +308,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( @@ -332,7 +331,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -347,7 +346,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..0d4b45bfe6 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -27,4 +27,9 @@ def _init_custom(self): super()._init_custom() # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4a8ee80a46..e3c51f3617 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -12,7 +12,6 @@ ) from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger -from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs @@ -56,12 +55,6 @@ def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_custom(self): - super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) - def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 249839b0a7..654ba0f3e5 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -4,6 +4,7 @@ import uuid import subprocess import signal +import math from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker from lightllm.utils.start_utils import process_manager, kill_recursive from .metrics.manager import start_metric_manager @@ -291,7 +292,10 @@ def normal_or_p_d_start(args): # linear att cache 参数自动设置 if args.linear_att_cache_size is None: # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 - args.linear_att_cache_size = args.running_max_req_size * 2 + default_cache_size = args.running_max_req_size * 2 + dp_size_in_node = max(1, args.dp // args.nnodes) + per_dp_cache_size = max(1, math.ceil(args.running_max_req_size / dp_size_in_node) * 2) + args.linear_att_cache_size = min(default_cache_size, per_dp_cache_size) if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 43b10ec88b..58bff90560 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -40,6 +40,11 @@ def get_device_sm_count(): return properties["multiprocessor_count"] +@lru_cache(maxsize=None) +def is_sm100_gpu(): + return torch.cuda.get_device_capability()[0] == 10 + + @lru_cache(maxsize=None) def get_device_sm_regs_num(): import triton diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index e11da07c8c..12b0b81993 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) _CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) -_TWO_GPU_CHECK_TIMEOUT_SECONDS = 60.0 +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: @@ -84,6 +84,8 @@ def _flashinfer_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> N input_tensor = torch.zeros(2, 64, device=cuda_device, dtype=torch.bfloat16) else: input_tensor = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + if not flashinfer_all_reduce.should_use(input_tensor): + raise RuntimeError("FlashInferAllReduce unsupported for probe tensor") output_tensor = flashinfer_all_reduce.all_reduce(input_tensor) dist.barrier() expected_reduced = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e897..2bdd4005fa 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -69,9 +69,22 @@ def enable_env_vars(args): @lru_cache(maxsize=None) -def get_deepep_num_max_dispatch_tokens_per_rank(): +def get_deepep_num_max_dispatch_tokens_per_rank_prefill(): + # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大。 + # 如果未显式配置,则默认至少覆盖当前进程的 `batch_max_tokens`,避免 DeepEP V2 在 autotune + # warmup 或大 prefill batch 时因为 buffer 上界过小而报错。 + configured = os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_PREFILL", None) + if configured is not None: + return int(configured) + + batch_max_tokens = get_env_start_args().batch_max_tokens or 256 + return ((int(batch_max_tokens) + 7) // 8) * 8 + + +@lru_cache(maxsize=None) +def get_deepep_num_max_dispatch_tokens_per_rank_decode(): # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大,如果出现显存不足,可以尝试调小该值 - return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256)) + return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_DECODE", 256)) def get_lightllm_gunicorn_keep_alive(): diff --git a/requirements.txt b/requirements.txt index d37ae05690..31a88629ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ mpmath==1.3.0 multiprocessing-logging==0.3.4 networkx==3.1 ninja==1.11.1 -numpy==1.25.1 +numpy==2.1.3 packaging==24.2 pip==23.0.1 pluggy==1.2.0 @@ -59,7 +59,7 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.9.1 +torch==2.11.0 tqdm==4.65.0 transformers==4.57.1 tokenizers==0.22.1 @@ -71,7 +71,7 @@ zstandard==0.23.0 safetensors==0.4.5 Pillow==10.4.0 tiktoken==0.7.0 -matplotlib==3.8.2 +matplotlib==3.10.0 psutil==5.9.4 prometheus_client==0.20.0 cchardet==2.1.7 @@ -81,19 +81,21 @@ atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 flashinfer-python==0.6.8.post1 -sgl-kernel==0.3.21 +flashinfer-cubin==0.6.8.post1 +sglang-kernel==0.4.2.post1 httpx==0.28.1 librosa==0.11.0 -cuda_bindings==12.9.0 +cuda_bindings==13.2.0 orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 -torchvision==0.24.1 +torchvision==0.26.0 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 -cupy-cuda12x==13.6.0 -nixl==0.8.0 -xformers==0.0.33.post2 +cupy-cuda13x==14.0.1 +nixl==1.1.0 +xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 +flash-attn-4[13]==4.0.0b14 diff --git a/test/benchmark/service/benchmark_client.py b/test/benchmark/service/benchmark_client.py index 09009fc9e1..3f55bcab1e 100644 --- a/test/benchmark/service/benchmark_client.py +++ b/test/benchmark/service/benchmark_client.py @@ -27,6 +27,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_output_length(input_num: int, output_len: int) -> List[int]: min_len, max_len = 2, output_len * 2 mean = (min_len + max_len) * 0.5 @@ -162,7 +169,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index 7387237f4d..83cbf934d3 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -39,6 +39,8 @@ import random import threading import time +import urllib.parse +import urllib.request from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Tuple, Union @@ -46,6 +48,15 @@ import requests from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +_DEFAULT_TRANSIENT_RETRIES = 2 +_PROMPT_LEN_OVERLAP_CHARS = 512 +_TRANSIENT_STREAM_ERRORS = ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.Timeout, +) + def seed_all(seed: int) -> None: if not seed: @@ -59,6 +70,85 @@ def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedT return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + +def get_models_url(completions_url: str) -> str: + parsed = urllib.parse.urlsplit(completions_url) + path = parsed.path.rstrip("/") + for suffix in ("/chat/completions", "/completions"): + if path.endswith(suffix): + path = path[: -len(suffix)] + "/models" + return urllib.parse.urlunsplit(parsed._replace(path=path, query="", fragment="")) + return urllib.parse.urlunsplit(parsed._replace(path="/v1/models", query="", fragment="")) + + +def fetch_served_model_names(completions_url: str, timeout_s: int = 10) -> List[str]: + models_url = get_models_url(completions_url) + request = urllib.request.Request(models_url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(request, timeout=timeout_s) as response: + payload = json.loads(response.read().decode("utf-8")) + return [item["id"] for item in payload.get("data", []) if item.get("id")] + + +def resolve_model_name( + completions_url: str, + requested_model_name: str, + explicit_model_name: bool, +) -> Tuple[str, Optional[str]]: + normalized_name = normalize_model_name(requested_model_name) + if normalized_name != requested_model_name: + note = f"Normalized model name from `{requested_model_name}` to `{normalized_name}`." + else: + note = None + + try: + served_model_names = fetch_served_model_names(completions_url) + except Exception as exc: + if note is not None: + note = f"{note} Failed to query served models: {exc}." + return normalized_name, note + + if requested_model_name in served_model_names: + return requested_model_name, note + if normalized_name in served_model_names: + if normalized_name != requested_model_name: + return normalized_name, ( + f"Normalized model name from `{requested_model_name}` to `{normalized_name}` " "to match `/v1/models`." + ) + return normalized_name, note + + requested_basename = os.path.basename(normalized_name) + basename_matches = [ + served_name + for served_name in served_model_names + if os.path.basename(normalize_model_name(served_name)) == requested_basename + ] + if len(basename_matches) == 1: + matched_name = basename_matches[0] + return matched_name, ( + f"Resolved model name `{requested_model_name}` to served model `{matched_name}` " "via `/v1/models`." + ) + + if not explicit_model_name and len(served_model_names) == 1: + matched_name = served_model_names[0] + return matched_name, ( + f"Using the only served model `{matched_name}` returned by `/v1/models` " + f"instead of `{requested_model_name}`." + ) + + if note is not None: + note = ( + f"{note} Available served models: {', '.join(served_model_names) or '(none)'}. " + f"Using `{normalized_name}`." + ) + return normalized_name, note + + def gen_random_token_ids(tokenizer, n: int, rng: random.Random) -> List[int]: vocab = tokenizer.vocab_size return [rng.randint(0, vocab - 1) for _ in range(n)] @@ -87,6 +177,7 @@ def gen_session_initial_prompt( def append_turn_input( tokenizer, prompt: str, + prompt_token_len: int, generated_text: str, turn_input_increment: int, rng: random.Random, @@ -98,17 +189,34 @@ def append_turn_input( new_text = decode_ids(tokenizer, new_ids) else: new_text = "" - new_prompt = prompt + generated_text + new_text - new_len = len(tokenizer.encode(new_prompt, add_special_tokens=False)) + + appended_text = generated_text + new_text + new_prompt = prompt + appended_text + if not appended_text: + return new_prompt, prompt_token_len + + # Token merges only depend on a small boundary window, so avoid + # re-encoding the entire prompt on every turn. + overlap_text = prompt[-_PROMPT_LEN_OVERLAP_CHARS:] + if overlap_text: + overlap_token_len = len(tokenizer.encode(overlap_text, add_special_tokens=False)) + merged_token_len = len(tokenizer.encode(overlap_text + appended_text, add_special_tokens=False)) + appended_token_len = max(merged_token_len - overlap_token_len, 0) + else: + appended_token_len = len(tokenizer.encode(appended_text, add_special_tokens=False)) + new_len = prompt_token_len + appended_token_len return new_prompt, new_len def stream_one_turn( + tokenizer, url: str, model_name: str, prompt: str, + prompt_token_len: int, max_new_tokens: int, request_timeout_s: int, + max_retries: int = _DEFAULT_TRANSIENT_RETRIES, ) -> Optional[Dict]: """Send one streaming completion request, return per-turn stats: { @@ -117,6 +225,8 @@ def stream_one_turn( "prompt_tokens": int, "completion_tokens": int, "cached_tokens": int, + "cached_tokens_reported": bool, + "usage_estimated": bool, "generated_text": str, } Returns None on failure.""" @@ -131,79 +241,121 @@ def stream_one_turn( } headers = {"Content-Type": "application/json"} - start_time = time.time() - first_token_time: Optional[float] = None - last_token_time: Optional[float] = None - decode_times: List[float] = [] - generated_text_parts: List[str] = [] - prompt_tokens = 0 - completion_tokens = 0 - cached_tokens = 0 - - with requests.Session() as req_session: - req_session.trust_env = False - with req_session.post( - url, - headers=headers, - json=payload, - stream=True, - timeout=(10, request_timeout_s), - ) as response: - if response.status_code != 200: - err = response.text - raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}") - - for raw in response.iter_lines(): - if not raw: - continue - line = raw.strip() - if not line.startswith(b"data:"): - continue - data_str = line[len(b"data:") :].strip() - if data_str == b"[DONE]": - break - try: - chunk = json.loads(data_str) - except Exception: - continue - - # Final usage-only chunk: choices == [] and usage present - usage = chunk.get("usage") - choices = chunk.get("choices") or [] - if usage is not None and not choices: - prompt_tokens = usage.get("prompt_tokens", prompt_tokens) - completion_tokens = usage.get("completion_tokens", completion_tokens) - details = usage.get("prompt_tokens_details") or {} - cached_tokens = details.get("cached_tokens", cached_tokens) - continue - - # Token-bearing chunk - if not choices: - continue - text_piece = choices[0].get("text", "") - if text_piece == "" and choices[0].get("finish_reason") is None: - continue - - now = time.time() - if first_token_time is None: - first_token_time = now - else: - decode_times.append(now - last_token_time) - last_token_time = now - if text_piece: - generated_text_parts.append(text_piece) - - if first_token_time is None: - raise RuntimeError("stream_one_turn failed: no token received from stream") - - return { - "ttft": first_token_time - start_time, - "decode_times": decode_times, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "cached_tokens": cached_tokens, - "generated_text": "".join(generated_text_parts), - } + for attempt in range(max_retries + 1): + start_time = time.time() + first_token_time: Optional[float] = None + last_token_time: Optional[float] = None + decode_times: List[float] = [] + generated_text_parts: List[str] = [] + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cached_tokens_reported = False + + try: + with requests.Session() as req_session: + req_session.trust_env = False + with req_session.post( + url, + headers=headers, + json=payload, + stream=True, + timeout=(10, request_timeout_s), + ) as response: + if response.status_code != 200: + err = response.text + if response.status_code >= 500 and attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + print(f"\n[turn failed] status={response.status_code} body={err[:200]}") + return None + + for raw in response.iter_lines(): + if not raw: + continue + line = raw.strip() + if not line.startswith(b"data:"): + continue + data_str = line[len(b"data:") :].strip() + if data_str == b"[DONE]": + break + try: + chunk = json.loads(data_str) + except Exception: + continue + + # Final usage-only chunk: choices == [] and usage present + usage = chunk.get("usage") + choices = chunk.get("choices") or [] + if usage is not None and not choices: + prompt_tokens = usage.get("prompt_tokens", prompt_tokens) + completion_tokens = usage.get("completion_tokens", completion_tokens) + details = usage.get("prompt_tokens_details") + if isinstance(details, dict) and details.get("cached_tokens") is not None: + cached_tokens = details["cached_tokens"] + cached_tokens_reported = True + continue + + # Token-bearing chunk + if not choices: + continue + text_piece = choices[0].get("text", "") + if text_piece == "" and choices[0].get("finish_reason") is None: + continue + + now = time.time() + if first_token_time is None: + first_token_time = now + else: + decode_times.append(now - last_token_time) + last_token_time = now + if text_piece: + generated_text_parts.append(text_piece) + except _TRANSIENT_STREAM_ERRORS as e: + if first_token_time is None and attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + + if first_token_time is not None: + generated_text = "".join(generated_text_parts) + estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) + estimated_completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) + print(f"\n[turn warning] {e}; keeping partial turn with estimated usage (attempt={attempt + 1})") + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens or prompt_token_len, + "completion_tokens": completion_tokens or estimated_completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": completion_tokens == 0 or prompt_tokens == 0, + "generated_text": generated_text, + } + + print(f"\n[turn exception] {e}") + return None + except Exception as e: + print(f"\n[turn exception] {e}") + return None + + if first_token_time is None: + if attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + return None + + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": False, + "generated_text": "".join(generated_text_parts), + } + + return None def run_session( @@ -234,9 +386,11 @@ def run_session( while turn_idx < max_turns and prompt_len < max_input_len: turn_output_len = rng.randint(min_output_len, output_len) result = stream_one_turn( + tokenizer=tokenizer, url=url, model_name=model_name, prompt=prompt, + prompt_token_len=prompt_len, max_new_tokens=turn_output_len, request_timeout_s=request_timeout_s, ) @@ -255,6 +409,7 @@ def run_session( prompt, prompt_len = append_turn_input( tokenizer, prompt, + result["prompt_tokens"] or prompt_len, result["generated_text"], turn_input_len, rng, @@ -358,13 +513,14 @@ def summarize( prompt_tokens = sum(t["prompt_tokens"] for t in turns) completion_tokens = sum(t["completion_tokens"] for t in turns) cached_tokens = sum(t["cached_tokens"] for t in turns) + cached_tokens_reported_turns = sum(1 for t in turns if t.get("cached_tokens_reported")) + usage_estimated_turns = sum(1 for t in turns if t.get("usage_estimated")) total_tokens = prompt_tokens + completion_tokens qps = len(turns) / wall_time tpm_total = total_tokens / wall_time * 60.0 tpm_prompt = prompt_tokens / wall_time * 60.0 tpm_completion = completion_tokens / wall_time * 60.0 - cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 out["QPS"] = round(qps, 4) out["TPM_total"] = round(tpm_total, 2) @@ -373,7 +529,18 @@ def summarize( out["total_prompt_tokens"] = prompt_tokens out["total_completion_tokens"] = completion_tokens out["total_cached_prompt_tokens"] = cached_tokens - out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + out["cached_tokens_reported_turns"] = cached_tokens_reported_turns + out["usage_estimated_turns"] = usage_estimated_turns + if cached_tokens_reported_turns > 0: + cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 + out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + else: + out["cache_hit_ratio"] = None + out["cache_hit_ratio_note"] = ( + "Server did not return usage.prompt_tokens_details.cached_tokens. " + "For vLLM OpenAI-compatible APIs, start the server with " + "--enable-prompt-tokens-details to expose cache-hit stats." + ) out["avg_prompt_tokens_per_turn"] = round(prompt_tokens / len(turns), 2) out["avg_completion_tokens_per_turn"] = round(completion_tokens / len(turns), 2) @@ -406,10 +573,16 @@ def print_summary(summary: Dict) -> None: print(f" TPM (total) : {summary['TPM_total']}") print(f" TPM (prompt) : {summary['TPM_prompt']}") print(f" TPM (completion) : {summary['TPM_completion']}") - print( - f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " - f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" - ) + if summary["cache_hit_ratio"] is None: + print(" Cache hit ratio : n/a") + print(f" Cache hit note : {summary['cache_hit_ratio_note']}") + else: + print( + f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " + f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" + ) + if summary.get("usage_estimated_turns"): + print(f" Usage estimated : {summary['usage_estimated_turns']} turns") print(f" Avg prompt tokens : {summary['avg_prompt_tokens_per_turn']}") print(f" Avg output tokens : {summary['avg_completion_tokens_per_turn']}") ttft = summary["TTFT_ms"] @@ -432,7 +605,7 @@ def main() -> None: parser.add_argument( "--url", type=str, - default="http://127.0.0.1:8088/v1/completions", + default="http://127.0.0.1:8000/v1/completions", help="Streaming OpenAI completion endpoint. The benchmark relies on " "the final SSE `usage` chunk to obtain cached_tokens.", ) @@ -499,12 +672,19 @@ def main() -> None: return seed_all(args.seed) - model_name = args.model_name or args.tokenizer_path + requested_model_name = args.model_name or args.tokenizer_path + model_name, model_name_note = resolve_model_name( + args.url, + requested_model_name, + explicit_model_name=args.model_name is not None, + ) tokenizer = get_tokenizer(args.tokenizer_path) concurrency_levels = [int(x) for x in args.concurrency_levels.split(",") if x.strip()] print(f"URL : {args.url}") print(f"Model : {model_name}") + if model_name_note: + print(f"Model note : {model_name_note}") print(f"Concurrency levels : {concurrency_levels}") print(f"start_input_len : {args.start_input_len}") print(f"max_input_len : {args.max_input_len}") @@ -538,6 +718,7 @@ def main() -> None: "config": { "url": args.url, "model_name": model_name, + "requested_model_name": requested_model_name, "tokenizer_path": args.tokenizer_path, "concurrency_levels": concurrency_levels, "start_input_len": args.start_input_len, diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 8249ae2c49..3249ebcbda 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -31,6 +31,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]: lens = [] lens = np.random.randint( @@ -429,7 +436,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) From f91690db05cde74a8e948ec330abfed27ec00032 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Sat, 23 May 2026 21:19:27 +0800 Subject: [PATCH 02/13] fix: fix bugs --- .../fused_moe/fused_moe_weight.py | 2 +- .../fused_moe/impl/deepgemm_impl.py | 2 +- .../kv_cache_mem_manager/mem_manager.py | 3 +- lightllm/common/quantization/deepgemm.py | 6 ++-- requirements.txt | 1 - test/benchmark/service/benchmark_multiturn.py | 33 +++++++++---------- 6 files changed, 22 insertions(+), 25 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 375725d124..4fec94d41f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -76,7 +76,7 @@ def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMetho if not self.enable_ep_moe: return quant_method - target_method = "deepgemm-fp8fp4-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128" + target_method = "deepgemm-fp4fp8-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128" if quant_method.method_name == "none": from lightllm.common.quantization.registry import QUANTMETHODS diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 2adc4343e2..29eaa8730f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -28,7 +28,7 @@ def _get_ep_num_sms(self) -> int: return getattr(dist_group_manager, "ep_num_sms", None) or 0 def _use_sm100_fp4_moe(self) -> bool: - return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp8fp4-b32" + return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp4fp8-b32" def _get_mega_moe_weights(self, w13: WeightPack, w2: WeightPack): cache_key = ( diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0454c86628..0a1deba499 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -67,8 +67,7 @@ def profile_size(self, mem_fraction): torch.cuda.empty_cache() world_size = dist.get_world_size() - - available_memory = get_available_gpu_memory(world_size) * mem_fraction + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) if world_size > 1: diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 3b29951f28..ec1ee90fd4 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -126,7 +126,7 @@ def _create_weight( return mm_param, mm_param_list -@QUANTMETHODS.register(["deepgemm-fp8fp4-b32"], platform="cuda") +@QUANTMETHODS.register(["deepgemm-fp4fp8-b32"], platform="cuda") class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod): def __init__(self): super().__init__() @@ -139,7 +139,7 @@ def __init__(self): @property def method_name(self): - return "deepgemm-fp8fp4-b32" + return "deepgemm-fp4fp8-b32" def quantize(self, weight: torch.Tensor, output: WeightPack): from deep_gemm.utils import per_token_cast_to_fp4 @@ -174,7 +174,7 @@ def apply( use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - raise NotImplementedError("deepgemm-fp8fp4-b32 is only implemented for fused MoE expert weights") + raise NotImplementedError("deepgemm-fp4fp8-b32 is only implemented for fused MoE expert weights") def _create_weight( self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 diff --git a/requirements.txt b/requirements.txt index 31a88629ed..f2f5bcf504 100644 --- a/requirements.txt +++ b/requirements.txt @@ -98,4 +98,3 @@ nixl==1.1.0 xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 -flash-attn-4[13]==4.0.0b14 diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index 83cbf934d3..7019654c38 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -317,20 +317,8 @@ def stream_one_turn( continue if first_token_time is not None: - generated_text = "".join(generated_text_parts) - estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) - estimated_completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) - print(f"\n[turn warning] {e}; keeping partial turn with estimated usage (attempt={attempt + 1})") - return { - "ttft": first_token_time - start_time, - "decode_times": decode_times, - "prompt_tokens": prompt_tokens or prompt_token_len, - "completion_tokens": completion_tokens or estimated_completion_tokens, - "cached_tokens": cached_tokens, - "cached_tokens_reported": cached_tokens_reported, - "usage_estimated": completion_tokens == 0 or prompt_tokens == 0, - "generated_text": generated_text, - } + print(f"\n[turn warning] {e}; discarding partial turn (attempt={attempt + 1})") + return None print(f"\n[turn exception] {e}") return None @@ -344,6 +332,16 @@ def stream_one_turn( continue return None + generated_text = "".join(generated_text_parts) + usage_estimated = False + if prompt_tokens == 0: + prompt_tokens = prompt_token_len + usage_estimated = True + if completion_tokens == 0: + estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) + completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) + usage_estimated = True + return { "ttft": first_token_time - start_time, "decode_times": decode_times, @@ -351,8 +349,8 @@ def stream_one_turn( "completion_tokens": completion_tokens, "cached_tokens": cached_tokens, "cached_tokens_reported": cached_tokens_reported, - "usage_estimated": False, - "generated_text": "".join(generated_text_parts), + "usage_estimated": usage_estimated, + "generated_text": generated_text, } return None @@ -402,8 +400,9 @@ def run_session( print( f"\rconc={progress_state['concurrency']} " f"finished_turns={progress_state['finished_turns']} " - f"active_sessions={progress_state['active_sessions']}", + f"active_sessions={progress_state['active_sessions']}\033[K", end="", + flush=True, ) turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) prompt, prompt_len = append_turn_input( From 97d45b031498a9a8a262dfc70ba92d0474e3e780 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 28 May 2026 15:30:38 +0800 Subject: [PATCH 03/13] feat: refine --- .../fused_moe/fused_moe_weight.py | 3 +- .../fused_moe/impl/deepgemm_impl.py | 88 ++----------------- .../fused_moe/grouped_fused_moe_ep.py | 86 ++++++++++++++++++ 3 files changed, 96 insertions(+), 81 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 4fec94d41f..8d42ad0856 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -177,7 +177,8 @@ def experts( ) def use_sm100_mega_moe(self) -> bool: - return bool(getattr(self.fuse_moe_impl, "_use_sm100_fp4_moe", lambda: False)()) + quant_method = getattr(self.fuse_moe_impl, "quant_method", None) + return is_sm100_gpu() and getattr(quant_method, "method_name", None) == "deepgemm-fp4fp8-b32" def low_latency_dispatch( self, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 29eaa8730f..91e1e29cea 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -10,8 +10,11 @@ ) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, + get_ep_num_sms, masked_group_gemm, deepgemm_grouped_fp8_nt_contiguous, + mega_moe_impl, + use_sm100_fp4_moe, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -20,84 +23,9 @@ from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair -from lightllm.utils.device_utils import is_sm100_gpu class FuseMoeDeepGEMM(FuseMoeTriton): - def _get_ep_num_sms(self) -> int: - return getattr(dist_group_manager, "ep_num_sms", None) or 0 - - def _use_sm100_fp4_moe(self) -> bool: - return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp4fp8-b32" - - def _get_mega_moe_weights(self, w13: WeightPack, w2: WeightPack): - cache_key = ( - w13.weight.data_ptr(), - w13.weight_scale.data_ptr(), - w2.weight.data_ptr(), - w2.weight_scale.data_ptr(), - ) - if getattr(self, "_mega_moe_weight_cache_key", None) != cache_key: - import deep_gemm - - self._mega_moe_weight_cache = deep_gemm.transform_weights_for_mega_moe( - (w13.weight, w13.weight_scale), - (w2.weight, w2.weight_scale), - ) - self._mega_moe_weight_cache_key = cache_key - return self._mega_moe_weight_cache - - def _get_mega_moe_stats(self, num_local_experts: int, device: torch.device): - stats = getattr(self, "_mega_moe_stats", None) - if stats is None or stats.numel() != num_local_experts or stats.device != device: - stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) - self._mega_moe_stats = stats - return stats - - def _mega_moe( - self, - hidden_states: torch.Tensor, - w13: WeightPack, - w2: WeightPack, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> torch.Tensor: - import deep_gemm - from deep_gemm.utils import per_token_cast_to_fp8 - - buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) - if buffer is None: - raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") - - num_tokens = hidden_states.shape[0] - if num_tokens > buffer.num_max_tokens_per_rank: - raise RuntimeError( - f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" - ) - - qinput_tensor = per_token_cast_to_fp8( - hidden_states, - use_ue8m0=True, - gran_k=self.quant_method.block_size, - use_packed_ue8m0=True, - ) - l1_weights, l2_weights = self._get_mega_moe_weights(w13, w2) - cumulative_stats = self._get_mega_moe_stats(w13.weight.shape[0], hidden_states.device) - buffer.x[:num_tokens].copy_(qinput_tensor[0]) - buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) - buffer.topk_idx[:num_tokens].copy_(topk_ids) - buffer.topk_weights[:num_tokens].copy_(topk_weights) - - output = torch.empty_like(hidden_states) - deep_gemm.fp8_fp4_mega_moe( - output, - l1_weights, - l2_weights, - buffer, - cumulative_local_expert_recv_stats=cumulative_stats, - ) - return output - def _select_experts( self, input_tensor: torch.Tensor, @@ -152,8 +80,8 @@ def _fused_experts( ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale - if self._use_sm100_fp4_moe(): - return self._mega_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long)) + if use_sm100_fp4_moe(self.quant_method): + return mega_moe_impl(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long), self.quant_method) use_fp8_w8a8 = self.quant_method.method_name != "none" buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer @@ -238,7 +166,7 @@ def select_experts_and_quant_input( scoring_func=scoring_func, ) w13_weight, w13_scale = w13.weight, w13.weight_scale - if self._use_sm100_fp4_moe(): + if use_sm100_fp4_moe(self.quant_method): from deep_gemm.utils import per_token_cast_to_fp8 qinput_tensor = per_token_cast_to_fp8( @@ -272,7 +200,7 @@ def dispatch( num_experts=self.total_expert_num_contain_redundancy, num_max_tokens_per_rank=num_max_tokens_per_rank, expert_alignment=128, - num_sms=self._get_ep_num_sms(), + num_sms=get_ep_num_sms(), previous_event=overlap_event, async_with_compute_stream=True, allocate_on_comm_stream=True, @@ -416,7 +344,7 @@ def combine( gemm_out_b, handle, topk_weights=None, - num_sms=self._get_ep_num_sms(), + num_sms=get_ep_num_sms(), previous_event=overlap_event, async_with_compute_stream=True, allocate_on_comm_stream=True, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 77705b1755..7c35a85c68 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -2,6 +2,7 @@ import torch import triton from typing import Any, Callable, Dict, Optional, Tuple +from lightllm.distributed import dist_group_manager from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( @@ -17,8 +18,10 @@ get_deepep_num_max_dispatch_tokens_per_rank_decode, ) from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +_MEGA_MOE_STATES: Dict[Tuple[int, int, int, int], Dict[str, Any]] = {} try: from deep_ep import Buffer, EventOverlap @@ -30,6 +33,14 @@ HAS_DEEPGEMM = False +def get_ep_num_sms() -> int: + return getattr(dist_group_manager, "ep_num_sms", None) or 0 + + +def use_sm100_fp4_moe(quant_method: Any) -> bool: + return is_sm100_gpu() and quant_method.method_name == "deepgemm-fp4fp8-b32" + + def masked_group_gemm( recv_x: Tuple[torch.Tensor, torch.Tensor], masked_m: torch.Tensor, @@ -58,6 +69,81 @@ def masked_group_gemm( return gemm_out_b +def _get_mega_moe_cache_state(w13: Any, w2: Any): + state_key = ( + w13.weight.data_ptr(), + w13.weight_scale.data_ptr(), + w2.weight.data_ptr(), + w2.weight_scale.data_ptr(), + ) + return _MEGA_MOE_STATES.setdefault(state_key, {}) + + +def _get_mega_moe_weights(w13: Any, w2: Any, state: Dict[str, Any]): + if "weight_cache" not in state: + state["weight_cache"] = deep_gemm.transform_weights_for_mega_moe( + (w13.weight, w13.weight_scale), + (w2.weight, w2.weight_scale), + ) + return state["weight_cache"] + + +def _get_mega_moe_cumulative_stats(num_local_experts: int, device: torch.device, state: Dict[str, Any]): + stats = state.get("stats") + if stats is None or stats.numel() != num_local_experts or stats.device != device: + stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) + state["stats"] = stats + return stats + + +def mega_moe_impl( + hidden_states: torch.Tensor, + w13: Any, + w2: Any, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_method: Any, +): + if not (HAS_DEEPGEMM and hasattr(deep_gemm, "fp8_fp4_mega_moe")): + raise RuntimeError("deep_gemm does not provide fp8-fp4 Mega MoE kernel") + + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) + if buffer is None: + raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") + + num_tokens = hidden_states.shape[0] + if num_tokens > buffer.num_max_tokens_per_rank: + raise RuntimeError( + f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" + ) + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=quant_method.block_size, + use_packed_ue8m0=True, + ) + state = _get_mega_moe_cache_state(w13, w2) + l1_weights, l2_weights = _get_mega_moe_weights(w13, w2, state) + stats = _get_mega_moe_cumulative_stats(w13.weight.shape[0], hidden_states.device, state) + buffer.x[:num_tokens].copy_(qinput_tensor[0]) + buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) + buffer.topk_idx[:num_tokens].copy_(topk_ids) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + output = torch.empty_like(hidden_states) + deep_gemm.fp8_fp4_mega_moe( + output, + l1_weights, + l2_weights, + buffer, + cumulative_local_expert_recv_stats=stats, + ) + return output + + def fused_experts_impl( hidden_states: torch.Tensor, # [M, K] w1: torch.Tensor, # [group, N, K] From aa4b0c13df5b1addaf402a90503ddbe2fb94216a Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 28 May 2026 07:37:09 +0000 Subject: [PATCH 04/13] slime fuse_moe_weight --- .../fused_moe/fused_moe_weight.py | 24 ------------------- lightllm/common/quantization/__init__.py | 18 ++++++++++++++ 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 4fec94d41f..0079f1d2e8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -11,7 +11,6 @@ from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args -from lightllm.utils.device_utils import is_sm100_gpu from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger @@ -53,7 +52,6 @@ def __init__( self.quant_method = quant_method assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." self.enable_ep_moe = get_env_start_args().enable_ep_moe - self.quant_method = self._maybe_upgrade_quant_method_for_ep_moe(self.quant_method) self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts self._init_config(network_config) @@ -72,28 +70,6 @@ def __init__( self.lock = threading.Lock() self._create_weight() - def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMethod) -> QuantizationMethod: - if not self.enable_ep_moe: - return quant_method - - target_method = "deepgemm-fp4fp8-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128" - if quant_method.method_name == "none": - from lightllm.common.quantization.registry import QUANTMETHODS - - logger.info( - f"enable_ep_moe requires DeepGEMM MoE expert weights; " - f"auto-upgrading fused_moe quantization from `none` to `{target_method}`." - ) - quant_method = QUANTMETHODS.get(target_method) - - if quant_method.method_name != target_method: - raise ValueError( - f"enable_ep_moe currently requires `{target_method}` for fused_moe on this GPU, " - f"but got `{quant_method.method_name}`." - ) - - return quant_method - def _init_config(self, network_config: Dict[str, Any]): self.n_group = network_config.get("n_group", 0) self.use_grouped_topk = self.n_group > 0 diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 1f08432c6a..39674d2bba 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -44,6 +44,24 @@ def _mapping_quant_method(self): else: self.quant_type = "vllm-fp8w8a8-b128" logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") + + # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, + # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 + expert_dtype = self.network_config_.get("expert_dtype", None) + if expert_dtype is not None: + expert_dtype_to_quant_type = { + "fp4": "deepgemm-fp4fp8-b32", + "fp8": "deepgemm-fp8w8a8-b128", + } + target = expert_dtype_to_quant_type.get(expert_dtype) + if target is None: + raise ValueError( + f"unsupported expert_dtype `{expert_dtype}`; " + f"expected one of {sorted(expert_dtype_to_quant_type)}" + ) + for layer_num in range(self.layer_num): + self.quant_cfg[layer_num].setdefault("fused_moe", target) + logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") elif self.hf_quantization_method == "awq": self.quant_type = "awq" if is_awq_marlin_compatible(self.hf_quantization_config): From 735dbf9ccc90decdba3eb999481535e98cb592ae Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 28 May 2026 07:38:43 +0000 Subject: [PATCH 05/13] remove use_sm100_mega_moe --- .../layer_weights/meta_weights/fused_moe/fused_moe_weight.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 0b78c097c7..fca9b80fcf 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -152,10 +152,6 @@ def experts( per_expert_scale=self.per_expert_scale, ) - def use_sm100_mega_moe(self) -> bool: - quant_method = getattr(self.fuse_moe_impl, "quant_method", None) - return is_sm100_gpu() and getattr(quant_method, "method_name", None) == "deepgemm-fp4fp8-b32" - def low_latency_dispatch( self, hidden_states: torch.Tensor, From 05ec73ee358eeff2607a13cd17f97f93cb711c0b Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 28 May 2026 16:23:38 +0800 Subject: [PATCH 06/13] feat: add do_fused_experts --- .../fused_moe/impl/deepgemm_impl.py | 23 +++---------- .../fused_moe/grouped_fused_moe_ep.py | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 91e1e29cea..2327c7f445 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -9,11 +9,10 @@ get_deepep_num_max_dispatch_tokens_per_rank_decode, ) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( - fused_experts_impl, + do_fused_experts, get_ep_num_sms, masked_group_gemm, deepgemm_grouped_fp8_nt_contiguous, - mega_moe_impl, use_sm100_fp4_moe, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( @@ -78,27 +77,15 @@ def _fused_experts( router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, ): - w13_weight, w13_scale = w13.weight, w13.weight_scale - w2_weight, w2_scale = w2.weight, w2.weight_scale - if use_sm100_fp4_moe(self.quant_method): - return mega_moe_impl(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long), self.quant_method) - - use_fp8_w8a8 = self.quant_method.method_name != "none" - buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer - output = fused_experts_impl( + output = do_fused_experts( hidden_states=input_tensor, - w1=w13_weight, - w2=w2_weight, + w13=w13, + w2=w2, topk_weights=topk_weights, topk_idx=topk_ids.to(torch.long), num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=buffer, + quant_method=self.quant_method, is_prefill=is_prefill, - use_fp8_w8a8=use_fp8_w8a8, - use_fp8_all2all=use_fp8_w8a8, - use_int8_w8a16=False, # default to False - w1_scale=w13_scale, - w2_scale=w2_scale, previous_event=None, # for overlap ) return output diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 7c35a85c68..184e1a2822 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -144,6 +144,40 @@ def mega_moe_impl( return output +def do_fused_experts( + hidden_states: torch.Tensor, + w13: Any, + w2: Any, + topk_weights: torch.Tensor, + topk_idx: torch.Tensor, + num_experts: int, + quant_method: Any, + is_prefill: Optional[bool], + previous_event: Optional[Any] = None, +): + if use_sm100_fp4_moe(quant_method): + return mega_moe_impl(hidden_states, w13, w2, topk_weights, topk_idx, quant_method) + + use_fp8_w8a8 = quant_method.method_name != "none" + buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer + return fused_experts_impl( + hidden_states=hidden_states, + w1=w13.weight, + w2=w2.weight, + topk_weights=topk_weights, + topk_idx=topk_idx, + num_experts=num_experts, + buffer=buffer, + is_prefill=is_prefill, + use_fp8_w8a8=use_fp8_w8a8, + use_fp8_all2all=use_fp8_w8a8, + use_int8_w8a16=False, + w1_scale=w13.weight_scale, + w2_scale=w2.weight_scale, + previous_event=previous_event, + ) + + def fused_experts_impl( hidden_states: torch.Tensor, # [M, K] w1: torch.Tensor, # [group, N, K] From 966c8acffbd474adb63d8260af1ab4f6057ce0be Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 28 May 2026 17:48:41 +0800 Subject: [PATCH 07/13] feat: add --expert_dtype param --- docs/CN/source/tutorial/api_server_args.rst | 8 ++++ docs/EN/source/tutorial/api_server_args.rst | 8 ++++ lightllm/common/basemodel/basemodel.py | 3 +- .../fused_moe/impl/deepgemm_impl.py | 4 +- .../fused_moe/grouped_fused_moe_ep.py | 26 +++++++++--- lightllm/common/quantization/__init__.py | 40 +++++++++++++------ .../layer_infer/transformer_layer_infer.py | 5 ++- .../layer_infer/transformer_layer_infer.py | 5 ++- lightllm/server/api_cli.py | 8 ++++ lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/router/manager.py | 1 + .../model_infer/mode_backend/base_backend.py | 2 + 12 files changed, 87 insertions(+), 24 deletions(-) diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index c19cc92667..8e7f9d78e8 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -464,6 +464,14 @@ PD 分离模式参数 示例可以在 test/advanced_config/mixed_quantization/llamacls-mix-down.yaml 中找到。 +.. option:: --expert_dtype + + EP MoE 专家量化类型,可选值: + + * ``fp8`` + * ``fp4``,仅支持 SM100 GPU + * ``None`` (默认) + .. option:: --vit_quant_type ViT 量化方法,可选值: diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index ad5b381304..84785de3b7 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -465,6 +465,14 @@ Quantization Parameters Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml. +.. option:: --expert_dtype + + Expert quantization dtype for EP MoE, optional values: + + * ``fp8`` + * ``fp4``: SM100 GPUs only + * ``None`` (default) + .. option:: --vit_quant_type ViT quantization method, optional values: diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 23e8d36da9..473dcbafda 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -85,6 +85,7 @@ def __init__(self, kvargs): self.disable_cudagraph = kvargs.get("disable_cudagraph", False) self.quant_type = kvargs.get("quant_type", "none") self.quant_cfg_path = kvargs.get("quant_cfg", None) + self.expert_dtype = kvargs.get("expert_dtype", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode @@ -156,7 +157,7 @@ def _verify_params(self): return def _init_quant(self): - self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) + self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path, self.expert_dtype) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 2327c7f445..cd4ece1306 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -13,7 +13,7 @@ get_ep_num_sms, masked_group_gemm, deepgemm_grouped_fp8_nt_contiguous, - use_sm100_fp4_moe, + use_sm100_mega_moe, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -153,7 +153,7 @@ def select_experts_and_quant_input( scoring_func=scoring_func, ) w13_weight, w13_scale = w13.weight, w13.weight_scale - if use_sm100_fp4_moe(self.quant_method): + if use_sm100_mega_moe(self.quant_method): from deep_gemm.utils import per_token_cast_to_fp8 qinput_tensor = per_token_cast_to_fp8( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 184e1a2822..82c14e5639 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -22,6 +22,7 @@ logger = init_logger(__name__) _MEGA_MOE_STATES: Dict[Tuple[int, int, int, int], Dict[str, Any]] = {} +SUPPORTED_EP_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32") try: from deep_ep import Buffer, EventOverlap @@ -37,10 +38,25 @@ def get_ep_num_sms() -> int: return getattr(dist_group_manager, "ep_num_sms", None) or 0 -def use_sm100_fp4_moe(quant_method: Any) -> bool: +def use_sm100_mega_moe(quant_method: Any) -> bool: return is_sm100_gpu() and quant_method.method_name == "deepgemm-fp4fp8-b32" +def check_ep_expert_dtype(quant_method: Any): + expert_dtype = getattr(quant_method, "method_name", None) + if expert_dtype not in SUPPORTED_EP_EXPERT_DTYPES: + raise ValueError( + "EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], " + f"but the resolved fused_moe quant method is `{expert_dtype}`. " + "Please start with --expert_dtype fp8 or --expert_dtype fp4. " + "Note that --expert_dtype fp4 is only supported on SM100 GPUs." + ) + if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu(): + raise RuntimeError( + "--expert_dtype fp4 requires an SM100 GPU for EP MoE; " "please use --expert_dtype fp8 on non-SM100 GPUs." + ) + + def masked_group_gemm( recv_x: Tuple[torch.Tensor, torch.Tensor], masked_m: torch.Tensor, @@ -155,10 +171,10 @@ def do_fused_experts( is_prefill: Optional[bool], previous_event: Optional[Any] = None, ): - if use_sm100_fp4_moe(quant_method): + check_ep_expert_dtype(quant_method) + if use_sm100_mega_moe(quant_method): return mega_moe_impl(hidden_states, w13, w2, topk_weights, topk_idx, quant_method) - use_fp8_w8a8 = quant_method.method_name != "none" buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer return fused_experts_impl( hidden_states=hidden_states, @@ -169,8 +185,8 @@ def do_fused_experts( num_experts=num_experts, buffer=buffer, is_prefill=is_prefill, - use_fp8_w8a8=use_fp8_w8a8, - use_fp8_all2all=use_fp8_w8a8, + use_fp8_w8a8=True, + use_fp8_all2all=True, use_int8_w8a16=False, w1_scale=w13.weight_scale, w2_scale=w2.weight_scale, diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 39674d2bba..a439323895 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -7,17 +7,42 @@ from .awq import * from .no_quant import * from lightllm.utils.log_utils import init_logger +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +EXPERT_DTYPE_TO_QUANT_TYPE = { + "fp8": "deepgemm-fp8w8a8-b128", + "fp4": "deepgemm-fp4fp8-b32", +} +SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE) + class Quantcfg: - def __init__(self, network_config, quant_type="none", custom_cfg_path=None): + def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expert_dtype=None): self.layer_num = network_config["n_layer"] self.quant_type = quant_type + self.expert_dtype = expert_dtype self.network_config_ = network_config self._parse_custom_cfg(custom_cfg_path) self._parse_network_config(network_config) + self._apply_custom_expert_dtype(expert_dtype) + + def _apply_custom_expert_dtype(self, expert_dtype): + if expert_dtype is None: + return + quant_type = self._get_expert_quant_type(expert_dtype, "--expert_dtype") + for layer_num in range(self.layer_num): + self.quant_cfg[layer_num]["fused_moe"] = quant_type + logger.info(f"select fused_moe quant way from --expert_dtype=`{expert_dtype}`: {quant_type}") + + def _get_expert_quant_type(self, expert_dtype, source): + quant_type = EXPERT_DTYPE_TO_QUANT_TYPE.get(expert_dtype) + if quant_type is None: + raise ValueError(f"unsupported {source} `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}") + if expert_dtype == "fp4" and not is_sm100_gpu(): + raise RuntimeError(f"{source} `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.") + return quant_type def _parse_network_config(self, network_config): hf_quantization_config = network_config.get("quantization_config", None) @@ -47,18 +72,9 @@ def _mapping_quant_method(self): # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 - expert_dtype = self.network_config_.get("expert_dtype", None) + expert_dtype = None if self.expert_dtype is not None else self.network_config_.get("expert_dtype", None) if expert_dtype is not None: - expert_dtype_to_quant_type = { - "fp4": "deepgemm-fp4fp8-b32", - "fp8": "deepgemm-fp8w8a8-b128", - } - target = expert_dtype_to_quant_type.get(expert_dtype) - if target is None: - raise ValueError( - f"unsupported expert_dtype `{expert_dtype}`; " - f"expected one of {sorted(expert_dtype_to_quant_type)}" - ) + target = self._get_expert_quant_type(expert_dtype, "network config expert_dtype") for layer_num in range(self.layer_num): self.quant_cfg[layer_num].setdefault("fused_moe", target) logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 4547ad529a..be819c94a0 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -7,6 +7,7 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.utils.envs_utils import get_env_start_args @@ -295,7 +296,7 @@ def overlap_tpsp_token_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -421,7 +422,7 @@ def overlap_tpsp_context_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index a39d2f9297..8879aa2d27 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -6,6 +6,7 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.envs_utils import get_env_start_args @@ -133,7 +134,7 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -245,7 +246,7 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe or layer_weight.experts.use_sm100_mega_moe(): + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 6b30ab6874..2db6c67e77 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -620,6 +620,14 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Path of quantization config. It can be used for mixed quantization. Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.""", ) + parser.add_argument( + "--expert_dtype", + type=str, + default=None, + choices=["fp8", "fp4"], + help="""Expert quantization dtype for EP MoE. Supported values are + fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""", + ) parser.add_argument( "--vit_quant_type", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 05ff2658e1..6d0ee07465 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -133,6 +133,7 @@ class StartArgs: graph_max_len_in_batch: int = field(default=0) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) + expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]}) vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 045723d073..a41c2f265a 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -171,6 +171,7 @@ async def wait_to_model_ready(self): "batch_max_tokens": self.args.batch_max_tokens, "quant_type": self.args.quant_type, "quant_cfg": self.args.quant_cfg, + "expert_dtype": self.args.expert_dtype, "pd_rpyc_ports": self.args.pd_node_infer_rpyc_ports, # 非 pd 模式可以不设置 } diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e47717747e..4fb6a0db9d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -164,6 +164,7 @@ def init_model(self, kvargs): "batch_max_tokens": kvargs.get("batch_max_tokens", None), "quant_type": kvargs.get("quant_type", None), "quant_cfg": kvargs.get("quant_cfg", None), + "expert_dtype": kvargs.get("expert_dtype", None), "run_mode": self.run_mode, } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) @@ -338,6 +339,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), "quant_cfg": main_kvargs.get("quant_cfg", None), + "expert_dtype": main_kvargs.get("expert_dtype", None), "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), From 414df1d9f4c6a2a1409a8d0241f4b142e125455d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 28 May 2026 11:50:10 +0000 Subject: [PATCH 08/13] slim code --- .../fused_moe/impl/deepgemm_impl.py | 26 +++------------- .../fused_moe/grouped_fused_moe_ep.py | 25 ++++++++++++++- lightllm/common/quantization/__init__.py | 31 +++++++++---------- 3 files changed, 43 insertions(+), 39 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index cd4ece1306..4d4614c007 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -9,11 +9,11 @@ get_deepep_num_max_dispatch_tokens_per_rank_decode, ) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( - do_fused_experts, + fused_experts, get_ep_num_sms, masked_group_gemm, deepgemm_grouped_fp8_nt_contiguous, - use_sm100_mega_moe, + quantize_fused_experts_input, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -77,7 +77,7 @@ def _fused_experts( router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, ): - output = do_fused_experts( + output = fused_experts( hidden_states=input_tensor, w13=w13, w2=w2, @@ -152,24 +152,8 @@ def select_experts_and_quant_input( num_expert_group=n_group, scoring_func=scoring_func, ) - w13_weight, w13_scale = w13.weight, w13.weight_scale - if use_sm100_mega_moe(self.quant_method): - from deep_gemm.utils import per_token_cast_to_fp8 - - qinput_tensor = per_token_cast_to_fp8( - hidden_states, - use_ue8m0=True, - gran_k=self.quant_method.block_size, - use_packed_ue8m0=True, - ) - return topk_weights, topk_idx.to(torch.long), qinput_tensor - - block_size_k = 0 - if w13_weight.ndim == 3: - block_size_k = w13_weight.shape[2] // w13_scale.shape[2] - assert block_size_k == 128, "block_size_k must be 128" - qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype) - return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) + qinput_tensor = quantize_fused_experts_input(hidden_states, w13, self.quant_method) + return topk_weights, topk_idx.to(torch.long), qinput_tensor def dispatch( self, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 82c14e5639..cb2e370cb9 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -160,7 +160,30 @@ def mega_moe_impl( return output -def do_fused_experts( +def quantize_fused_experts_input( + hidden_states: torch.Tensor, + w13: Any, + quant_method: Any, +): + check_ep_expert_dtype(quant_method) + if use_sm100_mega_moe(quant_method): + from deep_gemm.utils import per_token_cast_to_fp8 + + return per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=quant_method.block_size, + use_packed_ue8m0=True, + ) + + block_size_k = 0 + if w13.weight.ndim == 3: + block_size_k = w13.weight.shape[2] // w13.weight_scale.shape[2] + assert block_size_k == 128, "block_size_k must be 128" + return per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13.weight.dtype) + + +def fused_experts( hidden_states: torch.Tensor, w13: Any, w2: Any, diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index a439323895..cd534d53ec 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -26,22 +26,15 @@ def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expe self.network_config_ = network_config self._parse_custom_cfg(custom_cfg_path) self._parse_network_config(network_config) - self._apply_custom_expert_dtype(expert_dtype) - def _apply_custom_expert_dtype(self, expert_dtype): - if expert_dtype is None: - return - quant_type = self._get_expert_quant_type(expert_dtype, "--expert_dtype") - for layer_num in range(self.layer_num): - self.quant_cfg[layer_num]["fused_moe"] = quant_type - logger.info(f"select fused_moe quant way from --expert_dtype=`{expert_dtype}`: {quant_type}") - - def _get_expert_quant_type(self, expert_dtype, source): + def _get_expert_quant_type(self, expert_dtype): quant_type = EXPERT_DTYPE_TO_QUANT_TYPE.get(expert_dtype) if quant_type is None: - raise ValueError(f"unsupported {source} `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}") + raise ValueError( + f"unsupported expert_dtype `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}" + ) if expert_dtype == "fp4" and not is_sm100_gpu(): - raise RuntimeError(f"{source} `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.") + raise RuntimeError("expert_dtype `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.") return quant_type def _parse_network_config(self, network_config): @@ -72,12 +65,16 @@ def _mapping_quant_method(self): # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 - expert_dtype = None if self.expert_dtype is not None else self.network_config_.get("expert_dtype", None) - if expert_dtype is not None: - target = self._get_expert_quant_type(expert_dtype, "network config expert_dtype") - for layer_num in range(self.layer_num): + expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None) + if expert_dtype is None: + return + target = self._get_expert_quant_type(expert_dtype) + for layer_num in range(self.layer_num): + if self.expert_dtype is not None: + self.quant_cfg[layer_num]["fused_moe"] = target + else: self.quant_cfg[layer_num].setdefault("fused_moe", target) - logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") + logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") elif self.hf_quantization_method == "awq": self.quant_type = "awq" if is_awq_marlin_compatible(self.hf_quantization_config): From 2ac36306e852d6188a3e73ab87e9bebd0c88b743 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 29 May 2026 12:17:24 +0800 Subject: [PATCH 09/13] feat: add --model_name to benchmark_qps.py --- test/benchmark/service/benchmark_qps.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 3249ebcbda..a9083091ea 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -401,6 +401,12 @@ def main(): ) parser.add_argument("--num_clients", type=int, default=100) parser.add_argument("--tokenizer_path", type=str, default=None) + parser.add_argument( + "--model_name", + type=str, + default=None, + help="Model name passed to the server. Defaults to --tokenizer_path.", + ) parser.add_argument("--data_path", type=str, default=None) parser.add_argument("--input_num", type=int, default=2000) parser.add_argument("--input_qps", type=float, default=30.0) @@ -436,7 +442,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(normalize_model_name(args.tokenizer_path)) + model_name.append(args.model_name if args.model_name is not None else normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) From d3a0dd8b0d008c14a6f3e861768121f001be762d Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 29 May 2026 15:40:49 +0800 Subject: [PATCH 10/13] feat: add hadamard_transform kernel --- .../layer_infer/transformer_layer_infer.py | 2 +- .../triton_kernel/hadamard_transform.py | 57 ++++++++++++ lightllm/utils/backend_validator.py | 12 ++- .../triton_kernel/test_hadamard_transform.py | 92 +++++++++++++++++++ 4 files changed, 157 insertions(+), 6 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py create mode 100644 unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 899531448b..7f7377f319 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -251,7 +251,7 @@ def _get_indices( @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 - from sgl_kernel import hadamard_transform + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform hidden_size = x.size(-1) assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." diff --git a/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py new file mode 100644 index 0000000000..d5e571913c --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py @@ -0,0 +1,57 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _butterfly_stage(x, GROUPS: tl.constexpr, STEP: tl.constexpr, BLOCK_N: tl.constexpr): + x_grouped = tl.reshape(x, (GROUPS, 2, STEP)) + x_grouped = tl.permute(x_grouped, (0, 2, 1)) + left, right = tl.split(x_grouped) + x_pair = tl.join(left + right, left - right) + x_pair = tl.permute(x_pair, (0, 2, 1)) + return tl.reshape(x_pair, (BLOCK_N,)) + + +@triton.jit +def _hadamard_transform_kernel( + X, + Y, + scale: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + x = tl.load(X + row * BLOCK_N + offsets).to(tl.float32) + + x = _butterfly_stage(x, 64, 1, BLOCK_N) + x = _butterfly_stage(x, 32, 2, BLOCK_N) + x = _butterfly_stage(x, 16, 4, BLOCK_N) + x = _butterfly_stage(x, 8, 8, BLOCK_N) + x = _butterfly_stage(x, 4, 16, BLOCK_N) + x = _butterfly_stage(x, 2, 32, BLOCK_N) + x = _butterfly_stage(x, 1, 64, BLOCK_N) + + tl.store(Y + row * BLOCK_N + offsets, x * scale) + + +def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + assert x.is_cuda, "hadamard_transform only supports CUDA tensors" + assert x.dtype == torch.bfloat16, "hadamard_transform expects bfloat16 input" + + original_shape = x.shape + hidden_size = x.size(-1) + assert hidden_size == 128, "DeepSeek-V3.2 Hadamard transform expects hidden size 128" + + x = x.contiguous() + out = torch.empty_like(x) + rows = x.numel() // hidden_size + _hadamard_transform_kernel[(rows,)]( + x, + out, + scale, + BLOCK_N=hidden_size, + num_warps=4, + ) + + return out.view(original_shape) diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 6c5fe90309..ab5c0a88a1 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -196,12 +196,15 @@ def _validate_flashmla_sparse(): except Exception as e: return False, f"sgl_kernel.flash_mla import failed: {type(e).__name__}: {e}" - batch, heads, seq, dim = 1, 64, 128, 512 + 64 + batch, heads, seq = 1, 64, 128 + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_dim = kv_lora_rank + qk_rope_head_dim dtype = torch.bfloat16 device = "cuda" - q = torch.randn(batch * seq, heads, dim, dtype=dtype, device=device) - kv = torch.zeros(batch * seq, 1, dim, dtype=dtype, device=device) + q = torch.randn(batch * seq, heads, qk_dim, dtype=dtype, device=device) + kv = torch.zeros(batch * seq, 1, qk_dim, dtype=dtype, device=device) index_topk = 128 topk_indices = torch.zeros(batch * seq, index_topk, dtype=torch.int32, device=device) @@ -210,8 +213,7 @@ def _validate_flashmla_sparse(): topk_indices = topk_indices.view(batch * seq, 1, index_topk) - softmax_scale = 1.0 / (dim ** 0.5) - kv_lora_rank = dim + softmax_scale = 1.0 / (qk_dim ** 0.5) try: mla_out, _, _ = flash_mla_sparse_fwd( diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py new file mode 100644 index 0000000000..7ea3f31a08 --- /dev/null +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py @@ -0,0 +1,92 @@ +import pytest +import torch + +from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + + +TP = 8 +INDEX_N_HEADS = 64 +INDEX_HEAD_DIM = 128 +TP_INDEX_N_HEADS = INDEX_N_HEADS // TP +SCALE = INDEX_HEAD_DIM ** -0.5 + + +def _get_sgl_kernel_hadamard_transform(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for hadamard_transform comparison") + try: + from sgl_kernel import hadamard_transform as sgl_hadamard_transform + except ImportError: + pytest.skip("sgl_kernel.hadamard_transform is not available") + return sgl_hadamard_transform + + +def _bench(fn, x, warmup=30, iters=300): + for _ in range(warmup): + fn(x, scale=SCALE) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + y = fn(x, scale=SCALE) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters, y + + +@pytest.mark.parametrize("tokens", [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]) +def test_hadamard_transform_matches_sgl_kernel_deepseek_v32_shapes(tokens): + sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform() + + q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + q_expected = sgl_hadamard_transform(q, scale=SCALE) + q_actual = hadamard_transform(q, scale=SCALE) + k_expected = sgl_hadamard_transform(k, scale=SCALE) + k_actual = hadamard_transform(k, scale=SCALE) + torch.cuda.synchronize() + + assert torch.equal(q_actual, q_expected) + assert torch.equal(k_actual, k_expected) + + +def test_hadamard_transform_perf_report_deepseek_v32_shapes(): + sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform() + + print( + "\nDeepSeek-V3.2 per-rank shapes with tp=8:" + "\n q: [tokens, 8, 128]" + "\n k: [tokens, 128]" + "\n\ntokens | q_diff | k_diff | sgl_q ms | tri_q ms | sgl_k ms | tri_k ms | tri(q+k) ms | slowdown q+k" + ) + + for tokens in [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]: + q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + q_expected = sgl_hadamard_transform(q, scale=SCALE) + q_actual = hadamard_transform(q, scale=SCALE) + k_expected = sgl_hadamard_transform(k, scale=SCALE) + k_actual = hadamard_transform(k, scale=SCALE) + torch.cuda.synchronize() + + q_diff = (q_expected.float() - q_actual.float()).abs().max().item() + k_diff = (k_expected.float() - k_actual.float()).abs().max().item() + sgl_q_ms, _ = _bench(sgl_hadamard_transform, q) + tri_q_ms, _ = _bench(hadamard_transform, q) + sgl_k_ms, _ = _bench(sgl_hadamard_transform, k) + tri_k_ms, _ = _bench(hadamard_transform, k) + sgl_sum_ms = sgl_q_ms + sgl_k_ms + tri_sum_ms = tri_q_ms + tri_k_ms + + print( + f"{tokens:6d} | {q_diff:6.1g} | {k_diff:6.1g} | " + f"{sgl_q_ms:8.4f} | {tri_q_ms:8.4f} | {sgl_k_ms:8.4f} | {tri_k_ms:8.4f} | " + f"{tri_sum_ms:11.4f} | {tri_sum_ms / sgl_sum_ms:10.2f}x" + ) + + assert q_diff == 0 + assert k_diff == 0 From 2ee7939bede4c004fb6fb936f4457d844ceecac3 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 29 May 2026 16:25:33 +0800 Subject: [PATCH 11/13] fix --- .../layer_infer/transformer_layer_infer.py | 11 +++++++++-- .../deepseek3_2/triton_kernel/extract_indexer_ks.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 7f7377f319..d6eaebe2fd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -227,7 +227,15 @@ def _get_indices( import deep_gemm - logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + (k_fp8_, k_scale_), + weights.squeeze(-1), + ks, + ke, + clean_logits=False, + max_seqlen_k=infer_state.max_kv_seq_len, + ) from sgl_kernel import fast_topk_v2 @@ -235,7 +243,6 @@ def _get_indices( score=logits, lengths=lengths, topk=self.index_topk, - row_starts=ks, ) b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1) # 将 topk index 转化为 mem index diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index d0f8b45f81..f02fc30942 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -112,4 +112,4 @@ def extract_indexer_ks( num_stages=1, ) - return O_fp8, O_scale + return O_fp8, O_scale.squeeze(-1) From 5ef3796b84ad87174a92e981681115a034b0a269 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 29 May 2026 08:58:05 +0000 Subject: [PATCH 12/13] update hadamard --- .../triton_kernel/hadamard_transform.py | 73 ++++++++++++------- requirements.txt | 2 +- .../triton_kernel/test_hadamard_transform.py | 17 +---- 3 files changed, 53 insertions(+), 39 deletions(-) diff --git a/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py index d5e571913c..eabf703f56 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py +++ b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py @@ -1,57 +1,80 @@ +import functools + import torch import triton import triton.language as tl @triton.jit -def _butterfly_stage(x, GROUPS: tl.constexpr, STEP: tl.constexpr, BLOCK_N: tl.constexpr): - x_grouped = tl.reshape(x, (GROUPS, 2, STEP)) - x_grouped = tl.permute(x_grouped, (0, 2, 1)) +def _butterfly_stage(x, GROUPS: tl.constexpr, STEP: tl.constexpr, BLOCK_R: tl.constexpr, BLOCK_N: tl.constexpr): + x_grouped = tl.reshape(x, (BLOCK_R, GROUPS, 2, STEP)) + x_grouped = tl.permute(x_grouped, (0, 1, 3, 2)) left, right = tl.split(x_grouped) x_pair = tl.join(left + right, left - right) - x_pair = tl.permute(x_pair, (0, 2, 1)) - return tl.reshape(x_pair, (BLOCK_N,)) + x_pair = tl.permute(x_pair, (0, 1, 3, 2)) + return tl.reshape(x_pair, (BLOCK_R, BLOCK_N)) @triton.jit def _hadamard_transform_kernel( X, Y, + n_rows, scale: tl.constexpr, + BLOCK_R: tl.constexpr, BLOCK_N: tl.constexpr, ): - row = tl.program_id(0) - offsets = tl.arange(0, BLOCK_N) - x = tl.load(X + row * BLOCK_N + offsets).to(tl.float32) + pid = tl.program_id(0) + rows = pid * BLOCK_R + tl.arange(0, BLOCK_R) + mask = rows[:, None] < n_rows + offsets = rows[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) - x = _butterfly_stage(x, 64, 1, BLOCK_N) - x = _butterfly_stage(x, 32, 2, BLOCK_N) - x = _butterfly_stage(x, 16, 4, BLOCK_N) - x = _butterfly_stage(x, 8, 8, BLOCK_N) - x = _butterfly_stage(x, 4, 16, BLOCK_N) - x = _butterfly_stage(x, 2, 32, BLOCK_N) - x = _butterfly_stage(x, 1, 64, BLOCK_N) + x = _butterfly_stage(x, 64, 1, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 32, 2, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 16, 4, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 8, 8, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 4, 16, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 2, 32, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 1, 64, BLOCK_R, BLOCK_N) - tl.store(Y + row * BLOCK_N + offsets, x * scale) + tl.store(Y + offsets, x * scale, mask=mask) -def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: - assert x.is_cuda, "hadamard_transform only supports CUDA tensors" - assert x.dtype == torch.bfloat16, "hadamard_transform expects bfloat16 input" +@functools.lru_cache(maxsize=None) +def _target_programs(device_index: int) -> int: + return torch.cuda.get_device_properties(device_index).multi_processor_count * 2 + +def _pick_block_r(rows: int, device_index: int) -> int: + block_r = triton.next_power_of_2(max(1, rows // _target_programs(device_index))) + return max(1, min(128, block_r)) + + +def _hadamard_transform_triton(x: torch.Tensor, scale: float) -> torch.Tensor: original_shape = x.shape hidden_size = x.size(-1) - assert hidden_size == 128, "DeepSeek-V3.2 Hadamard transform expects hidden size 128" - - x = x.contiguous() - out = torch.empty_like(x) + if not x.is_contiguous(): + x = x.contiguous() rows = x.numel() // hidden_size - _hadamard_transform_kernel[(rows,)]( + out = torch.empty_like(x) + BLOCK_R = _pick_block_r(rows, x.device.index) + grid = (triton.cdiv(rows, BLOCK_R),) + _hadamard_transform_kernel[grid]( x, out, + rows, scale, + BLOCK_R=BLOCK_R, BLOCK_N=hidden_size, num_warps=4, ) - return out.view(original_shape) + + +def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + assert x.is_cuda, "hadamard_transform only supports CUDA tensors" + assert x.dtype == torch.bfloat16, "hadamard_transform expects bfloat16 input" + assert x.size(-1) == 128, "DeepSeek-V3.2 Hadamard transform expects hidden size 128" + + return _hadamard_transform_triton(x, scale) diff --git a/requirements.txt b/requirements.txt index f2f5bcf504..f124ce76f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,7 +61,7 @@ sortedcontainers==2.4.0 toolz==0.12.0 torch==2.11.0 tqdm==4.65.0 -transformers==4.57.1 +transformers==5.8.0 tokenizers==0.22.1 urllib3==1.26.16 uvicorn==0.19.0 diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py index 7ea3f31a08..8a54d6d9fd 100644 --- a/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py @@ -1,5 +1,6 @@ import pytest import torch +import triton from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform @@ -21,19 +22,9 @@ def _get_sgl_kernel_hadamard_transform(): return sgl_hadamard_transform -def _bench(fn, x, warmup=30, iters=300): - for _ in range(warmup): - fn(x, scale=SCALE) - torch.cuda.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - y = fn(x, scale=SCALE) - end.record() - torch.cuda.synchronize() - return start.elapsed_time(end) / iters, y +def _bench(fn, x): + ms = triton.testing.do_bench_cudagraph(lambda: fn(x, scale=SCALE), return_mode="median") + return ms, fn(x, scale=SCALE) @pytest.mark.parametrize("tokens", [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]) From 8ce80b8566392595a69cec38a8fdaf043fbea7db Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Sun, 31 May 2026 21:28:15 +0800 Subject: [PATCH 13/13] fix: fix a random bug in the benchmark_multiturn.py --- test/benchmark/service/benchmark_multiturn.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index 7019654c38..897d125077 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -5,8 +5,9 @@ "sessions". Each session starts from a prompt of ~start_input_len tokens (with a per-session random prefix so different sessions don't share KV cache) and keeps issuing streaming requests turn by turn. After every -turn the model's generated text plus a dynamically sampled number of new -tokens are appended to the prompt, simulating the user's next message. +turn, deterministic synthetic assistant tokens plus a dynamically sampled +number of new user tokens are appended to the prompt. This keeps the exact +request stream reproducible for a fixed seed. A session stops when the next prompt would exceed max_input_len, or after max_turns turns. @@ -178,19 +179,28 @@ def append_turn_input( tokenizer, prompt: str, prompt_token_len: int, - generated_text: str, + assistant_token_count: int, turn_input_increment: int, rng: random.Random, ) -> Tuple[str, int]: - """Append the model's generated text plus a fresh random user turn - to the prompt. Returns (new_prompt, new_prompt_token_len).""" + """Append deterministic synthetic assistant/user text to the prompt. + + The benchmark measures server output, but the next request must not depend + on that output; otherwise repeated runs with the same seed can diverge. + """ + if assistant_token_count > 0: + assistant_ids = gen_random_token_ids(tokenizer, assistant_token_count, rng) + assistant_text = decode_ids(tokenizer, assistant_ids) + else: + assistant_text = "" + if turn_input_increment > 0: - new_ids = gen_random_token_ids(tokenizer, turn_input_increment, rng) - new_text = decode_ids(tokenizer, new_ids) + user_ids = gen_random_token_ids(tokenizer, turn_input_increment, rng) + user_text = decode_ids(tokenizer, user_ids) else: - new_text = "" + user_text = "" - appended_text = generated_text + new_text + appended_text = assistant_text + user_text new_prompt = prompt + appended_text if not appended_text: return new_prompt, prompt_token_len @@ -408,8 +418,8 @@ def run_session( prompt, prompt_len = append_turn_input( tokenizer, prompt, - result["prompt_tokens"] or prompt_len, - result["generated_text"], + prompt_len, + turn_output_len, turn_input_len, rng, ) @@ -631,7 +641,7 @@ def main() -> None: "--turn_input_increment", type=int, default=2048, - help="Maximum new 'user' tokens sampled after each turn, on top " "of the model's generated text.", + help="Maximum new 'user' tokens sampled after each turn, on top of deterministic synthetic assistant tokens.", ) parser.add_argument( "--min_turn_input_increment", type=int, default=512, help="Minimum new 'user' tokens sampled after each turn."