From 7c7bd619e9ddf0234a9afbee22dde8f268f9c8e3 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 3 Jun 2026 09:20:12 +0000 Subject: [PATCH 01/30] one pass --- .../common/kv_cache_mem_manager/__init__.py | 2 + .../common/kv_cache_mem_manager/allocator.py | 12 +- .../deepseek4_mem_manager.py | 203 +++++++++ lightllm/common/req_manager.py | 106 ++++- lightllm/models/__init__.py | 1 + lightllm/models/deepseek_v4/__init__.py | 0 lightllm/models/deepseek_v4/infer_struct.py | 27 ++ .../deepseek_v4/layer_infer/__init__.py | 0 .../deepseek_v4/layer_infer/attention.py | 55 +++ .../deepseek_v4/layer_infer/compressor.py | 156 +++++++ .../layer_infer/hyper_connection.py | 58 +++ .../layer_infer/post_layer_infer.py | 19 + .../layer_infer/pre_layer_infer.py | 22 + .../layer_infer/transformer_layer_infer.py | 270 ++++++++++++ .../deepseek_v4/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 37 ++ .../layer_weights/transformer_layer_weight.py | 398 ++++++++++++++++++ lightllm/models/deepseek_v4/mem_manager.py | 12 + lightllm/models/deepseek_v4/model.py | 121 ++++++ .../deepseek_v4/triton_kernel/__init__.py | 0 .../triton_kernel/quant_convert.py | 93 ++++ .../deepseek_v4/triton_kernel/rotary_emb.py | 26 ++ .../server/router/model_infer/infer_batch.py | 2 + 23 files changed, 1614 insertions(+), 6 deletions(-) create mode 100644 lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py create mode 100644 lightllm/models/deepseek_v4/__init__.py create mode 100644 lightllm/models/deepseek_v4/infer_struct.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/__init__.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/attention.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/compressor.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/hyper_connection.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/deepseek_v4/layer_weights/__init__.py create mode 100644 lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/deepseek_v4/mem_manager.py create mode 100644 lightllm/models/deepseek_v4/model.py create mode 100644 lightllm/models/deepseek_v4/triton_kernel/__init__.py create mode 100644 lightllm/models/deepseek_v4/triton_kernel/quant_convert.py create mode 100644 lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 05544e149a..95f7e8ab76 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,6 +4,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek3_2mem_manager import Deepseek3_2MemoryManager +from .deepseek4_mem_manager import DeepseekV4MemoryManager from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager @@ -17,6 +18,7 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek3_2MemoryManager", + "DeepseekV4MemoryManager", "FP8PerTokenGroupQuantDeepseek3_2MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", diff --git a/lightllm/common/kv_cache_mem_manager/allocator.py b/lightllm/common/kv_cache_mem_manager/allocator.py index 850c158778..0179ed2714 100644 --- a/lightllm/common/kv_cache_mem_manager/allocator.py +++ b/lightllm/common/kv_cache_mem_manager/allocator.py @@ -3,13 +3,13 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.log_utils import init_logger -from typing import Union, List +from typing import Union, List, Optional logger = init_logger(__name__) class KvCacheAllocator: - def __init__(self, size: int) -> None: + def __init__(self, size: int, shared_name: Optional[str] = None) -> None: self.size = size self.mem_state = torch.arange( 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True @@ -26,9 +26,11 @@ def __init__(self, size: int) -> None: rank_in_node = get_current_rank_in_node() # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + # shared_name 为 None 时使用主 kv 池的默认名(router 调度据此估算);DeepSeek-V4 的压缩子池等 + # 需要各自独立的计数器,传入区别于主池的唯一名,避免多个 allocator 写同一个共享计数器。 + if shared_name is None: + shared_name = f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + self.shared_can_use_token_num = SharedInt(shared_name) self.shared_can_use_token_num.set_value(self.can_use_mem_size) return diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py new file mode 100644 index 0000000000..900b551cec --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -0,0 +1,203 @@ +import torch +from typing import Dict, List, Optional +from .deepseek2_mem_manager import Deepseek2MemoryManager +from .allocator import KvCacheAllocator +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class _SubKvPool: + """DeepSeek-V4 压缩分支(c4 / c128)使用的轻量子池。 + + 一个独立的 KvCacheAllocator + 一块压缩 latent buffer,可选附带一块与 latent 1:1 的 + indexer-K buffer(仅 c4/CSA 层用)。刻意不继承 MemoryManager —— pd/shm/kv_move 等机制 + 对压缩池暂不需要,保持最小。布局与主 MLA latent 池一致(每槽多预留 1 行作 padding 哨兵)。 + """ + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + indexer_head_dim: int = 0, + shared_name: Optional[str] = None, + device: str = "cuda", + ): + self.size = size + self.dtype = dtype + self.head_num = head_num + self.head_dim = head_dim + self.layer_num = layer_num + self.indexer_head_dim = indexer_head_dim + + self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device=device) + if indexer_head_dim > 0: + self.index_k_buffer = torch.empty((layer_num, size + 1, indexer_head_dim), dtype=dtype, device=device) + else: + self.index_k_buffer = None + + self.allocator = KvCacheAllocator(size, shared_name=shared_name) + self.HOLD_TOKEN_MEMINDEX = size + + def alloc(self, need_size) -> torch.Tensor: + return self.allocator.alloc(need_size) + + def free(self, free_index) -> None: + self.allocator.free(free_index) + + def free_all(self) -> None: + self.allocator.free_all() + + def get_kv_buffer(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index] + + def get_index_k_buffer(self, layer_index: int) -> torch.Tensor: + assert self.index_k_buffer is not None, "this sub pool has no indexer-K buffer" + return self.index_k_buffer[layer_index] + + +class DeepseekV4MemoryManager(Deepseek2MemoryManager): + """DeepSeek-V4 KV 管理(锁定决策: SWA 全历史 + 不分页)。 + + - dense/SWA latent: 继承 Deepseek2 的单张量 MLA latent ``kv_buffer``(每 token 一槽,所有层 + 共享层轴,head_num==1)。SWA 分支靠 layer_infer 传 ``AttControl(use_sliding_window)`` + attn_sink + 读最近窗口;dense 槽为纯 latent,不挂 indexer-K(与 V3.2 区别)。 + - c4_pool / c128_pool: 两个独立 ``_SubKvPool``(window 粒度,1-token 分配)。c4 池附带 indexer-K。 + - 容量: 用闭式 ``get_cell_size()``(= 每个 dense token 在所有池上的总字节)让基类 ``profile_size`` + 直接得到 full_token = dense 池大小,再按 1/4、1/128 派生压缩池大小。 + - compressor 递归状态不在这里,放 DeepseekV4ReqManager(后续步骤)。 + """ + + # dense 写入沿用 Deepseek2MemOperator(拆 nope/rope);压缩写入算子随 layer_infer 一并补。 + # operator_class 继承自 Deepseek2MemoryManager(= Deepseek2MemOperator)。 + + def __init__( + self, + size, + dtype, + head_num, + head_dim, + layer_num, + compress_rates: List[int], + indexer_head_dim: int = 128, + always_copy=False, + mem_fraction=0.9, + ): + assert head_num == 1, "DeepSeek-V4 是 MLA(MQA),dense latent 的 head_num 必须为 1" + assert ( + len(compress_rates) == layer_num + ), f"compress_rates 长度 {len(compress_rates)} 必须等于 layer_num {layer_num}" + assert all(r in (0, 4, 128) for r in compress_rates), "compress_rates 取值只能是 0/4/128" + + self.compress_rates = list(compress_rates) + self.n_c4 = sum(1 for r in self.compress_rates if r == 4) + self.n_c128 = sum(1 for r in self.compress_rates if r == 128) + self.indexer_head_dim = indexer_head_dim + + # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法) + self.layer_to_c4_idx: Dict[int, int] = {} + self.layer_to_c128_idx: Dict[int, int] = {} + c4 = c128 = 0 + for lid, r in enumerate(self.compress_rates): + if r == 4: + self.layer_to_c4_idx[lid] = c4 + c4 += 1 + elif r == 128: + self.layer_to_c128_idx[lid] = c128 + c128 += 1 + + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + + def get_cell_size(self): + # 返回“每个 dense(full) token 在所有池上的总字节”。基类 profile_size 用 + # size = available_bytes / get_cell_size(),于是直接得到 full_token = dense 池大小。 + elem = torch._utils._element_size(self.dtype) + latent_bytes = self.head_num * self.head_dim * elem # 每 token 每层 dense latent + dense = latent_bytes * self.layer_num # SWA 全历史: 所有层 + c4 = latent_bytes * self.n_c4 / 4 # c4 压缩 latent + c128 = latent_bytes * self.n_c128 / 128 # c128 压缩 latent + indexer = self.indexer_head_dim * elem * self.n_c4 / 4 # c4 indexer-K + return dense + c4 + c128 + indexer + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + # dense/SWA latent(继承 Deepseek2: [layer_num, size+1, head_num, head_dim]) + super()._init_buffers(size, dtype, head_num, head_dim, layer_num) + self._init_compressed_pools(size, dtype, head_num, head_dim) + + def _init_compressed_pools(self, size, dtype, head_num, head_dim): + rank_in_node = get_current_rank_in_node() + server = get_unique_server_name() + + self.c4_size = (size + 4 - 1) // 4 + self.c128_size = (size + 128 - 1) // 128 + + self.c4_pool: Optional[_SubKvPool] = None + self.c128_pool: Optional[_SubKvPool] = None + if self.n_c4 > 0: + self.c4_pool = _SubKvPool( + size=self.c4_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=self.n_c4, + indexer_head_dim=self.indexer_head_dim, + shared_name=f"{server}_dsv4_c4_can_use_token_num_{rank_in_node}", + ) + if self.n_c128 > 0: + self.c128_pool = _SubKvPool( + size=self.c128_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=self.n_c128, + indexer_head_dim=0, + shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}", + ) + + logger.info( + f"DeepseekV4MemoryManager pools: dense={size} " + f"c4={self.c4_size}(L={self.n_c4}) c128={self.c128_size}(L={self.n_c128}) " + f"indexer_head_dim={self.indexer_head_dim}" + ) + + # dense latent 读取沿用父类 get_att_input_params。 + + def _pool_and_local_layer(self, layer_index: int): + r = self.compress_rates[layer_index] + if r == 4: + return self.c4_pool, self.layer_to_c4_idx[layer_index] + if r == 128: + return self.c128_pool, self.layer_to_c128_idx[layer_index] + raise AssertionError(f"layer {layer_index} (rate {r}) 不是压缩层,没有压缩池") + + def get_compressed_kv_buffer(self, layer_index: int) -> torch.Tensor: + pool, local_layer = self._pool_and_local_layer(layer_index) + return pool.get_kv_buffer(local_layer) + + def get_compressed_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" + return self.c4_pool.get_index_k_buffer(self.layer_to_c4_idx[layer_index]) + + def alloc_c4(self, need_size) -> torch.Tensor: + return self.c4_pool.alloc(need_size) + + def alloc_c128(self, need_size) -> torch.Tensor: + return self.c128_pool.alloc(need_size) + + def free_c4(self, free_index) -> None: + self.c4_pool.free(free_index) + + def free_c128(self, free_index) -> None: + self.c128_pool.free(free_index) + + def free_all(self): + super().free_all() + if self.c4_pool is not None: + self.c4_pool.free_all() + if self.c128_pool is not None: + self.c128_pool.free_all() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..c8197401c1 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -3,7 +3,7 @@ from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig from lightllm.utils.log_utils import init_logger -from .kv_cache_mem_manager import MemoryManager +from .kv_cache_mem_manager import MemoryManager, DeepseekV4MemoryManager from typing import List, Optional, TYPE_CHECKING from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter @@ -299,3 +299,107 @@ def copy_small_page_buffer_to_linear_att_state( self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state return + + +class DeepseekV4ReqManager(ReqManager): + """DeepSeek-V4 的请求级管理(锁定决策: SWA 全历史 + 不分页)。 + + 在基类 ReqManager 之上补三类 V4 专有的 per-request 结构(均从 mem_manager 读取 n_c4/n_c128/ + layer_to_*_idx/head_dim 等,避免重复配置): + + * ``req_to_c4_indexs`` / ``req_to_c128_indexs`` —— (req, 窗口下标) -> 压缩池槽位。 + 窗口下标 = position // compress_rate;窗口关闭时由 layer-infer 写入,attention 读取前 + n_windows 列即该 req 的全部压缩条目槽。未填充列为 0(不会被读到,语义同 req_to_token_indexs)。 + * ``req_to_c4_state`` / ``req_to_c128_state`` / ``req_to_c4_indexer_state`` —— compressor 的 + “在途窗口”累加状态(per req、per 压缩层),fp32。形状为 + ``(kv_or_score, coff * ratio, coff * dim)``; c4 因 Ca/Cb overlap 取 ``coff=2``, + c128 取 ``coff=1``。score 初始化为 ``-inf``,与官方 reference compressor 的 + ``kv_state``/``score_state`` 对齐。 + * entry_count 不另存:= position // compress_rate,可由序列长度推出。 + """ + + def __init__(self, max_request_num, max_sequence_length, mem_manager: DeepseekV4MemoryManager): + super().__init__(max_request_num, max_sequence_length, mem_manager) + assert isinstance(mem_manager, DeepseekV4MemoryManager) + self.n_c4 = mem_manager.n_c4 + self.n_c128 = mem_manager.n_c128 + head_dim = mem_manager.head_dim + indexer_head_dim = mem_manager.indexer_head_dim + + # (req, 窗口) -> 压缩槽。列数取 ceil(max_seq / ratio) 留足余量。 + c4_windows = (max_sequence_length + 4 - 1) // 4 + c128_windows = (max_sequence_length + 128 - 1) // 128 + self.req_to_c4_indexs = torch.zeros((max_request_num + 1, c4_windows), dtype=torch.int32, device="cuda") + self.req_to_c128_indexs = torch.zeros((max_request_num + 1, c128_windows), dtype=torch.int32, device="cuda") + + # compressor 在途窗口累加状态(fp32): [kv_or_score, coff * ratio, coff * dim]. + state_dtype = torch.float32 + self.req_to_c4_state = LayerCache( + size=max_request_num + 1, + dtype=state_dtype, + shape=(2, 8, 2 * head_dim), + layer_num=self.n_c4, + device="cuda", + ) + self.req_to_c128_state = LayerCache( + size=max_request_num + 1, + dtype=state_dtype, + shape=(2, 128, head_dim), + layer_num=self.n_c128, + device="cuda", + ) + self.req_to_c4_indexer_state = LayerCache( + size=max_request_num + 1, + dtype=state_dtype, + shape=(2, 8, 2 * indexer_head_dim), + layer_num=self.n_c4, + device="cuda", + ) + self._init_all_score_state() + return + + def _init_all_score_state(self): + if self.n_c4 > 0: + self.req_to_c4_state.buffer[:, :, 1, ...].fill_(float("-inf")) + self.req_to_c4_indexer_state.buffer[:, :, 1, ...].fill_(float("-inf")) + if self.n_c128 > 0: + self.req_to_c128_state.buffer[:, :, 1, ...].fill_(float("-inf")) + return + + def _reset_compress_cache_req(self, cache: LayerCache, req_idx: int): + if cache.layer_num == 0: + return + cache.buffer[:, req_idx, 0, ...].fill_(0) + cache.buffer[:, req_idx, 1, ...].fill_(float("-inf")) + return + + def init_compress_state(self, req_idx: int): + """新请求开始时重置其 compressor 在途状态(对应 mamba 的 init_linear_att_state)。""" + if self.n_c4 > 0: + self._reset_compress_cache_req(self.req_to_c4_state, req_idx) + self._reset_compress_cache_req(self.req_to_c4_indexer_state, req_idx) + if self.n_c128 > 0: + self._reset_compress_cache_req(self.req_to_c128_state, req_idx) + return + + def get_c4_compress_state(self, layer_index: int) -> torch.Tensor: + local = self.mem_manager.layer_to_c4_idx[layer_index] + return self.req_to_c4_state.buffer[local] + + def get_c128_compress_state(self, layer_index: int) -> torch.Tensor: + local = self.mem_manager.layer_to_c128_idx[layer_index] + return self.req_to_c128_state.buffer[local] + + def get_c4_indexer_compress_state(self, layer_index: int) -> torch.Tensor: + local = self.mem_manager.layer_to_c4_idx[layer_index] + return self.req_to_c4_indexer_state.buffer[local] + + def free(self, free_req_indexes, free_token_index, free_c4_index=None, free_c128_index=None): + """释放 dense 槽(基类)+ 压缩槽。压缩槽由调用方(infer batch)从 req_to_c*_indexs 收集后传入, + 与基类用 free_token_index 传 dense 槽的方式一致。""" + super().free(free_req_indexes, free_token_index) + if free_c4_index is not None and len(free_c4_index) > 0: + self.mem_manager.free_c4(free_c4_index) + if free_c128_index is not None and len(free_c128_index) > 0: + self.mem_manager.free_c128(free_c128_index) + return diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index f619b1d88f..3d376d160d 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -20,6 +20,7 @@ from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel +from lightllm.models.deepseek_v4.model import DeepseekV4TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, diff --git a/lightllm/models/deepseek_v4/__init__.py b/lightllm/models/deepseek_v4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py new file mode 100644 index 0000000000..6bc402cd28 --- /dev/null +++ b/lightllm/models/deepseek_v4/infer_struct.py @@ -0,0 +1,27 @@ +import torch +from lightllm.common.basemodel import InferStateInfo + + +class DeepseekV4InferStateInfo(InferStateInfo): + """Per-token interleaved-rope cos/sin for the two rope variants (sliding / compressed), following + the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). Also exposes the full compressed + cos/sin tables, which the KV compressor indexes at window positions (not per-token).""" + + def __init__(self): + super().__init__() + self.position_cos_sliding = None + self.position_sin_sliding = None + self.position_cos_compress = None + self.position_sin_compress = None + self.cos_compress_table = None + self.sin_compress_table = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) # sets position_ids, b_q_seq_len, b_q_start_loc (prefill) + pos = self.position_ids + self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, pos) # [T, rope_dim//2] + self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos) + self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos) + self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos) + self.cos_compress_table = model._cos_cached_compress + self.sin_compress_table = model._sin_cached_compress diff --git a/lightllm/models/deepseek_v4/layer_infer/__init__.py b/lightllm/models/deepseek_v4/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/layer_infer/attention.py b/lightllm/models/deepseek_v4/layer_infer/attention.py new file mode 100644 index 0000000000..a25a2aa3d1 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/attention.py @@ -0,0 +1,55 @@ +import torch +import torch.nn.functional as F + +# DeepSeek-V4 attention: MLA with a single shared KV head (head_dim=512), per-head learnable attention +# sink, and a candidate set = sliding-window tokens (size `window`) ++ compressed KV entries. Pure-torch +# transcription of the bundled reference (inference/model.py Attention.forward + kernel.py sparse_attn). +# Correctness-first prefill path. head_dim=512 > 256 so FlashAttention is unusable anyway; a fused +# triton sparse-gather kernel is a perf follow-up. + + +def torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale): + """Gather-then-softmax attention with a per-head sink, matching reference kernel.sparse_attn. + + q:[b,m,h,d], kv:[b,n,d] (single KV head shared over h), attn_sink:[h] (fp32), + topk_idxs:[b,m,K] int (-1 = invalid/skip). Returns o:[b,m,h,d]. + """ + b, m, h, d = q.shape + n = kv.shape[1] + K = topk_idxs.shape[-1] + idx = topk_idxs.clamp(min=0).long() # [b,m,K] + keys = torch.gather(kv.unsqueeze(1).expand(b, m, n, d), 2, idx.unsqueeze(-1).expand(b, m, K, d)) # [b,m,K,d] + qf, kf = q.float(), keys.float() + scores = torch.einsum("bmhd,bmkd->bmhk", qf, kf) * scale # [b,m,h,K] + valid = (topk_idxs != -1).unsqueeze(2) # [b,m,1,K] + scores = scores.masked_fill(~valid, float("-inf")) + mx = scores.amax(dim=-1, keepdim=True) # [b,m,h,1] + mx = torch.nan_to_num(mx, neginf=0.0) + ex = (scores - mx).exp() # [b,m,h,K] + denom = ex.sum(-1) + (attn_sink.view(1, 1, h) - mx.squeeze(-1)).exp() # [b,m,h] + o = torch.einsum("bmhk,bmkd->bmhd", ex, kf) / denom.unsqueeze(-1) + return o.to(q.dtype) + + +def build_prefill_topk_idxs(seqlen, window, ratio, n_window, device): + """Per-query candidate indices into [window_kv (n_window tokens) ++ compressed_kv (ncomp entries)]. + + Returns int32 [seqlen, window + ncomp] with -1 for invalid. Window part indexes the per-token KV + (here stored as tokens 0..seqlen-1, so n_window == seqlen); compressed part is offset by n_window. + For prompts where ncomp <= index_topk the indexer is a no-op, so all causally-valid compressed + entries are attended (matches the reference for short context). + """ + t = torch.arange(seqlen, device=device) + # sliding window: query t attends tokens [max(0, t-window+1) .. t] + j = torch.arange(n_window, device=device) + win = j.unsqueeze(0).expand(seqlen, n_window).clone() # [s, n_window] + win_valid = (j.unsqueeze(0) <= t.unsqueeze(1)) & (j.unsqueeze(0) > (t.unsqueeze(1) - window)) + win = torch.where(win_valid, win, torch.full_like(win, -1)) + if ratio: + ncomp = seqlen // ratio + c = torch.arange(ncomp, device=device) + comp_valid = c.unsqueeze(0) < ((t.unsqueeze(1) + 1) // ratio) # [s, ncomp] + comp_idx = (c.unsqueeze(0) + n_window).expand(seqlen, ncomp) + comp = torch.where(comp_valid, comp_idx, torch.full((seqlen, ncomp), -1, device=device, dtype=torch.long)) + return torch.cat([win, comp], dim=1).int() + return win.int() diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py new file mode 100644 index 0000000000..902de113db --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -0,0 +1,156 @@ +import torch +import torch.nn.functional as F +from ..triton_kernel.rotary_emb import apply_rotary_emb + +# KV compressor: pools every `ratio` consecutive tokens into one compressed KV entry via gated +# (softmax) pooling + a learned absolute-position bias (ape), RMSNorm, and rope on the trailing +# rope_dim. ratio==4 uses overlapping windows (two-series Ca/Cb scheme). Pure-torch transcription of +# the bundled reference inference/model.py Compressor.forward for the prefill (start_pos==0) path. +# NOTE: the reference also applies an FP8/FP4 QAT activation sim to the compressed entry; omitted here +# for the correctness-first prefill path (negligible vs argmax; revisit if e2e diverges). + + +def _overlap_transform(tensor, ratio, d, value): + # tensor: [nwin, ratio, 2*d] -> [nwin, 2*ratio, d]; slots [ratio:]=Cb(current), [:ratio]=Ca(previous window) + nwin = tensor.shape[0] + out = tensor.new_full((nwin, 2 * ratio, d), value) + out[:, ratio:] = tensor[:, :, d:] + out[1:, :ratio] = tensor[:-1, :, :d] + return out + + +def _rmsnorm(x, weight, eps): + xf = x.float() + xf = xf * torch.rsqrt(xf.square().mean(-1, keepdim=True) + eps) + return (xf * weight.float()).to(x.dtype) + + +def compress_prefill(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, rope_dim, cos_table, sin_table, eps): + """x:[s,dim] (one request, start_pos=0) -> compressed kv [nwin, head_dim] (rope applied to last rope_dim). + + nwin = s // ratio (remainder tokens are decode-state, handled in the decode path). wkv_w/wgate_w: + [coff*head_dim, dim]; norm_w:[head_dim]; ape:[ratio, coff*head_dim]; cos_table/sin_table: compress rope tables. + """ + overlap = ratio == 4 + coff = 2 if overlap else 1 + d = head_dim + s = x.shape[0] + nwin = s // ratio + if nwin == 0: + # fewer than `ratio` tokens -> no completed window -> no compressed entry (matches reference) + return x.new_zeros(0, head_dim) + cutoff = nwin * ratio + xf = x.float() + kv = F.linear(xf, wkv_w.float())[:cutoff].view(nwin, ratio, coff * d) + score = F.linear(xf, wgate_w.float())[:cutoff].view(nwin, ratio, coff * d) + ape.float() + if overlap: + kv = _overlap_transform(kv, ratio, d, 0.0) + score = _overlap_transform(score, ratio, d, float("-inf")) + kv = (kv * torch.softmax(score, dim=1)).sum(dim=1) # [nwin, d] fp32 + kv = _rmsnorm(kv.to(x.dtype), norm_w, eps) # [nwin, d] + pos = torch.arange(nwin, device=x.device) * ratio + kv_rope = apply_rotary_emb(kv[:, -rope_dim:], cos_table[pos], sin_table[pos]) # cos/sin: [nwin, rope_dim//2] + return torch.cat([kv[:, :-rope_dim], kv_rope], dim=1) + + +def new_compressor_state(ratio, head_dim, device, dtype=torch.float32): + """Per-request compressor running state (matches reference Compressor.kv_state/score_state).""" + coff = 2 if ratio == 4 else 1 + kv_state = torch.zeros(coff * ratio, coff * head_dim, device=device, dtype=dtype) + score_state = torch.full((coff * ratio, coff * head_dim), float("-inf"), device=device, dtype=dtype) + return kv_state, score_state + + +def _finish_entry(kv, norm_w, ape_unused, rope_dim, cos_table, sin_table, position, eps, dtype): + kv = _rmsnorm(kv.to(dtype), norm_w, eps) # [d] + cos = cos_table[position : position + 1] # [1, rope_dim//2] + sin = sin_table[position : position + 1] + kv_rope = apply_rotary_emb(kv[-rope_dim:].unsqueeze(0), cos, sin)[0] + return torch.cat([kv[:-rope_dim], kv_rope], dim=0) + + +def compressor_prefill_state(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, rope_dim, cos_table, sin_table, eps): + """Faithful reference start_pos==0 path (incl. remainder). Returns (entries[ncomp,d], kv_state, score_state). + + entries have rope applied; kv_state/score_state carry the partial window for the decode path. + """ + overlap = ratio == 4 + coff = 2 if overlap else 1 + d = head_dim + s = x.shape[0] + dtype = x.dtype + xf = x.float() + kv = F.linear(xf, wkv_w.float()) # [s, coff*d] + score = F.linear(xf, wgate_w.float()) # [s, coff*d] + ape = ape.float() + kv_state, score_state = new_compressor_state(ratio, head_dim, x.device) + should_compress = s >= ratio + remainder = s % ratio + cutoff = s - remainder + offset = ratio if overlap else 0 + if overlap and cutoff >= ratio: + kv_state[:ratio] = kv[cutoff - ratio : cutoff] + score_state[:ratio] = score[cutoff - ratio : cutoff] + ape + if remainder > 0: + kv_state[offset : offset + remainder] = kv[cutoff:] + score_state[offset : offset + remainder] = score[cutoff:] + ape[:remainder] + kv = kv[:cutoff] + score = score[:cutoff] + if not should_compress: + return x.new_zeros(0, head_dim), kv_state, score_state + nwin = cutoff // ratio + kvw = kv.view(nwin, ratio, coff * d) + scw = score.view(nwin, ratio, coff * d) + ape + if overlap: + kvw = _overlap_transform(kvw, ratio, d, 0.0) + scw = _overlap_transform(scw, ratio, d, float("-inf")) + comp = (kvw * torch.softmax(scw, dim=1)).sum(dim=1) # [nwin, d] fp32 + comp = _rmsnorm(comp.to(dtype), norm_w, eps) + pos = torch.arange(nwin, device=x.device) * ratio + comp_rope = apply_rotary_emb(comp[:, -rope_dim:], cos_table[pos], sin_table[pos]) + comp = torch.cat([comp[:, :-rope_dim], comp_rope], dim=1) + return comp, kv_state, score_state + + +def compressor_decode_step( + x_new, + wkv_w, + wgate_w, + norm_w, + ape, + ratio, + head_dim, + rope_dim, + cos_table, + sin_table, + eps, + kv_state, + score_state, + start_pos, +): + """Faithful reference start_pos>0 path for one new token. Mutates kv_state/score_state in place. + Returns the new compressed entry [d] (rope applied) when a window completes, else None.""" + overlap = ratio == 4 + d = head_dim + dtype = x_new.dtype + xf = x_new.float().view(-1) # [dim] + kv = F.linear(xf, wkv_w.float()) # [coff*d] + score = F.linear(xf, wgate_w.float()) + ape.float()[start_pos % ratio] # [coff*d] + should_compress = (start_pos + 1) % ratio == 0 + if overlap: + kv_state[ratio + start_pos % ratio] = kv + score_state[ratio + start_pos % ratio] = score + if should_compress: + kv_cat = torch.cat([kv_state[:ratio, :d], kv_state[ratio:, d:]], dim=0) # [2*ratio, d] + sc_cat = torch.cat([score_state[:ratio, :d], score_state[ratio:, d:]], dim=0) + entry = (kv_cat * torch.softmax(sc_cat, dim=0)).sum(dim=0) # [d] + kv_state[:ratio] = kv_state[ratio:] + score_state[:ratio] = score_state[ratio:] + else: + kv_state[start_pos % ratio] = kv + score_state[start_pos % ratio] = score + if should_compress: + entry = (kv_state * torch.softmax(score_state, dim=0)).sum(dim=0) # [d] + if not should_compress: + return None + return _finish_entry(entry, norm_w, ape, rope_dim, cos_table, sin_table, start_pos + 1 - ratio, eps, dtype) diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py new file mode 100644 index 0000000000..75f540725b --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py @@ -0,0 +1,58 @@ +import torch +import torch.nn.functional as F + +# Manifold-constrained Hyper-Connections (mHC). Replaces the plain residual add: the hidden state is +# carried as ``hc_mult`` parallel streams. Each sub-layer (attn / ffn) collapses the streams to one +# vector (hc_pre), runs the sub-layer, then re-expands into the streams via learned post/comb weights +# (hc_post). A doubly-stochastic (Sinkhorn-normalized) ``comb`` matrix mixes the residual streams. +# Pure-torch transcription of the bundled reference inference/model.py (Block.hc_pre/hc_post, +# ParallelHead.hc_head) + inference/kernel.py (hc_split_sinkhorn). All math in fp32, as in the reference. + + +def hc_split_sinkhorn(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps): + """mixes:[N, (2+hc)*hc] fp32 -> pre[N,hc], post[N,hc], comb[N,hc,hc] (doubly stochastic).""" + hc = hc_mult + pre = torch.sigmoid(mixes[:, :hc] * hc_scale[0] + hc_base[:hc]) + eps + post = 2.0 * torch.sigmoid(mixes[:, hc : 2 * hc] * hc_scale[1] + hc_base[hc : 2 * hc]) + comb = mixes[:, 2 * hc :].view(-1, hc, hc) * hc_scale[2] + hc_base[2 * hc :].view(hc, hc) + # comb = softmax(comb, dim=-1) + eps + comb = torch.softmax(comb, dim=-1) + eps + # one column normalization, then (iters-1) of (row, column) + comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) + for _ in range(sinkhorn_iters - 1): + comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) + comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) + return pre, post, comb + + +def hc_pre(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, eps, sinkhorn_iters): + """streams:[N, hc*dim] -> (collapsed[N,dim], post[N,hc], comb[N,hc,hc]).""" + dtype = streams.dtype + x = streams.float() # [N, hc*dim] + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + eps) + mixes = F.linear(x, hc_fn) * rsqrt # [N, (2+hc)*hc] + pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps) + streams3 = x.view(-1, hc_mult, dim) + collapsed = torch.sum(pre.unsqueeze(-1) * streams3, dim=1) # [N, dim] + return collapsed.to(dtype), post, comb + + +def hc_post(x, residual, post, comb, hc_mult, dim): + """x:[N,dim] sub-layer output, residual:[N, hc*dim] -> [N, hc*dim].""" + res = residual.float().view(-1, hc_mult, dim) # [N, hc, dim] + xf = x.float() + # post: [N,hc] -> [N,hc,dim]; comb mixes residual streams: out[i] = post[i]*x + sum_j comb[i,j]*res[j] + y = post.unsqueeze(-1) * xf.unsqueeze(-2) + torch.einsum("nij,njd->nid", comb, res) + return y.reshape(-1, hc_mult * dim).to(x.dtype) + + +def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, eps): + """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim] (sigmoid gate, no sinkhorn).""" + dtype = streams.dtype + x = streams.float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + eps) + mixes = F.linear(x, hc_fn) * rsqrt # [N, hc] + pre = torch.sigmoid(mixes * hc_scale + hc_base) + eps # [N, hc] + streams3 = x.view(-1, hc_mult, dim) + collapsed = torch.sum(pre.unsqueeze(-1) * streams3, dim=1) + return collapsed.to(dtype) diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..87951e7360 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py @@ -0,0 +1,19 @@ +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from .hyper_connection import hc_head + + +class DeepseekV4PostLayerInfer(LlamaPostLayerInfer): + """Collapse the hc_mult residual streams (hc_head) to [T, hidden], then final norm + lm_head.""" + + def token_forward(self, input_embdings, infer_state, layer_weight): + cfg = layer_weight.network_config_ + collapsed = hc_head( + input_embdings, + layer_weight.hc_head_fn_.weight, + layer_weight.hc_head_scale_.weight, + layer_weight.hc_head_base_.weight, + cfg["hc_mult"], + cfg["hidden_size"], + cfg.get("hc_eps", 1e-6), + ) + return super().token_forward(collapsed, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..0be99ecbab --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py @@ -0,0 +1,22 @@ +import torch +import torch.distributed as dist +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.distributed.communication_op import all_reduce + + +class DeepseekV4PreLayerInfer(LlamaPreLayerInfer): + """Token embedding, then expand to the hc_mult parallel residual streams [T, hc_mult*hidden].""" + + def _embed_and_expand(self, input_ids, infer_state, layer_weight): + emb = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) # [T, hidden] + if self.tp_world_size_ > 1: + all_reduce(emb, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + hc_mult = layer_weight.network_config_["hc_mult"] + t, hidden = emb.shape + return emb.unsqueeze(1).expand(t, hc_mult, hidden).reshape(t, hc_mult * hidden).contiguous() + + def context_forward(self, input_ids, infer_state, layer_weight): + return self._embed_and_expand(input_ids, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state, layer_weight): + return self._embed_and_expand(input_ids, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..a8dd0bb1e8 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -0,0 +1,270 @@ +import torch +import torch.nn.functional as F +import torch.distributed as dist +from lightllm.common.basemodel import TransformerLayerInferTpl +from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args +from .hyper_connection import hc_pre, hc_post +from ..triton_kernel.rotary_emb import apply_rotary_emb +from .compressor import compressor_prefill_state, compressor_decode_step +from .attention import torch_sparse_attn +from ..triton_kernel.quant_convert import dequant_fp4_group_to_bf16 + + +class DeepseekV4TransformerLayerInfer(TransformerLayerInferTpl): + """One V4 decoder layer: HC(attn) then HC(ffn). Correctness-first pure-torch. + + The residual is carried as ``hc_mult`` streams flattened to [T, hc_mult*hidden]; each sub-layer + collapses (hc_pre), computes, and re-expands (hc_post). Attention is MLA over a sliding window + + compressed KV with a per-head sink (torch_sparse_attn); the MoE reuses lightllm's deepgemm FP8 + grouped GEMM driven by V4's custom router (sqrtsoftplus + hash/topk + bias-for-selection). + + Per-request decode state (window KV history + compressed KV + compressor running state) is kept in + a dict keyed by request id. NOTE: correctness-first — this should move into the KV mem manager for + production memory management / request eviction. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + cfg = network_config + self.eps_ = cfg["rms_norm_eps"] + self.hidden = cfg["hidden_size"] + self.n_heads = cfg["num_attention_heads"] + self.head_dim = cfg["head_dim"] + self.rope_dim = cfg["qk_rope_head_dim"] + self.o_groups = cfg["o_groups"] + self.o_lora = cfg["o_lora_rank"] + self.hc_mult = cfg["hc_mult"] + self.sinkhorn_iters = cfg["hc_sinkhorn_iters"] + self.hc_eps = cfg["hc_eps"] + self.window = cfg["sliding_window"] + self.compress_ratio = cfg["compress_ratios"][layer_num] + self.is_hash = layer_num < cfg["num_hash_layers"] + self.topk = cfg["num_experts_per_tok"] + self.route_scale = cfg["routed_scaling_factor"] + self.swiglu_limit = cfg["swiglu_limit"] + self.softmax_scale = self.head_dim**-0.5 + self.tp_q_heads = self.n_heads // self.tp_world_size_ + self.tp_groups = self.o_groups // self.tp_world_size_ + self.embed_dim_ = self.hc_mult * self.hidden + self.enable_ep_moe = get_env_start_args().enable_ep_moe + self._state = {} # req_id -> dict(kv_hist, comp_kv, cstate_kv, cstate_score) + + # ------------------------------------------------------------------ forward (HC-wrapped) + def _hc_block(self, streams, infer_state, lw, attn_fn): + residual = streams + collapsed, post, comb = hc_pre( + streams, + lw.hc_attn_fn_.weight, + lw.hc_attn_scale_.weight, + lw.hc_attn_base_.weight, + self.hc_mult, + self.hidden, + self.hc_eps, + self.sinkhorn_iters, + ) + o = attn_fn(lw.attn_norm_(collapsed, eps=self.eps_), infer_state, lw) + streams = hc_post(o, residual, post, comb, self.hc_mult, self.hidden) + + residual = streams + collapsed, post, comb = hc_pre( + streams, + lw.hc_ffn_fn_.weight, + lw.hc_ffn_scale_.weight, + lw.hc_ffn_base_.weight, + self.hc_mult, + self.hidden, + self.hc_eps, + self.sinkhorn_iters, + ) + f = self._moe_ffn(lw.ffn_norm_(collapsed, eps=self.eps_), infer_state, lw) + return hc_post(f, residual, post, comb, self.hc_mult, self.hidden) + + def context_forward(self, streams, infer_state, lw): + return self._hc_block(streams, infer_state, lw, self._attention_prefill) + + def token_forward(self, streams, infer_state, lw): + return self._hc_block(streams, infer_state, lw, self._attention_decode) + + # ------------------------------------------------------------------ shared projections + def _qkv(self, x, cos_tok, sin_tok, lw): + T = x.shape[0] + qa = lw.q_norm_(lw.wq_a_.mm(x), eps=self.eps_) + q = lw.wq_b_.mm(qa).view(T, self.tp_q_heads, self.head_dim).float() + q = (q * torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps_)).to(x.dtype) + q = torch.cat( + [ + q[..., : -self.rope_dim], + apply_rotary_emb(q[..., -self.rope_dim :], cos_tok.unsqueeze(1), sin_tok.unsqueeze(1)), + ], + dim=-1, + ) + kv = lw.kv_norm_(lw.wkv_.mm(x), eps=self.eps_) + kv = torch.cat([kv[:, : -self.rope_dim], apply_rotary_emb(kv[:, -self.rope_dim :], cos_tok, sin_tok)], dim=1) + return q, kv + + def _out_proj(self, o, infer_state, lw): + # o: [T, tp_q_heads, head_dim] -> inverse rope -> grouped low-rank O -> [T, hidden] + T = o.shape[0] + o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] + o = lw.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] + o = lw.wo_b_.mm(o) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def _inv_rope(self, o, cos_tok, sin_tok): + return torch.cat( + [ + o[..., : -self.rope_dim], + apply_rotary_emb(o[..., -self.rope_dim :], cos_tok.unsqueeze(1), sin_tok.unsqueeze(1), inverse=True), + ], + dim=-1, + ) + + # ------------------------------------------------------------------ attention (prefill) + def _attention_prefill(self, x, infer_state, lw): + T = x.shape[0] + if self.compress_ratio: + cos_tok, sin_tok = infer_state.position_cos_compress, infer_state.position_sin_compress + else: + cos_tok, sin_tok = infer_state.position_cos_sliding, infer_state.position_sin_sliding + q, kv = self._qkv(x, cos_tok, sin_tok, lw) + sink = lw.attn_sink_.weight + o = x.new_empty(T, self.tp_q_heads, self.head_dim) + b_req = infer_state.b_req_idx.tolist() + starts = infer_state.b_q_start_loc.tolist() + lens = infer_state.b_q_seq_len.tolist() + for req, st, ln in zip(b_req, starts, lens): + q_r, kv_r, x_r = q[st : st + ln], kv[st : st + ln], x[st : st + ln] + kv_all, n_window, ncomp = self._gather_prefill(x_r, kv_r, req, lw, infer_state) + ti = self._topk_idxs_prefill(ln, n_window, ncomp, x.device) + o[st : st + ln] = torch_sparse_attn(q_r.unsqueeze(0), kv_all.unsqueeze(0), sink, ti, self.softmax_scale)[0] + return self._out_proj(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + + def _gather_prefill(self, x_r, kv_r, req, lw, infer_state): + ln = kv_r.shape[0] + if self.compress_ratio: + comp, ks, ss = compressor_prefill_state( + x_r, + lw.compressor_wkv_.mm_param.weight, + lw.compressor_wgate_.mm_param.weight, + lw.compressor_norm_.weight, + lw.compressor_ape_.weight, + self.compress_ratio, + self.head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + ) + self._state[req] = {"kv_hist": kv_r.detach(), "comp_kv": comp.detach(), "cstate_kv": ks, "cstate_score": ss} + return torch.cat([kv_r, comp], dim=0), ln, comp.shape[0] + self._state[req] = {"kv_hist": kv_r.detach()} + return kv_r, ln, 0 + + def _topk_idxs_prefill(self, seqlen, n_window, ncomp, device): + t = torch.arange(seqlen, device=device) + j = torch.arange(n_window, device=device) + win = torch.where( + (j.unsqueeze(0) <= t.unsqueeze(1)) & (j.unsqueeze(0) > (t.unsqueeze(1) - self.window)), + j.unsqueeze(0).expand(seqlen, n_window), + torch.full((seqlen, n_window), -1, device=device, dtype=torch.long), + ) + if ncomp: + c = torch.arange(ncomp, device=device) + comp = torch.where( + c.unsqueeze(0) < ((t.unsqueeze(1) + 1) // self.compress_ratio), + (c.unsqueeze(0) + n_window).expand(seqlen, ncomp), + torch.full((seqlen, ncomp), -1, device=device, dtype=torch.long), + ) + return torch.cat([win, comp], dim=1).int().unsqueeze(0) + return win.int().unsqueeze(0) + + # ------------------------------------------------------------------ attention (decode) + def _attention_decode(self, x, infer_state, lw): + B = x.shape[0] # one new token per request + if self.compress_ratio: + cos_tok, sin_tok = infer_state.position_cos_compress, infer_state.position_sin_compress + else: + cos_tok, sin_tok = infer_state.position_cos_sliding, infer_state.position_sin_sliding + q, kv = self._qkv(x, cos_tok, sin_tok, lw) # [B, heads, hd], [B, hd] + sink = lw.attn_sink_.weight + b_req = infer_state.b_req_idx.tolist() + seqlens = infer_state.b_seq_len.tolist() + o = x.new_empty(B, self.tp_q_heads, self.head_dim) + for i, (req, seq) in enumerate(zip(b_req, seqlens)): + stt = self._state[req] + stt["kv_hist"] = torch.cat([stt["kv_hist"], kv[i : i + 1]], dim=0) + start_pos = seq - 1 + if self.compress_ratio: + e = compressor_decode_step( + x[i], + lw.compressor_wkv_.mm_param.weight, + lw.compressor_wgate_.mm_param.weight, + lw.compressor_norm_.weight, + lw.compressor_ape_.weight, + self.compress_ratio, + self.head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + stt["cstate_kv"], + stt["cstate_score"], + start_pos, + ) + if e is not None: + stt["comp_kv"] = torch.cat([stt["comp_kv"], e.unsqueeze(0)], dim=0) + win_kv = stt["kv_hist"][-self.window :] + kv_all = torch.cat([win_kv, stt["comp_kv"]], dim=0) + else: + win_kv = stt["kv_hist"][-self.window :] + kv_all = win_kv + ti = torch.arange(kv_all.shape[0], device=x.device).view(1, 1, -1).int() + o[i] = torch_sparse_attn( + q[i].view(1, 1, self.tp_q_heads, self.head_dim), kv_all.unsqueeze(0), sink, ti, self.softmax_scale + )[0, 0] + return self._out_proj(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + + # ------------------------------------------------------------------ moe + def _fp4_experts(self, x, weights, indices, lw): + experts = lw.experts_ + out = torch.zeros(x.shape, device=x.device, dtype=torch.float32) + counts = torch.bincount(indices.reshape(-1), minlength=experts.n_routed_experts) + for expert_id in torch.nonzero(counts, as_tuple=False).flatten().tolist(): + token_idx, top_idx = torch.where(indices == expert_id) + if token_idx.numel() == 0: + continue + x_i = x[token_idx] + w1 = dequant_fp4_group_to_bf16(experts.w1[expert_id], experts.w1_scale[expert_id]) + w3 = dequant_fp4_group_to_bf16(experts.w3[expert_id], experts.w3_scale[expert_id]) + gate = F.linear(x_i, w1).float().clamp(max=self.swiglu_limit) + up = F.linear(x_i, w3).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + hidden = F.silu(gate) * up + hidden.mul_(weights[token_idx, top_idx].unsqueeze(-1)) + w2 = dequant_fp4_group_to_bf16(experts.w2[expert_id], experts.w2_scale[expert_id]) + out.index_add_(0, token_idx, F.linear(hidden.to(x.dtype), w2).float()) + return out.to(x.dtype) + + def _moe_ffn(self, x, infer_state, lw): + gw = lw.gate_weight_.mm_param.weight + scores = F.softplus(F.linear(x.float(), gw.float())).sqrt() # sqrtsoftplus + if self.is_hash: + indices = lw.gate_tid2eid_.weight[infer_state.input_ids.long()] + else: + indices = (scores + lw.gate_bias_.weight.unsqueeze(0)).topk(self.topk, dim=-1)[1] + weights = scores.gather(1, indices) + weights = (weights / (weights.sum(-1, keepdim=True) + 1e-20) * self.route_scale).to(torch.float32) + routed = self._fp4_experts(x, weights, indices.long(), lw) + g = lw.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) + u = lw.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + shared = lw.shared_down_.mm((F.silu(g) * u).to(x.dtype)) + if self.enable_ep_moe and getattr(lw.experts_, "is_ep", False): + if self.tp_world_size_ > 1: + all_reduce(shared, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return routed + shared + out = routed + shared + if self.tp_world_size_ > 1: + all_reduce(out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return out diff --git a/lightllm/models/deepseek_v4/layer_weights/__init__.py b/lightllm/models/deepseek_v4/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..54f29ce574 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,37 @@ +import torch +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + RMSNormWeight, + ParameterWeight, +) + + +class DeepseekV4PreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + + hidden = network_config["hidden_size"] + vocab = network_config["vocab_size"] + hc_mult = network_config["hc_mult"] + + # embeddings / lm_head / final norm (bf16, vocab tensor-parallel). V4 has no `model.` prefix + # and does not tie embeddings (tie_word_embeddings=false). + self.wte_weight_ = EmbeddingWeight( + dim=hidden, vocab_size=vocab, weight_name="embed.weight", data_type=self.data_type_ + ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden, vocab_size=vocab, weight_name="head.weight", data_type=self.data_type_ + ) + self.final_norm_weight_ = RMSNormWeight(dim=hidden, weight_name="norm.weight", data_type=self.data_type_) + + # final hyper-connection head (collapses the hc_mult residual streams before the lm_head) + self.hc_head_fn_ = ParameterWeight( + weight_name="hc_head_fn", data_type=torch.float32, weight_shape=(hc_mult, hc_mult * hidden) + ) + self.hc_head_base_ = ParameterWeight( + weight_name="hc_head_base", data_type=torch.float32, weight_shape=(hc_mult,) + ) + self.hc_head_scale_ = ParameterWeight(weight_name="hc_head_scale", data_type=torch.float32, weight_shape=(1,)) + return diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..7c12f714db --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -0,0 +1,398 @@ +import torch +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + ROWBMMWeight, + RMSNormWeight, + ParameterWeight, + TpAttSinkWeight, +) +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.quantization.registry import QUANTMETHODS +from ..triton_kernel.quant_convert import dequant_fp8_block_to_bf16 + + +class DeepseekV4FP4ExpertsWeight(BaseWeightTpl): + def __init__(self, weight_prefix, n_routed_experts, hidden_size, moe_intermediate_size, data_type): + super().__init__(data_type=data_type) + self.weight_prefix = weight_prefix + self.n_routed_experts = n_routed_experts + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.split_inter_size = moe_intermediate_size // self.tp_world_size_ + self.local_expert_ids = list(range(n_routed_experts)) + self.expert_idx_to_local_idx = {expert_idx: expert_idx for expert_idx in self.local_expert_ids} + self._create_weight() + + def _create_weight(self): + device = f"cuda:{self.device_id_}" + n = self.n_routed_experts + h = self.hidden_size + inter = self.split_inter_size + self.w1 = torch.empty((n, inter, h // 2), dtype=torch.int8, device=device) + self.w3 = torch.empty((n, inter, h // 2), dtype=torch.int8, device=device) + self.w2 = torch.empty((n, h, inter // 2), dtype=torch.int8, device=device) + self.w1_scale = torch.empty((n, inter, h // 32), dtype=torch.float8_e8m0fnu, device=device) + self.w3_scale = torch.empty((n, inter, h // 32), dtype=torch.float8_e8m0fnu, device=device) + self.w2_scale = torch.empty((n, h, inter // 32), dtype=torch.float8_e8m0fnu, device=device) + self.load_ok = { + name: [False] * n + for name in ("w1", "w1_scale", "w2", "w2_scale", "w3", "w3_scale") + } + + def _copy_expert_weight(self, dst, weight, expert_idx, name, is_down=False): + if is_down: + start = self.tp_rank_ * self.split_inter_size // 2 + end = (self.tp_rank_ + 1) * self.split_inter_size // 2 + src = weight[:, start:end] + else: + start = self.tp_rank_ * self.split_inter_size + end = (self.tp_rank_ + 1) * self.split_inter_size + src = weight[start:end, :] + dst[expert_idx].copy_(src) + self.load_ok[name][expert_idx] = True + + def _copy_expert_scale(self, dst, scale, expert_idx, name, is_down=False): + if is_down: + start = self.tp_rank_ * self.split_inter_size // 32 + end = (self.tp_rank_ + 1) * self.split_inter_size // 32 + src = scale[:, start:end] + else: + start = self.tp_rank_ * self.split_inter_size + end = (self.tp_rank_ + 1) * self.split_inter_size + src = scale[start:end, :] + dst[expert_idx].copy_(src) + self.load_ok[name][expert_idx] = True + + def load_hf_weights(self, weights): + for expert_idx in self.local_expert_ids: + prefix = f"{self.weight_prefix}.{expert_idx}" + w1 = f"{prefix}.w1.weight" + w1_scale = f"{prefix}.w1.scale" + w2 = f"{prefix}.w2.weight" + w2_scale = f"{prefix}.w2.scale" + w3 = f"{prefix}.w3.weight" + w3_scale = f"{prefix}.w3.scale" + if w1 in weights: + self._copy_expert_weight(self.w1, weights[w1], expert_idx, "w1") + if w1_scale in weights: + self._copy_expert_scale(self.w1_scale, weights[w1_scale], expert_idx, "w1_scale") + if w3 in weights: + self._copy_expert_weight(self.w3, weights[w3], expert_idx, "w3") + if w3_scale in weights: + self._copy_expert_scale(self.w3_scale, weights[w3_scale], expert_idx, "w3_scale") + if w2 in weights: + self._copy_expert_weight(self.w2, weights[w2], expert_idx, "w2", is_down=True) + if w2_scale in weights: + self._copy_expert_scale(self.w2_scale, weights[w2_scale], expert_idx, "w2_scale", is_down=True) + + def verify_load(self): + return all(all(ok_list) for ok_list in self.load_ok.values()) + + +class DeepseekV4TransformerLayerWeight(TransformerLayerWeight): + """Per-layer weights for DeepSeek-V4-Flash. + + The checkpoint stores most linears in FP8 (e4m3 + block-128 ue8m0 scale) and the routed + experts in FP4 (int8-packed e2m1 + group-32 ue8m0 scale). Hopper does not use the SM100 + MegaMoE path here, so routed experts are kept in packed FP4 and temporarily de-quantized only + for selected experts in the correctness-first torch MoE path. + """ + + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _parse_config(self): + cfg = self.network_config_ + self.fp8_quant = QUANTMETHODS.get("deepgemm-fp8w8a8-b128") + self.hidden = cfg["hidden_size"] + self.n_heads = cfg["num_attention_heads"] + self.head_dim = cfg["head_dim"] + self.rope_dim = cfg["qk_rope_head_dim"] + self.q_lora_rank = cfg["q_lora_rank"] + self.o_lora_rank = cfg["o_lora_rank"] + self.o_groups = cfg["o_groups"] + self.index_n_heads = cfg["index_n_heads"] + self.index_head_dim = cfg["index_head_dim"] + self.n_routed_experts = cfg["n_routed_experts"] + self.moe_inter = cfg["moe_intermediate_size"] + self.num_hash_layers = cfg["num_hash_layers"] + self.vocab_size = cfg["vocab_size"] + self.hc_mult = cfg["hc_mult"] + self.mix_hc = (2 + self.hc_mult) * self.hc_mult + self.compress_ratio = cfg["compress_ratios"][self.layer_num_] + self.has_compressor = self.compress_ratio != 0 + self.has_indexer = self.compress_ratio == 4 + self.is_hash = self.layer_num_ < self.num_hash_layers + assert self.n_heads % self.tp_world_size_ == 0 + assert self.o_groups % self.tp_world_size_ == 0 + assert self.index_n_heads % self.tp_world_size_ == 0 + self.prefix = f"layers.{self.layer_num_}" + + def _init_weight_names(self): + return + + def _init_weight(self): + self._init_attn() + if self.has_compressor: + self._init_compressor(f"{self.prefix}.attn.compressor", self.head_dim, self.compress_ratio) + if self.has_indexer: + self._init_indexer() + self._init_moe() + self._init_norm() + self._init_hyper_connection() + + # ------------------------------------------------------------------ attention + def _init_attn(self): + p = f"{self.prefix}.attn" + # q low-rank (a replicated, b column-parallel over heads), kv single head (replicated) + self.wq_a_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.q_lora_rank], + weight_names=f"{p}.wq_a.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + tp_rank=0, + tp_world_size=1, + ) + self.wq_b_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.n_heads * self.head_dim], + weight_names=f"{p}.wq_b.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + ) + self.wkv_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.head_dim], + weight_names=f"{p}.wkv.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + tp_rank=0, + tp_world_size=1, + ) + self.q_norm_ = RMSNormWeight(dim=self.q_lora_rank, weight_name=f"{p}.q_norm.weight", data_type=self.data_type_) + self.kv_norm_ = RMSNormWeight(dim=self.head_dim, weight_name=f"{p}.kv_norm.weight", data_type=self.data_type_) + self.attn_sink_ = TpAttSinkWeight( + all_q_head_num=self.n_heads, weight_name=f"{p}.attn_sink", data_type=torch.float32 + ) + # grouped low-rank output projection: wo_a is a per-group batched matmul [groups, in, o_lora], + # wo_b is row-parallel [groups*o_lora -> hidden]. wo_a is reshaped in load_hf_weights. + per_group_in = self.n_heads * self.head_dim // self.o_groups + self.wo_a_ = ROWBMMWeight( + dim0=self.o_groups, + dim1=per_group_in, + dim2=self.o_lora_rank, + weight_names=f"{p}.wo_a.weight", + data_type=self.data_type_, + quant_method=None, + ) + self.wo_b_ = COLMMWeight( + in_dim=self.o_groups * self.o_lora_rank, + out_dims=[self.hidden], + weight_names=f"{p}.wo_b.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + ) + + # ------------------------------------------------------------------ compressor / indexer + def _init_compressor(self, prefix, head_dim, ratio): + coff = 2 if ratio == 4 else 1 + # wkv/wgate are bf16 (no scale) and replicated (single KV head). + self.compressor_wkv_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * head_dim], + weight_names=f"{prefix}.wkv.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.compressor_wgate_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * head_dim], + weight_names=f"{prefix}.wgate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.compressor_norm_ = RMSNormWeight( + dim=head_dim, weight_name=f"{prefix}.norm.weight", data_type=self.data_type_ + ) + self.compressor_ape_ = ParameterWeight( + weight_name=f"{prefix}.ape", data_type=torch.float32, weight_shape=(ratio, coff * head_dim) + ) + + def _init_indexer(self): + p = f"{self.prefix}.attn.indexer" + # wq_b is FP8 in the checkpoint -> de-quantized to bf16 at load; column-parallel over index heads. + self.idx_wq_b_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{p}.wq_b.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + ) + self.idx_weights_proj_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.index_n_heads], + weight_names=f"{p}.weights_proj.weight", + data_type=self.data_type_, + quant_method=None, + ) + coff = 2 # indexer compressor always uses ratio 4 (overlap) + self.idx_cmp_wkv_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * self.index_head_dim], + weight_names=f"{p}.compressor.wkv.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.idx_cmp_wgate_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * self.index_head_dim], + weight_names=f"{p}.compressor.wgate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.idx_cmp_norm_ = RMSNormWeight( + dim=self.index_head_dim, weight_name=f"{p}.compressor.norm.weight", data_type=self.data_type_ + ) + self.idx_cmp_ape_ = ParameterWeight( + weight_name=f"{p}.compressor.ape", data_type=torch.float32, weight_shape=(4, coff * self.index_head_dim) + ) + + # ------------------------------------------------------------------ moe + def _init_moe(self): + p = f"{self.prefix}.ffn" + # router gate (replicated) + self.gate_weight_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.n_routed_experts], + weight_names=f"{p}.gate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + if self.is_hash: + self.gate_tid2eid_ = ParameterWeight( + weight_name=f"{p}.gate.tid2eid", + data_type=torch.int64, + weight_shape=(self.vocab_size, self.network_config_["num_experts_per_tok"]), + ) + else: + self.gate_bias_ = ParameterWeight( + weight_name=f"{p}.gate.bias", data_type=torch.float32, weight_shape=(self.n_routed_experts,) + ) + # shared expert (dense, bf16 after de-quant): w1=gate, w3=up (row), w2=down (col) + sp = f"{p}.shared_experts" + self.shared_gate_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.moe_inter], + weight_names=f"{sp}.w1.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + ) + self.shared_up_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.moe_inter], + weight_names=f"{sp}.w3.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + ) + self.shared_down_ = COLMMWeight( + in_dim=self.moe_inter, + out_dims=[self.hidden], + weight_names=f"{sp}.w2.weight", + data_type=self.data_type_, + quant_method=self.fp8_quant, + ) + self.experts_ = DeepseekV4FP4ExpertsWeight( + weight_prefix=f"{p}.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.hidden, + moe_intermediate_size=self.moe_inter, + data_type=self.data_type_, + ) + + def _init_norm(self): + self.attn_norm_ = RMSNormWeight( + dim=self.hidden, weight_name=f"{self.prefix}.attn_norm.weight", data_type=self.data_type_ + ) + self.ffn_norm_ = RMSNormWeight( + dim=self.hidden, weight_name=f"{self.prefix}.ffn_norm.weight", data_type=self.data_type_ + ) + + def _init_hyper_connection(self): + for which in ["attn", "ffn"]: + setattr( + self, + f"hc_{which}_fn_", + ParameterWeight( + weight_name=f"{self.prefix}.hc_{which}_fn", + data_type=torch.float32, + weight_shape=(self.mix_hc, self.hc_mult * self.hidden), + ), + ) + setattr( + self, + f"hc_{which}_base_", + ParameterWeight( + weight_name=f"{self.prefix}.hc_{which}_base", data_type=torch.float32, weight_shape=(self.mix_hc,) + ), + ) + setattr( + self, + f"hc_{which}_scale_", + ParameterWeight( + weight_name=f"{self.prefix}.hc_{which}_scale", data_type=torch.float32, weight_shape=(3,) + ), + ) + + # ------------------------------------------------------------------ loading + def load_hf_weights(self, weights): + self._dequant_in_place(weights) + return super().load_hf_weights(weights) + + def _direct_fp8_weight_names(self): + names = set() + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + quant_method = getattr(attr, "quant_method", None) + if getattr(quant_method, "method_name", None) == "deepgemm-fp8w8a8-b128": + names.update(getattr(attr, "weight_names", [])) + return names + + def _dequant_in_place(self, weights): + p = self.prefix + "." + direct_fp8_names = self._direct_fp8_weight_names() + # Convert every (weight, scale) pair belonging to this layer. Existing FP8 matmul + # weights stay quantized; bmm-only weights are expanded; routed FP4 experts stay packed. + for k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".weight")]: + scale_k = k[: -len(".weight")] + ".scale" + if scale_k not in weights: + continue + w, s = weights[k], weights[scale_k] + if w.dtype == torch.int8: # FP4 routed experts stay packed for DeepseekV4FP4ExpertsWeight. + continue + elif k in direct_fp8_names: # FP8 e4m3, block-128 scale, run by DeepGEMM directly + weights[k.replace("weight", "weight_scale_inv")] = s.to(torch.float32) + del weights[scale_k] + else: # FP8 e4m3 for no-quant paths such as ROWBMMWeight + weights[k] = dequant_fp8_block_to_bf16(w, s).to(self.data_type_) + del weights[scale_k] + # grouped-O: reshape [groups*o_lora, in] -> [groups, in, o_lora] for the batched matmul + woa = f"{self.prefix}.attn.wo_a.weight" + if woa in weights and weights[woa].dim() == 2: + w = weights[woa] + per_group_in = self.n_heads * self.head_dim // self.o_groups + weights[woa] = w.view(self.o_groups, self.o_lora_rank, per_group_in).transpose(1, 2).contiguous() + return diff --git a/lightllm/models/deepseek_v4/mem_manager.py b/lightllm/models/deepseek_v4/mem_manager.py new file mode 100644 index 0000000000..288d433380 --- /dev/null +++ b/lightllm/models/deepseek_v4/mem_manager.py @@ -0,0 +1,12 @@ +from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + + +class DeepseekV4MemoryManager(Deepseek2MemoryManager): + """Stores the per-token MLA KV (head_num=1, head_dim=512), reusing the deepseek2 layout/operator. + + The prefill path computes attention in-layer from the request's hidden states, so it does not read + this buffer. The decode/incremental path (M6) will add the sliding-window ring + compressed-KV + + per-request compressor-state buffers here. + """ + + pass diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py new file mode 100644 index 0000000000..02c71f01b2 --- /dev/null +++ b/lightllm/models/deepseek_v4/model.py @@ -0,0 +1,121 @@ +import torch +from lightllm.models.registry import ModelRegistry +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.common.req_manager import ReqManager, DeepseekV4ReqManager +from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend +from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import DeepseekV4PreAndPostLayerWeight +from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight +from lightllm.models.deepseek_v4.layer_infer.pre_layer_infer import DeepseekV4PreLayerInfer +from lightllm.models.deepseek_v4.layer_infer.post_layer_infer import DeepseekV4PostLayerInfer +from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import DeepseekV4TransformerLayerInfer +from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo +from lightllm.models.llama.yarn_rotary_utils import find_correction_range, linear_ramp_mask +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager + +logger = init_logger(__name__) + + +@ModelRegistry("deepseek_v4") +class DeepseekV4TpPartModel(LlamaTpPartModel): + + pre_and_post_weight_class = DeepseekV4PreAndPostLayerWeight + transformer_weight_class = DeepseekV4TransformerLayerWeight + + pre_layer_infer_class = DeepseekV4PreLayerInfer + post_layer_infer_class = DeepseekV4PostLayerInfer + transformer_layer_infer_class = DeepseekV4TransformerLayerInfer + + infer_state_class = DeepseekV4InferStateInfo + + def _verify_params(self): + assert self.load_way == "HF", "only support HF format weights" + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + assert self.config["o_groups"] % self.tp_world_size_ == 0 + assert self.config["index_n_heads"] % self.tp_world_size_ == 0 + return + + def _init_some_value(self): + super()._init_some_value() + self.head_dim_ = self.config["head_dim"] + return + + def _init_req_manager(self): + create_max_seq_len = 0 + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self._dsv4_req_manager_seq_len = create_max_seq_len + self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, None) + return + + def _get_compress_rates(self, layer_num): + rates = list(self.config.get("compress_ratios", [])) + if len(rates) < layer_num: + rates.extend([0] * (layer_num - len(rates))) + return rates[:layer_num] + + def _init_mem_manager(self): + layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num() + self.mem_manager = DeepseekV4MemoryManager( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["head_dim"], + layer_num=layer_num, + compress_rates=self._get_compress_rates(layer_num), + indexer_head_dim=self.config["index_head_dim"], + mem_fraction=self.mem_fraction, + ) + self.req_manager = DeepseekV4ReqManager( + self.max_req_num, self._dsv4_req_manager_seq_len, self.mem_manager + ) + return + + def _init_att_backend(self): + self.prefill_att_backend = TritonAttBackend(model=self) + self.decode_att_backend = TritonAttBackend(model=self) + return + + def _init_custom(self): + self._init_to_get_rotary() + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) + return + + def _init_to_get_rotary(self): + # Interleaved (GPT-J) rope. Build real cos/sin tables (_cos_cached_*/_sin_cached_*) following the + # gemma4 two-variant convention; the infer-struct slices them into position_cos_*/position_sin_* + # and apply_rotary_emb (interleaved, NOT the NeoX rotary_emb_fwd) applies them. Sliding-window + # layers use base rope_theta (no YaRN); compressed (CSA/HCA) layers use compress_rope_theta with + # YaRN. Tables kept fp32 for accuracy (the apply upcasts anyway). + cfg = self.config + rs = cfg.get("rope_scaling", {}) or {} + dim = cfg["qk_rope_head_dim"] + beta_fast = rs.get("beta_fast", 32) + beta_slow = rs.get("beta_slow", 1) + max_seq = max(int(self.max_seq_length), int(cfg.get("max_position_embeddings", 8192))) + max_seq = min(max_seq, 1 << 18) # cap table size (256K) for correctness-first + + def build(base, factor, orig_max): + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim)) + if orig_max > 0: + low, high = find_correction_range(beta_fast, beta_slow, dim, base, orig_max) + smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda() + freqs = freqs / factor * (1 - smooth) + freqs * smooth + f = torch.outer(torch.arange(max_seq, dtype=torch.float32, device="cuda"), freqs) # [max_seq, dim//2] + return f.cos(), f.sin() + + self._cos_cached_sliding, self._sin_cached_sliding = build(cfg["rope_theta"], rs.get("factor", 1.0), 0) + self._cos_cached_compress, self._sin_cached_compress = build( + cfg["compress_rope_theta"], rs.get("factor", 16), rs.get("original_max_position_embeddings", 65536) + ) + return diff --git a/lightllm/models/deepseek_v4/triton_kernel/__init__.py b/lightllm/models/deepseek_v4/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py new file mode 100644 index 0000000000..c7d2d59ec6 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py @@ -0,0 +1,93 @@ +import torch + +# DeepSeek-V4-Flash ships weights in two quantized formats: +# * non-expert linears: FP8 e4m3 with block-[128,128] scales stored as float8_e8m0fnu (ue8m0) +# * routed experts: FP4 e2m1 packed 2-per-byte (stored as int8) with group-32 ue8m0 scales +# Hopper (H200) has no native SM100 MegaMoE path. Non-expert FP8 weights can run directly through +# DeepGEMM. Routed FP4 experts are converted blockwise to FP8, avoiding a full bf16 expansion. + +# OCP E2M1 magnitude table for the 3 low bits (sign = bit 3). torch.float4_e2m1fn_x2 packs two +# such codes per byte, low nibble = lower (even) logical index. +_E2M1_MAG = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + + +def e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + """float8_e8m0fnu encodes 2**(byte-127); torch decodes it correctly on .to(float32).""" + return scale.to(torch.float32) + + +def dequant_fp8_block_to_bf16(weight_e4m3: torch.Tensor, scale_e8m0: torch.Tensor, block_size: int = 128): + """De-quantize an FP8 e4m3 weight [out, in] with block-[bs,bs] ue8m0 scale to bf16.""" + from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant + + w = weight_e4m3.cuda().contiguous() + s = e8m0_to_fp32(scale_e8m0).cuda().contiguous() + # weight_dequant runs with torch default dtype for the output; force bf16 result. + return weight_dequant(w, s, block_size) + + +def cast_e2m1fn_to_e4m3fn(weight_int8: torch.Tensor, scale_e8m0: torch.Tensor): + """Cast packed FP4 e2m1 expert weights to FP8 e4m3 with block-128 fp32 scales. + + This follows the DeepSeek-V4 reference converter, but returns the scale in fp32 because + LightLLM's DeepGEMM FP8 weight pack stores block scales as fp32. + """ + assert weight_int8.dtype == torch.int8 + assert weight_int8.ndim == 2 + out_dim, packed_in = weight_int8.shape + in_dim = packed_in * 2 + fp8_block_size = 128 + fp4_block_size = 32 + assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0 + assert scale_e8m0.shape[0] == out_dim + assert scale_e8m0.shape[1] == in_dim // fp4_block_size + + table = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32, + device=weight_int8.device, + ) + packed = weight_int8.view(torch.uint8) + low = packed & 0x0F + high = (packed >> 4) & 0x0F + vals = torch.stack([table[low.long()], table[high.long()]], dim=-1).reshape(out_dim, in_dim) + + # 6.0 * 2**6 fits in e4m3fn (384 < 448), while 6.0 * 2**7 would overflow. + max_offset_bits = 6 + block_out = out_dim // fp8_block_size + block_in = in_dim // fp8_block_size + + vals = vals.view(block_out, fp8_block_size, block_in, fp8_block_size).transpose(1, 2) + scale = scale_e8m0.float().view(block_out, fp8_block_size, block_in, -1).transpose(1, 2).flatten(2) + block_scale = scale.amax(dim=-1, keepdim=True) / (2**max_offset_bits) + offset = scale / block_scale + offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1) + vals = (vals * offset).transpose(1, 2).reshape(out_dim, in_dim) + block_scale = block_scale.squeeze(-1).to(torch.float8_e8m0fnu).to(torch.float32) + return vals.to(torch.float8_e4m3fn), block_scale + + +def dequant_fp4_group_to_bf16(weight_int8: torch.Tensor, scale_e8m0: torch.Tensor, group_size: int = 32): + """De-quantize an int8-packed FP4 e2m1 weight to bf16. + + weight_int8: [out, in // 2] int8 (two e2m1 codes per byte, low nibble = even index). + scale_e8m0: [out, in // group_size] ue8m0 (one scale per group_size logical elements along K). + returns: [out, in] bf16. + """ + w = weight_int8.cuda() + out, packed_in = w.shape + in_dim = packed_in * 2 + b = w.to(torch.int32).bitwise_and(0xFF) + lut = torch.tensor(_E2M1_MAG, dtype=torch.float32, device=w.device) + + def _decode(nib: torch.Tensor) -> torch.Tensor: + mag = lut[nib.bitwise_and(0x7)] + neg = nib.bitwise_and(0x8).bool() + return torch.where(neg, -mag, mag) + + lo = _decode(b.bitwise_and(0xF)) + hi = _decode(b.bitwise_right_shift(4).bitwise_and(0xF)) + vals = torch.stack([lo, hi], dim=-1).reshape(out, in_dim) # [out, in] + s = e8m0_to_fp32(scale_e8m0).cuda() # [out, in//group_size] + s = s.repeat_interleave(group_size, dim=1)[:, :in_dim] + return (vals * s).to(torch.bfloat16) diff --git a/lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py b/lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py new file mode 100644 index 0000000000..cb50977446 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py @@ -0,0 +1,26 @@ +import torch + +# Interleaved (GPT-J) rotary application for DeepSeek-V4. Unlike llama/gemma's NeoX-style +# rotary_emb_fwd (rotate-half: pairs channel i with i+d/2 over a real cos/sin table), V4 rotates +# adjacent pairs (x0,x1),(x2,x3),... — a different channel pairing — so it cannot reuse +# rotary_emb_fwd, but it consumes the same real cos/sin tables (built in model.py:_init_to_get_rotary +# as _cos_cached_*/_sin_cached_*, gemma4-style). Correctness-first pure-torch; a fused triton port is +# a perf follow-up. + + +def apply_rotary_emb(x, cos, sin, inverse=False): + """Apply interleaved rope to the LAST dim of x (size = 2*cos.size(-1)). + + x: [..., rope_dim] (real). cos/sin: [..., rope_dim//2], broadcastable to x's paired view. + For x of shape [N, H, rope_dim], pass cos/sin [N, 1, rope_dim//2]; for [N, rope_dim] pass [N, rope_dim//2]. + Returns a new tensor of x's dtype (not in-place). inverse=True applies the conjugate rotation. + """ + dtype = x.dtype + x = x.float().reshape(*x.shape[:-1], -1, 2) + x0, x1 = x[..., 0], x[..., 1] + cos = cos.float() + sin = sin.float() + if inverse: + sin = -sin + out = torch.stack([x0 * cos - x1 * sin, x0 * sin + x1 * cos], dim=-1) + return out.flatten(-2).to(dtype) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c1..5e90c9b34a 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -590,6 +590,8 @@ def _init_all_state(self): self.cur_output_len = 0 g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self) + if hasattr(g_infer_context.req_manager, "init_compress_state"): + g_infer_context.req_manager.init_compress_state(req_idx=self.req_idx) self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list() # token healing mode 才被使用的管理对象 From d790ad2407f75308fa30cf73f91dfbd928dd4306 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 5 Jun 2026 01:30:51 +0000 Subject: [PATCH 02/30] Optimization --- lightllm/__init__.py | 27 + .../deepseek4_mem_manager.py | 639 ++++++++++++++++-- .../kv_cache_mem_manager/operator/__init__.py | 1 + .../kv_cache_mem_manager/operator/deepseek.py | 19 +- lightllm/common/quantization/__init__.py | 3 + lightllm/common/req_manager.py | 279 +++++++- lightllm/models/deepseek3_2/model.py | 71 +- .../deepseek_v4/layer_infer/attention.py | 123 +++- .../deepseek_v4/layer_infer/compressor.py | 327 ++++++++- .../layer_infer/hyper_connection.py | 86 ++- .../layer_infer/transformer_layer_infer.py | 636 +++++++++++++++-- .../layer_weights/transformer_layer_weight.py | 155 ++++- lightllm/models/deepseek_v4/model.py | 144 +++- lightllm/server/api_start.py | 53 +- .../server/router/model_infer/infer_batch.py | 76 ++- lightllm/server/tokenizer.py | 5 + 16 files changed, 2295 insertions(+), 349 deletions(-) diff --git a/lightllm/__init__.py b/lightllm/__init__.py index e9ba6f3041..8e515afb70 100644 --- a/lightllm/__init__.py +++ b/lightllm/__init__.py @@ -1,4 +1,31 @@ from lightllm.utils.device_utils import is_musa + +def _patch_mp_resource_tracker_for_semaphore(): + from multiprocessing import resource_tracker + + if getattr(resource_tracker, "_lightllm_ignore_semaphore", False): + return + + orig_register = resource_tracker.register + orig_unregister = resource_tracker.unregister + + def register(name, rtype): + if rtype == "semaphore": + return + return orig_register(name, rtype) + + def unregister(name, rtype): + if rtype == "semaphore": + return + return orig_unregister(name, rtype) + + resource_tracker.register = register + resource_tracker.unregister = unregister + resource_tracker._lightllm_ignore_semaphore = True + + +_patch_mp_resource_tracker_for_semaphore() + if is_musa(): import torchada # noqa: F401 diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 900b551cec..739e0bd51a 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -1,44 +1,208 @@ import torch +import torch.distributed as dist from typing import Dict, List, Optional from .deepseek2_mem_manager import Deepseek2MemoryManager +from .operator import DeepseekV4MemOperator from .allocator import KvCacheAllocator -from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.log_utils import init_logger +from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory logger = init_logger(__name__) -class _SubKvPool: - """DeepSeek-V4 压缩分支(c4 / c128)使用的轻量子池。 - - 一个独立的 KvCacheAllocator + 一块压缩 latent buffer,可选附带一块与 latent 1:1 的 - indexer-K buffer(仅 c4/CSA 层用)。刻意不继承 MemoryManager —— pd/shm/kv_move 等机制 - 对压缩池暂不需要,保持最小。布局与主 MLA latent 池一致(每槽多预留 1 行作 padding 哨兵)。 +DSV4_MLA_NOPE_DIM = 448 +DSV4_MLA_ROPE_DIM = 64 +DSV4_MLA_HEAD_DIM = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM +DSV4_MLA_QUANT_GROUP_SIZE = 64 +DSV4_MLA_SCALE_BYTES = DSV4_MLA_NOPE_DIM // DSV4_MLA_QUANT_GROUP_SIZE + 1 +DSV4_MLA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 + DSV4_MLA_SCALE_BYTES +DSV4_INDEXER_HEAD_DIM = 128 +DSV4_INDEXER_BYTES_PER_TOKEN = DSV4_INDEXER_HEAD_DIM + 4 +DSV4_FP8_E4M3_MAX = 448.0 +DSV4_FP8_SCALE_MIN = 1e-4 +DSV4_MLA_DATA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 +DSV4_MLA_SCALE_TAIL_BYTES = DSV4_MLA_SCALE_BYTES +DSV4_MLA_PAGE_ALIGN_BYTES = DSV4_MLA_DATA_BYTES_PER_TOKEN +DSV4_SWA_PAGE_SIZE = 128 +DSV4_C4_PAGE_SIZE = 64 +DSV4_C128_PAGE_SIZE = 2 +DSV4_PROFILE_MAX_FULL_TOKENS = 2_000_000 + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +class _PageSlabMlaPool: + """SGLang-compatible fp8_ds_mla page-slab storage with token-slot addressing. + + The public loc is still a LightLLM token slot. Internally each page stores all + 576B NoPE+RoPE payloads first and the 8B scale records at the page tail: + data_offset = page * bytes_per_page + token_in_page * 576 + scale_offset = page * bytes_per_page + page_size * 576 + token_in_page * 8 """ def __init__( self, size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, + page_size: int, layer_num: int, - indexer_head_dim: int = 0, - shared_name: Optional[str] = None, device: str = "cuda", ): self.size = size - self.dtype = dtype - self.head_num = head_num - self.head_dim = head_dim + self.page_size = page_size self.layer_num = layer_num - self.indexer_head_dim = indexer_head_dim + self.dtype = torch.uint8 + self.data_bytes_per_token = DSV4_MLA_DATA_BYTES_PER_TOKEN + self.scale_bytes_per_token = DSV4_MLA_SCALE_TAIL_BYTES + self.bytes_per_token = DSV4_MLA_BYTES_PER_TOKEN + self.num_pages = _ceil_div(size + 1, page_size) + self.bytes_per_page = ( + _ceil_div(page_size * self.bytes_per_token, DSV4_MLA_PAGE_ALIGN_BYTES) * DSV4_MLA_PAGE_ALIGN_BYTES + ) + self.scale_offset_in_page = page_size * self.data_bytes_per_token + self.kv_buffer = torch.zeros( + (layer_num, self.num_pages, self.bytes_per_page), + dtype=torch.uint8, + device=device, + ) + self.HOLD_TOKEN_MEMINDEX = size + + def _loc_offsets(self, loc: torch.Tensor): + loc = loc.long() + page = torch.div(loc, self.page_size, rounding_mode="floor") + token = loc % self.page_size + page_base = page * self.bytes_per_page + data_offsets = page_base + token * self.data_bytes_per_token + scale_offsets = page_base + self.scale_offset_in_page + token * self.scale_bytes_per_token + return data_offsets, scale_offsets + + def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None: + if loc.numel() == 0: + return + loc = loc.long() + packed = packed.reshape(-1, DSV4_MLA_BYTES_PER_TOKEN).contiguous() + flat = self.kv_buffer[layer_index].view(-1) + data_offsets, scale_offsets = self._loc_offsets(loc) + + data = packed[:, : self.data_bytes_per_token].contiguous() + scale = packed[:, self.data_bytes_per_token : self.bytes_per_token].contiguous() + data_range = torch.arange(self.data_bytes_per_token, device=loc.device) + scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device) + flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] = data + flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = scale + return + + def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor: + loc = loc.long() + if loc.numel() == 0: + return torch.empty((0, DSV4_MLA_BYTES_PER_TOKEN), dtype=torch.uint8, device=self.kv_buffer.device) + flat = self.kv_buffer[layer_index].view(-1) + data_offsets, scale_offsets = self._loc_offsets(loc) + data_range = torch.arange(self.data_bytes_per_token, device=loc.device) + scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device) + data = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] + scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] + return torch.cat([data, scale], dim=1).contiguous() + + def get_layer_buffer(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index] + - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device=device) - if indexer_head_dim > 0: - self.index_k_buffer = torch.empty((layer_num, size + 1, indexer_head_dim), dtype=dtype, device=device) +class _PageSlabIndexerPool: + """C4 indexer-K storage: page tail stores per-token fp32 scales.""" + + def __init__( + self, + size: int, + page_size: int, + layer_num: int, + device: str = "cuda", + ): + self.size = size + self.page_size = page_size + self.layer_num = layer_num + self.head_dim = DSV4_INDEXER_HEAD_DIM + self.scale_bytes = 4 + self.bytes_per_token = DSV4_INDEXER_BYTES_PER_TOKEN + self.num_pages = _ceil_div(size + 1, page_size) + self.bytes_per_page = page_size * self.bytes_per_token + self.scale_offset_in_page = page_size * self.head_dim + self.index_k_buffer = torch.zeros( + (layer_num, self.num_pages, self.bytes_per_page), + dtype=torch.uint8, + device=device, + ) + self.HOLD_TOKEN_MEMINDEX = size + + def _loc_offsets(self, loc: torch.Tensor): + loc = loc.long() + page = torch.div(loc, self.page_size, rounding_mode="floor") + token = loc % self.page_size + page_base = page * self.bytes_per_page + k_offsets = page_base + token * self.head_dim + scale_offsets = page_base + self.scale_offset_in_page + token * self.scale_bytes + return k_offsets, scale_offsets + + def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None: + if loc.numel() == 0: + return + loc = loc.long() + packed = packed.reshape(-1, self.bytes_per_token).contiguous() + flat = self.index_k_buffer[layer_index].view(-1) + k_offsets, scale_offsets = self._loc_offsets(loc) + k_range = torch.arange(self.head_dim, device=loc.device) + scale_range = torch.arange(self.scale_bytes, device=loc.device) + flat[k_offsets.unsqueeze(1) + k_range.unsqueeze(0)] = packed[:, : self.head_dim] + flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = packed[:, self.head_dim :] + return + + def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor: + loc = loc.long() + if loc.numel() == 0: + return torch.empty((0, self.bytes_per_token), dtype=torch.uint8, device=self.index_k_buffer.device) + flat = self.index_k_buffer[layer_index].view(-1) + k_offsets, scale_offsets = self._loc_offsets(loc) + k_range = torch.arange(self.head_dim, device=loc.device) + scale_range = torch.arange(self.scale_bytes, device=loc.device) + k = flat[k_offsets.unsqueeze(1) + k_range.unsqueeze(0)] + scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] + return torch.cat([k, scale], dim=1).contiguous() + + def get_layer_buffer(self, layer_index: int) -> torch.Tensor: + return self.index_k_buffer[layer_index] + + +class _SubKvPool: + """Compressed c4/c128 KV pool with token-slot allocator and page-slab backing.""" + + def __init__( + self, + size: int, + page_size: int, + layer_num: int, + with_indexer: bool = False, + shared_name: Optional[str] = None, + device: str = "cuda", + ): + self.size = size + self.dtype = torch.uint8 + self.layer_num = layer_num + self.page_size = page_size + self.mla_pool = _PageSlabMlaPool(size=size, page_size=page_size, layer_num=layer_num, device=device) + self.kv_buffer = self.mla_pool.kv_buffer + if with_indexer: + self.indexer_pool = _PageSlabIndexerPool( + size=size, + page_size=page_size, + layer_num=layer_num, + device=device, + ) + self.index_k_buffer = self.indexer_pool.index_k_buffer else: + self.indexer_pool = None self.index_k_buffer = None self.allocator = KvCacheAllocator(size, shared_name=shared_name) @@ -54,27 +218,51 @@ def free_all(self) -> None: self.allocator.free_all() def get_kv_buffer(self, layer_index: int) -> torch.Tensor: - return self.kv_buffer[layer_index] + return self.mla_pool.get_layer_buffer(layer_index) def get_index_k_buffer(self, layer_index: int) -> torch.Tensor: - assert self.index_k_buffer is not None, "this sub pool has no indexer-K buffer" - return self.index_k_buffer[layer_index] + assert self.indexer_pool is not None, "this sub pool has no indexer-K buffer" + return self.indexer_pool.get_layer_buffer(layer_index) + + def write_kv(self, layer_index: int, slots: torch.Tensor, packed: torch.Tensor) -> None: + self.mla_pool.write(layer_index, slots, packed) + + def read_kv(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + return self.mla_pool.read(layer_index, slots) + + def write_indexer_k(self, layer_index: int, slots: torch.Tensor, packed: torch.Tensor) -> None: + assert self.indexer_pool is not None + self.indexer_pool.write(layer_index, slots, packed) + + def read_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + assert self.indexer_pool is not None + return self.indexer_pool.read(layer_index, slots) class DeepseekV4MemoryManager(Deepseek2MemoryManager): - """DeepSeek-V4 KV 管理(锁定决策: SWA 全历史 + 不分页)。 - - - dense/SWA latent: 继承 Deepseek2 的单张量 MLA latent ``kv_buffer``(每 token 一槽,所有层 - 共享层轴,head_num==1)。SWA 分支靠 layer_infer 传 ``AttControl(use_sliding_window)`` + attn_sink - 读最近窗口;dense 槽为纯 latent,不挂 indexer-K(与 V3.2 区别)。 - - c4_pool / c128_pool: 两个独立 ``_SubKvPool``(window 粒度,1-token 分配)。c4 池附带 indexer-K。 - - 容量: 用闭式 ``get_cell_size()``(= 每个 dense token 在所有池上的总字节)让基类 ``profile_size`` - 直接得到 full_token = dense 池大小,再按 1/4、1/128 派生压缩池大小。 - - compressor 递归状态不在这里,放 DeepseekV4ReqManager(后续步骤)。 + """DeepSeek-V4 token-slot KV 管理(584B packed cache + bf16 workspace)。 + + - dense/SWA latent: 主 ``kv_buffer`` 仍是 LightLLM 的 token-slot cache,不分页;物理格式改为 + SGLang/vLLM 的 ``fp8_ds_mla``: 448B NoPE fp8 + 64*2B RoPE bf16 + 7B scale + 1B pad = 584B。 + - c4_pool / c128_pool: 两个独立 ``_SubKvPool``(window 粒度,1-token 分配),compressed KV 同样 + 存 584B packed。c4 池附带 132B/token 的 packed indexer-K。 + - 读取时先用 torch reference dequant/gather 回 bf16 workspace,供现有 vLLM sparse FlashMLA wrapper + 消费;下一步可把这些 pack/dequant helper 替换成 fused/triton 版本。 + - 容量: 用闭式 ``get_cell_size()``(= 每个 dense token 在所有池上的 packed 总字节)让基类 + ``profile_size`` 直接得到 full_token = dense 池大小,再按 1/4、1/128 派生压缩池大小。 + - compressor 递归状态放 DeepseekV4ReqManager。 """ - # dense 写入沿用 Deepseek2MemOperator(拆 nope/rope);压缩写入算子随 layer_infer 一并补。 - # operator_class 继承自 Deepseek2MemoryManager(= Deepseek2MemOperator)。 + operator_class = DeepseekV4MemOperator + + mla_nope_dim = DSV4_MLA_NOPE_DIM + mla_rope_dim = DSV4_MLA_ROPE_DIM + mla_head_dim = DSV4_MLA_HEAD_DIM + mla_quant_group_size = DSV4_MLA_QUANT_GROUP_SIZE + mla_scale_bytes = DSV4_MLA_SCALE_BYTES + mla_bytes_per_token = DSV4_MLA_BYTES_PER_TOKEN + indexer_head_dim_default = DSV4_INDEXER_HEAD_DIM + indexer_bytes_per_token = DSV4_INDEXER_BYTES_PER_TOKEN def __init__( self, @@ -85,19 +273,27 @@ def __init__( layer_num, compress_rates: List[int], indexer_head_dim: int = 128, + max_request_num: Optional[int] = None, + sliding_window: Optional[int] = None, always_copy=False, mem_fraction=0.9, ): assert head_num == 1, "DeepSeek-V4 是 MLA(MQA),dense latent 的 head_num 必须为 1" + assert head_dim == self.mla_head_dim, f"DeepSeek-V4 packed KV 期望 head_dim={self.mla_head_dim}" assert ( - len(compress_rates) == layer_num - ), f"compress_rates 长度 {len(compress_rates)} 必须等于 layer_num {layer_num}" + indexer_head_dim == self.indexer_head_dim_default + ), f"DeepSeek-V4 packed indexer-K 期望 indexer_head_dim={self.indexer_head_dim_default}" + assert len(compress_rates) == layer_num, f"compress_rates 长度 {len(compress_rates)} 必须等于 layer_num {layer_num}" assert all(r in (0, 4, 128) for r in compress_rates), "compress_rates 取值只能是 0/4/128" self.compress_rates = list(compress_rates) self.n_c4 = sum(1 for r in self.compress_rates if r == 4) self.n_c128 = sum(1 for r in self.compress_rates if r == 128) self.indexer_head_dim = indexer_head_dim + self.prefill_dtype = dtype + self.cache_dtype = torch.uint8 + self.max_request_num = max_request_num + self.sliding_window = sliding_window # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法) self.layer_to_c4_idx: Dict[int, int] = {} @@ -113,23 +309,111 @@ def __init__( super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def _planned_swa_size(self, full_size: int) -> int: + if self.max_request_num is None or self.sliding_window is None: + return full_size + window_cap = max(1, int(self.max_request_num) * int(self.sliding_window)) + return max(1, min(full_size, window_cap)) + + def _dense_cell_size(self): + return self.head_num * self.mla_bytes_per_token * self.layer_num + + def _compressed_cell_size(self): + latent_bytes = self.head_num * self.mla_bytes_per_token + c4 = latent_bytes * self.n_c4 / 4 + c128 = latent_bytes * self.n_c128 / 128 + indexer = self.indexer_bytes_per_token * self.n_c4 / 4 + return c4 + c128 + indexer + + def profile_size(self, mem_fraction): + if self.size is not None: + return + + torch.cuda.empty_cache() + world_size = dist.get_world_size() + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) + available_bytes = available_memory * 1024 ** 3 + dense_cell = self._dense_cell_size() + compressed_cell = self._compressed_cell_size() + + if self.max_request_num is not None and self.sliding_window is not None and compressed_cell > 0: + swa_cap = max(1, int(self.max_request_num) * int(self.sliding_window)) + full_cell = dense_cell + compressed_cell + bytes_until_swa_cap = full_cell * swa_cap + if available_bytes <= bytes_until_swa_cap: + self.size = max(1, int(available_bytes / full_cell)) + else: + self.size = max(1, int((available_bytes - dense_cell * swa_cap) / compressed_cell)) + else: + self.size = max(1, int(available_bytes / (dense_cell + compressed_cell))) + + if world_size > 1: + tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + self.size = tensor.item() + + if self.size > DSV4_PROFILE_MAX_FULL_TOKENS: + logger.info( + f"DeepseekV4MemoryManager cap profiled max_total_token_num from " + f"{self.size} to {DSV4_PROFILE_MAX_FULL_TOKENS} to keep runtime headroom" + ) + self.size = DSV4_PROFILE_MAX_FULL_TOKENS + + logger.info( + f"{str(available_memory)} GB space is available after load the model weight\n" + f"{str((dense_cell + compressed_cell) / 1024 ** 2)} MB is the conservative size of one token kv cache\n" + f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" + ) + return + def get_cell_size(self): - # 返回“每个 dense(full) token 在所有池上的总字节”。基类 profile_size 用 - # size = available_bytes / get_cell_size(),于是直接得到 full_token = dense 池大小。 - elem = torch._utils._element_size(self.dtype) - latent_bytes = self.head_num * self.head_dim * elem # 每 token 每层 dense latent - dense = latent_bytes * self.layer_num # SWA 全历史: 所有层 - c4 = latent_bytes * self.n_c4 / 4 # c4 压缩 latent - c128 = latent_bytes * self.n_c128 / 128 # c128 压缩 latent - indexer = self.indexer_head_dim * elem * self.n_c4 / 4 # c4 indexer-K - return dense + c4 + c128 + indexer + dense = self._dense_cell_size() + compressed = self._compressed_cell_size() + if self.size is None: + return dense + compressed + swa_ratio = self._planned_swa_size(self.size) / max(1, self.size) + return dense * swa_ratio + compressed def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - # dense/SWA latent(继承 Deepseek2: [layer_num, size+1, head_num, head_dim]) - super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self._init_compressed_pools(size, dtype, head_num, head_dim) + self.swa_size = self._planned_swa_size(size) + self.swa_pool = _PageSlabMlaPool( + size=self.swa_size, + page_size=DSV4_SWA_PAGE_SIZE, + layer_num=layer_num, + device="cuda", + ) + self.kv_buffer = self.swa_pool.kv_buffer + self._init_swa_mapping(size) + self._init_compressed_pools(size, head_num) + + def _init_swa_mapping(self, size): + rank_in_node = get_current_rank_in_node() + server = get_unique_server_name() + self.swa_allocator = KvCacheAllocator( + self.swa_size, + shared_name=f"{server}_dsv4_swa_can_use_token_num_{rank_in_node}", + ) + self.full_to_swa_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") + self.full_to_swa_indexs[size] = self.swa_pool.HOLD_TOKEN_MEMINDEX + if self.max_request_num is None or self.sliding_window is None: + self.req_to_swa_indexs = None + self.req_to_swa_full_indexs = None + return + + self.req_to_swa_indexs = torch.full( + (self.max_request_num + 1, self.sliding_window), + self.swa_pool.HOLD_TOKEN_MEMINDEX, + dtype=torch.int32, + device="cuda", + ) + self.req_to_swa_full_indexs = torch.full( + (self.max_request_num + 1, self.sliding_window), + -1, + dtype=torch.int32, + device="cuda", + ) - def _init_compressed_pools(self, size, dtype, head_num, head_dim): + def _init_compressed_pools(self, size, head_num): rank_in_node = get_current_rank_in_node() server = get_unique_server_name() @@ -141,31 +425,254 @@ def _init_compressed_pools(self, size, dtype, head_num, head_dim): if self.n_c4 > 0: self.c4_pool = _SubKvPool( size=self.c4_size, - dtype=dtype, - head_num=head_num, - head_dim=head_dim, + page_size=DSV4_C4_PAGE_SIZE, layer_num=self.n_c4, - indexer_head_dim=self.indexer_head_dim, + with_indexer=True, shared_name=f"{server}_dsv4_c4_can_use_token_num_{rank_in_node}", ) if self.n_c128 > 0: self.c128_pool = _SubKvPool( size=self.c128_size, - dtype=dtype, - head_num=head_num, - head_dim=head_dim, + page_size=DSV4_C128_PAGE_SIZE, layer_num=self.n_c128, - indexer_head_dim=0, + with_indexer=False, shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}", ) logger.info( - f"DeepseekV4MemoryManager pools: dense={size} " + f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size} " f"c4={self.c4_size}(L={self.n_c4}) c128={self.c128_size}(L={self.n_c128}) " - f"indexer_head_dim={self.indexer_head_dim}" + f"packed_kv_bytes={self.mla_bytes_per_token} indexer_bytes={self.indexer_bytes_per_token}" ) - # dense latent 读取沿用父类 get_att_input_params。 + def get_att_input_params(self, layer_index: int): + return self.swa_pool.get_layer_buffer(layer_index) + + def _pack_mla_kv(self, kv: torch.Tensor) -> torch.Tensor: + kv = kv.reshape(-1, self.mla_head_dim) + out = torch.empty((kv.shape[0], self.mla_bytes_per_token), dtype=torch.uint8, device=kv.device) + nope = kv[:, : self.mla_nope_dim].float().reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size) + scale = torch.clamp(nope.abs().amax(dim=-1) / DSV4_FP8_E4M3_MAX, min=DSV4_FP8_SCALE_MIN) + scale_exp = torch.ceil(torch.log2(scale)).to(torch.int32) + scale = torch.exp2(scale_exp.float()) + nope_fp8 = torch.clamp(nope / scale.unsqueeze(-1), -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to( + torch.float8_e4m3fn + ) + out[:, : self.mla_nope_dim].copy_(nope_fp8.reshape(-1, self.mla_nope_dim).view(dtype=torch.uint8)) + rope_start = self.mla_nope_dim + rope_end = rope_start + self.mla_rope_dim * 2 + rope = kv[:, self.mla_nope_dim : self.mla_head_dim].contiguous().to(torch.bfloat16) + out[:, rope_start:rope_end].copy_(rope.view(dtype=torch.uint8).reshape(-1, self.mla_rope_dim * 2)) + scale_start = rope_end + scale_end = scale_start + self.mla_scale_bytes - 1 + out[:, scale_start:scale_end].copy_((scale_exp + 127).to(torch.uint8)) + out[:, scale_end].zero_() + return out + + def _unpack_mla_kv(self, packed: torch.Tensor) -> torch.Tensor: + packed = packed.reshape(-1, self.mla_bytes_per_token) + if packed.shape[0] == 0: + return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=packed.device) + nope_fp8 = packed[:, : self.mla_nope_dim].view(dtype=torch.float8_e4m3fn).float() + nope_fp8 = nope_fp8.reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size) + rope_start = self.mla_nope_dim + rope_end = rope_start + self.mla_rope_dim * 2 + scale_start = rope_end + scale_end = scale_start + self.mla_scale_bytes - 1 + scale_exp = packed[:, scale_start:scale_end].to(torch.int32) - 127 + scale = torch.exp2(scale_exp.float()) + nope = (nope_fp8 * scale.reshape(-1, self.mla_scale_bytes - 1, 1)).reshape(-1, self.mla_nope_dim) + rope = packed[:, rope_start:rope_end].view(dtype=torch.bfloat16) + return torch.cat([nope.to(self.dtype), rope.to(self.dtype)], dim=-1) + + def _pack_indexer_k(self, indexer_k: torch.Tensor) -> torch.Tensor: + indexer_k = indexer_k.reshape(-1, self.indexer_head_dim) + out = torch.empty( + (indexer_k.shape[0], self.indexer_bytes_per_token), + dtype=torch.uint8, + device=indexer_k.device, + ) + k_float = indexer_k.float() + scale = torch.clamp( + k_float.abs().amax(dim=-1, keepdim=True) / DSV4_FP8_E4M3_MAX, + min=DSV4_FP8_SCALE_MIN, + ) + k_fp8 = torch.clamp(k_float / scale, -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(torch.float8_e4m3fn) + out[:, : self.indexer_head_dim].copy_(k_fp8.view(dtype=torch.uint8)) + out[:, self.indexer_head_dim : self.indexer_bytes_per_token].copy_(scale.view(dtype=torch.uint8).reshape(-1, 4)) + return out + + def _unpack_indexer_k(self, packed: torch.Tensor) -> torch.Tensor: + packed = packed.reshape(-1, self.indexer_bytes_per_token) + if packed.shape[0] == 0: + return torch.empty((0, self.indexer_head_dim), dtype=self.dtype, device=packed.device) + k_fp8 = packed[:, : self.indexer_head_dim].view(dtype=torch.float8_e4m3fn).float() + scale = packed[:, self.indexer_head_dim : self.indexer_bytes_per_token].view(dtype=torch.float32) + return (k_fp8 * scale).to(self.dtype) + + def _identity_swa_slots(self, full_slots: torch.Tensor) -> torch.Tensor: + full_slots = full_slots.long() + valid = full_slots != self.HOLD_TOKEN_MEMINDEX + if valid.any() and int(full_slots[valid].max().item()) >= self.swa_size: + raise RuntimeError( + "DeepSeek-V4 SWA cache needs req_idx/positions for full token slots outside the SWA pool" + ) + swa_slots = torch.where( + valid, + full_slots, + torch.full_like(full_slots, self.swa_pool.HOLD_TOKEN_MEMINDEX), + ) + if valid.any(): + self.full_to_swa_indexs[full_slots[valid]] = swa_slots[valid].to(torch.int32) + return swa_slots + + def ensure_swa_slots(self, req_idx: int, positions: torch.Tensor, full_slots: torch.Tensor) -> torch.Tensor: + full_slots = full_slots.long().reshape(-1) + if full_slots.numel() == 0: + return full_slots + if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: + return self._identity_swa_slots(full_slots) + + positions = positions.long().reshape(-1) + assert positions.numel() == full_slots.numel() + req_idx = int(req_idx) + out = torch.empty_like(full_slots, dtype=torch.long) + for i, (pos, full) in enumerate(zip(positions.tolist(), full_slots.tolist())): + if full == self.HOLD_TOKEN_MEMINDEX: + out[i] = self.swa_pool.HOLD_TOKEN_MEMINDEX + continue + + ring_pos = pos % self.sliding_window + old_swa = int(self.req_to_swa_indexs[req_idx, ring_pos].item()) + old_full = int(self.req_to_swa_full_indexs[req_idx, ring_pos].item()) + if old_full == full and old_swa != self.swa_pool.HOLD_TOKEN_MEMINDEX: + swa = old_swa + elif old_swa != self.swa_pool.HOLD_TOKEN_MEMINDEX: + if old_full >= 0: + self.full_to_swa_indexs[old_full] = -1 + swa = old_swa + else: + swa = int(self.swa_allocator.alloc(1)[0].item()) + + self.req_to_swa_indexs[req_idx, ring_pos] = swa + self.req_to_swa_full_indexs[req_idx, ring_pos] = full + self.full_to_swa_indexs[full] = swa + out[i] = swa + return out + + def _swa_slots_from_full(self, full_slots: torch.Tensor) -> torch.Tensor: + full_slots = full_slots.long().reshape(-1) + if full_slots.numel() == 0: + return full_slots + mapped = self.full_to_swa_indexs[full_slots].long() + missing = mapped < 0 + if missing.any(): + if self.req_to_swa_indexs is not None: + bad = int(full_slots[missing][0].item()) + raise RuntimeError(f"DeepSeek-V4 dense KV for full token slot {bad} has been evicted from SWA cache") + fallback = full_slots[missing] + fallback_valid = fallback < self.swa_size + if fallback_valid.all(): + mapped[missing] = fallback + self.full_to_swa_indexs[fallback] = fallback.to(torch.int32) + else: + bad = int(fallback[~fallback_valid][0].item()) + raise RuntimeError(f"DeepSeek-V4 dense KV for full token slot {bad} has been evicted from SWA cache") + return mapped + + def free_swa_for_req(self, req_idx: int) -> None: + if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: + return + req_idx = int(req_idx) + slots = self.req_to_swa_indexs[req_idx] + full_slots = self.req_to_swa_full_indexs[req_idx] + valid_swa = slots != self.swa_pool.HOLD_TOKEN_MEMINDEX + if valid_swa.any(): + free_slots = torch.unique(slots[valid_swa]).detach().cpu() + self.swa_allocator.free(free_slots) + valid_full = full_slots >= 0 + if valid_full.any(): + self.full_to_swa_indexs[full_slots[valid_full].long()] = -1 + self.req_to_swa_indexs[req_idx].fill_(self.swa_pool.HOLD_TOKEN_MEMINDEX) + self.req_to_swa_full_indexs[req_idx].fill_(-1) + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + + def _keep_last_swa_writes(self, swa_slots: torch.Tensor, packed: torch.Tensor): + """Drop duplicate SWA writes generated by long prefill ring reuse.""" + if swa_slots.numel() <= 1: + return swa_slots, packed + + slots_cpu = swa_slots.detach().cpu().tolist() + seen = set() + keep = [] + hold = self.swa_pool.HOLD_TOKEN_MEMINDEX + for i in range(len(slots_cpu) - 1, -1, -1): + slot = int(slots_cpu[i]) + if slot == hold or slot in seen: + continue + seen.add(slot) + keep.append(i) + keep.reverse() + if len(keep) == len(slots_cpu): + return swa_slots, packed + if not keep: + return swa_slots[:0], packed[:0] + keep_index = torch.tensor(keep, dtype=torch.long, device=swa_slots.device) + return swa_slots.index_select(0, keep_index), packed.index_select(0, keep_index) + + def pack_mla_kv_to_cache( + self, + layer_index: int, + mem_index: torch.Tensor, + kv: torch.Tensor, + req_idx: Optional[int] = None, + positions: Optional[torch.Tensor] = None, + ): + if kv.shape[0] == 0: + return + packed = self._pack_mla_kv(kv) + if req_idx is None or positions is None: + swa_slots = self._identity_swa_slots(mem_index).to(kv.device) + else: + swa_slots = self.ensure_swa_slots(req_idx, positions, mem_index).to(kv.device) + swa_slots, packed = self._keep_last_swa_writes(swa_slots, packed) + if swa_slots.numel() == 0: + return + self.swa_pool.write(layer_index, swa_slots, packed) + + def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor): + if comp.shape[0] == 0: + return + pool, local_layer = self._pool_and_local_layer(layer_index) + pool.write_kv(local_layer, slots.to(comp.device), self._pack_mla_kv(comp)) + + def pack_c4_indexer_k_to_cache(self, layer_index: int, slots: torch.Tensor, indexer_k: torch.Tensor): + if indexer_k.shape[0] == 0: + return + pool, local_layer = self._pool_and_local_layer(layer_index) + pool.write_indexer_k(local_layer, slots.to(indexer_k.device), self._pack_indexer_k(indexer_k)) + + def gather_mla_kv(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + if slots.numel() == 0: + return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=self.kv_buffer.device) + swa_slots = self._swa_slots_from_full(slots).to(self.kv_buffer.device) + return self._unpack_mla_kv(self.swa_pool.read(layer_index, swa_slots)) + + def gather_compressed_kv(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + if slots.numel() == 0: + return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=self.kv_buffer.device) + pool, local_layer = self._pool_and_local_layer(layer_index) + return self._unpack_mla_kv(pool.read_kv(local_layer, slots.to(self.kv_buffer.device))) + + def gather_c4_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + if slots.numel() == 0: + return torch.empty( + (0, self.indexer_head_dim), + dtype=self.dtype, + device=self.kv_buffer.device, + ) + pool, local_layer = self._pool_and_local_layer(layer_index) + return self._unpack_indexer_k(pool.read_indexer_k(local_layer, slots.to(self.kv_buffer.device))) def _pool_and_local_layer(self, layer_index: int): r = self.compress_rates[layer_index] @@ -197,7 +704,21 @@ def free_c128(self, free_index) -> None: def free_all(self): super().free_all() + if hasattr(self, "swa_allocator"): + self.swa_allocator.free_all() + if hasattr(self, "full_to_swa_indexs"): + self.full_to_swa_indexs.fill_(-1) + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + if getattr(self, "req_to_swa_indexs", None) is not None: + self.req_to_swa_indexs.fill_(self.swa_pool.HOLD_TOKEN_MEMINDEX) + self.req_to_swa_full_indexs.fill_(-1) if self.c4_pool is not None: self.c4_pool.free_all() if self.c128_pool is not None: self.c128_pool.free_all() + + def alloc_kv_move_buffer(self, max_req_total_len): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + raise NotImplementedError("DeepSeek-V4 packed/composite paged KV transfer is not implemented") diff --git a/lightllm/common/kv_cache_mem_manager/operator/__init__.py b/lightllm/common/kv_cache_mem_manager/operator/__init__.py index 85c37ad39b..442c2e300e 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/operator/__init__.py @@ -5,6 +5,7 @@ from .deepseek import ( Deepseek2MemOperator, Deepseek3_2MemOperator, + DeepseekV4MemOperator, FP8PerTokenGroupQuantDeepseek3_2MemOperator, ) from .fp8_quant import ( diff --git a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py index 6e05b96e10..0725ce9b93 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py +++ b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py @@ -8,7 +8,9 @@ class Deepseek2MemOperator(NormalMemOperator): def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import ( + Deepseek2MemoryManager, + ) mem_manager: Deepseek2MemoryManager = self.mem_manager @@ -30,7 +32,9 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: class Deepseek3_2MemOperator(Deepseek2MemOperator): def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import Deepseek3_2MemoryManager + from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import ( + Deepseek3_2MemoryManager, + ) mem_manager: Deepseek3_2MemoryManager = self.mem_manager from ...basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv @@ -78,3 +82,14 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: o_rope, ) return + + +class DeepseekV4MemOperator(BaseMemManagerOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( + DeepseekV4MemoryManager, + ) + + mem_manager: DeepseekV4MemoryManager = self.mem_manager + mem_manager.pack_mla_kv_to_cache(layer_index, mem_index, kv) + return diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index cd534d53ec..1c5a9c09d3 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -68,6 +68,9 @@ def _mapping_quant_method(self): expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None) if expert_dtype is None: return + if expert_dtype == "fp4" and self.network_config_.get("model_type") == "deepseek_v4" and not is_sm100_gpu(): + logger.info("skip generic fused_moe quant mapping for DeepSeek-V4 fp4 experts on non-SM100 GPUs") + return target = self._get_expert_quant_type(expert_dtype) for layer_num in range(self.layer_num): if self.expert_dtype is not None: diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index c8197401c1..1cdea03381 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -6,12 +6,16 @@ from .kv_cache_mem_manager import MemoryManager, DeepseekV4MemoryManager from typing import List, Optional, TYPE_CHECKING from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter -from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter +from lightllm.common.basemodel.triton_kernel.gen_sampling_params import ( + update_req_to_token_id_counter, +) from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache -from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager +from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import ( + LinearAttCacheManager, +) if TYPE_CHECKING: from lightllm.server.router.model_infer.infer_batch import InferReq @@ -131,11 +135,13 @@ def __init__(self, max_request_num): ) elif self.penalty_counter_mode == "pin_mem_counter": self.req_to_out_token_id_counter = torch.zeros( - (max_request_num + 1, self.vocab_size), dtype=torch.int32, device="cpu", pin_memory=True + (max_request_num + 1, self.vocab_size), + dtype=torch.int32, + device="cpu", + pin_memory=True, ) def init_req_sampling_params(self, req: "InferReq"): - shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty) @@ -165,14 +171,18 @@ def init_req_sampling_params(self, req: "InferReq"): dtype=torch.int32, ).cuda(non_blocking=True) token_id_counter( - prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] + prompt_ids=prompt_ids, + out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx], ) torch.cuda.current_stream().synchronize() return def update_reqs_out_token_counter_gpu( - self, b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, mask: torch.Tensor = None + self, + b_req_idx: torch.Tensor, + next_token_ids: torch.Tensor, + mask: torch.Tensor = None, ): if self.penalty_counter_mode not in ["gpu_counter", "pin_mem_counter"]: return @@ -188,7 +198,10 @@ def update_reqs_out_token_counter_gpu( return def update_reqs_token_counter( - self, req_objs: List["InferReq"], next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None + self, + req_objs: List["InferReq"], + next_token_ids: List[int], + accept_mark: Optional[List[List[bool]]] = None, ): if self.penalty_counter_mode != "cpu_counter": return @@ -230,7 +243,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List["InferReq"]): class ReqManagerForMamba(ReqManager): - def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_config: LinearAttCacheConfig): + def __init__( + self, + max_request_num, + max_sequence_length, + mem_manager, + linear_config: LinearAttCacheConfig, + ): super().__init__(max_request_num, max_sequence_length, mem_manager) self.mtp_step = get_env_start_args().mtp_step self.big_page_token_num = ( @@ -275,7 +294,6 @@ def get_mamba_cache(self, layer_idx_in_all: int): return conv_states, ssm_states def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req: "InferReq"): - from .linear_att_cache_manager import LinearAttCacheManager big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers @@ -304,8 +322,9 @@ def copy_small_page_buffer_to_linear_att_state( class DeepseekV4ReqManager(ReqManager): """DeepSeek-V4 的请求级管理(锁定决策: SWA 全历史 + 不分页)。 - 在基类 ReqManager 之上补三类 V4 专有的 per-request 结构(均从 mem_manager 读取 n_c4/n_c128/ - layer_to_*_idx/head_dim 等,避免重复配置): + 在基类 ReqManager 之上补三类 V4 专有的 per-request 结构。该对象在 mem manager profile 前创建, + 所以初始化只依赖 config 派生出的 compress_rates/head_dim/indexer_head_dim;真实 mem_manager + 会在 `_init_mem_manager()` 后通过 `bind_mem_manager()` 接入。 * ``req_to_c4_indexs`` / ``req_to_c128_indexs`` —— (req, 窗口下标) -> 压缩池槽位。 窗口下标 = position // compress_rate;窗口关闭时由 layer-infer 写入,attention 读取前 @@ -318,19 +337,48 @@ class DeepseekV4ReqManager(ReqManager): * entry_count 不另存:= position // compress_rate,可由序列长度推出。 """ - def __init__(self, max_request_num, max_sequence_length, mem_manager: DeepseekV4MemoryManager): + def __init__( + self, + max_request_num, + max_sequence_length, + mem_manager: Optional[DeepseekV4MemoryManager] = None, + compress_rates: Optional[List[int]] = None, + head_dim: Optional[int] = None, + indexer_head_dim: Optional[int] = None, + ): super().__init__(max_request_num, max_sequence_length, mem_manager) - assert isinstance(mem_manager, DeepseekV4MemoryManager) - self.n_c4 = mem_manager.n_c4 - self.n_c128 = mem_manager.n_c128 - head_dim = mem_manager.head_dim - indexer_head_dim = mem_manager.indexer_head_dim + if mem_manager is not None: + assert isinstance(mem_manager, DeepseekV4MemoryManager) + compress_rates = mem_manager.compress_rates + head_dim = mem_manager.head_dim + indexer_head_dim = mem_manager.indexer_head_dim + assert compress_rates is not None, "DeepSeek-V4 req manager requires compress_rates" + assert head_dim is not None, "DeepSeek-V4 req manager requires head_dim" + assert indexer_head_dim is not None, "DeepSeek-V4 req manager requires indexer_head_dim" + + self.compress_rates = list(compress_rates) + self.n_c4 = sum(1 for r in self.compress_rates if r == 4) + self.n_c128 = sum(1 for r in self.compress_rates if r == 128) + self.head_dim = head_dim + self.indexer_head_dim = indexer_head_dim + self.layer_to_c4_idx = {} + self.layer_to_c128_idx = {} + c4 = c128 = 0 + for lid, r in enumerate(self.compress_rates): + if r == 4: + self.layer_to_c4_idx[lid] = c4 + c4 += 1 + elif r == 128: + self.layer_to_c128_idx[lid] = c128 + c128 += 1 # (req, 窗口) -> 压缩槽。列数取 ceil(max_seq / ratio) 留足余量。 c4_windows = (max_sequence_length + 4 - 1) // 4 c128_windows = (max_sequence_length + 128 - 1) // 128 self.req_to_c4_indexs = torch.zeros((max_request_num + 1, c4_windows), dtype=torch.int32, device="cuda") self.req_to_c128_indexs = torch.zeros((max_request_num + 1, c128_windows), dtype=torch.int32, device="cuda") + self._c4_entry_counts = [0 for _ in range(max_request_num + 1)] + self._c128_entry_counts = [0 for _ in range(max_request_num + 1)] # compressor 在途窗口累加状态(fp32): [kv_or_score, coff * ratio, coff * dim]. state_dtype = torch.float32 @@ -355,9 +403,39 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: DeepseekV4 layer_num=self.n_c4, device="cuda", ) + self.req_to_c4_state_pool = LayerCache( + size=max_request_num + 1, + dtype=state_dtype, + shape=(1, 8, 4 * head_dim), + layer_num=self.n_c4, + device="cuda", + ) + self.req_to_c128_state_pool = LayerCache( + size=max_request_num + 1, + dtype=state_dtype, + shape=(1, 128, 2 * head_dim), + layer_num=self.n_c128, + device="cuda", + ) + self.req_to_c4_indexer_state_pool = LayerCache( + size=max_request_num + 1, + dtype=state_dtype, + shape=(1, 8, 4 * indexer_head_dim), + layer_num=self.n_c4, + device="cuda", + ) + self._runtime_states = [{} for _ in range(max_request_num + 1)] self._init_all_score_state() return + def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager): + assert isinstance(mem_manager, DeepseekV4MemoryManager) + assert self.compress_rates == mem_manager.compress_rates + assert self.head_dim == mem_manager.head_dim + assert self.indexer_head_dim == mem_manager.indexer_head_dim + self.mem_manager = mem_manager + return + def _init_all_score_state(self): if self.n_c4 > 0: self.req_to_c4_state.buffer[:, :, 1, ...].fill_(float("-inf")) @@ -373,33 +451,184 @@ def _reset_compress_cache_req(self, cache: LayerCache, req_idx: int): cache.buffer[:, req_idx, 1, ...].fill_(float("-inf")) return + def _reset_state_pool_req(self, cache: LayerCache, req_idx: int): + if cache.layer_num == 0: + return + cache.buffer[:, req_idx, ...].fill_(0) + return + def init_compress_state(self, req_idx: int): """新请求开始时重置其 compressor 在途状态(对应 mamba 的 init_linear_att_state)。""" + self.clear_runtime_state(req_idx) + c4, c128 = self.pop_compress_indices_for_req(req_idx) + self.free_compress_indices(free_c4_index=c4, free_c128_index=c128) if self.n_c4 > 0: self._reset_compress_cache_req(self.req_to_c4_state, req_idx) self._reset_compress_cache_req(self.req_to_c4_indexer_state, req_idx) + self._reset_state_pool_req(self.req_to_c4_state_pool, req_idx) + self._reset_state_pool_req(self.req_to_c4_indexer_state_pool, req_idx) if self.n_c128 > 0: self._reset_compress_cache_req(self.req_to_c128_state, req_idx) + self._reset_state_pool_req(self.req_to_c128_state_pool, req_idx) return + def _ensure_compress_slots(self, req_idx: int, ratio: int, entry_start: int, entry_count: int) -> torch.Tensor: + if entry_count == 0: + return torch.empty((0,), dtype=torch.int32, device="cuda") + assert entry_start >= 0 and entry_count >= 0 + assert self.mem_manager is not None, "DeepSeek-V4 mem manager is not bound yet" + if ratio == 4: + table = self.req_to_c4_indexs + counts = self._c4_entry_counts + alloc = self.mem_manager.alloc_c4 + elif ratio == 128: + table = self.req_to_c128_indexs + counts = self._c128_entry_counts + alloc = self.mem_manager.alloc_c128 + else: + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}") + + required_count = entry_start + entry_count + assert required_count <= table.shape[1], ( + f"DeepSeek-V4 compressed slot table overflow: req={req_idx} " + f"ratio={ratio} required={required_count} capacity={table.shape[1]}" + ) + old_count = counts[req_idx] + if required_count > old_count: + new_slots_cpu = alloc(required_count - old_count) + table[req_idx, old_count:required_count] = new_slots_cpu.cuda(non_blocking=True) + counts[req_idx] = required_count + return table[req_idx, entry_start:required_count] + + def ensure_c4_slots(self, req_idx: int, entry_start: int, entry_count: int) -> torch.Tensor: + return self._ensure_compress_slots(req_idx, 4, entry_start, entry_count) + + def ensure_c128_slots(self, req_idx: int, entry_start: int, entry_count: int) -> torch.Tensor: + return self._ensure_compress_slots(req_idx, 128, entry_start, entry_count) + + def ensure_compress_slots(self, layer_index: int, req_idx: int, entry_start: int, entry_count: int) -> torch.Tensor: + ratio = self.compress_rates[layer_index] + if ratio == 4: + return self.ensure_c4_slots(req_idx, entry_start, entry_count) + if ratio == 128: + return self.ensure_c128_slots(req_idx, entry_start, entry_count) + raise AssertionError(f"layer {layer_index} is not a compressed attention layer") + + def pop_compress_indices_for_req(self, req_idx: int): + c4_count = self._c4_entry_counts[req_idx] + if c4_count > 0: + c4 = self.req_to_c4_indexs[req_idx, :c4_count].clone() + self.req_to_c4_indexs[req_idx, :c4_count].fill_(0) + self._c4_entry_counts[req_idx] = 0 + else: + c4 = None + + c128_count = self._c128_entry_counts[req_idx] + if c128_count > 0: + c128 = self.req_to_c128_indexs[req_idx, :c128_count].clone() + self.req_to_c128_indexs[req_idx, :c128_count].fill_(0) + self._c128_entry_counts[req_idx] = 0 + else: + c128 = None + return c4, c128 + + def free_compress_indices(self, free_c4_index=None, free_c128_index=None): + if free_c4_index is not None and len(free_c4_index) > 0: + self.mem_manager.free_c4(free_c4_index) + if free_c128_index is not None and len(free_c128_index) > 0: + self.mem_manager.free_c128(free_c128_index) + return + + def alloc(self): + req_idx = super().alloc() + if req_idx is not None: + self.init_compress_state(req_idx) + return req_idx + + def clear_runtime_state(self, req_idx: int): + self._runtime_states[req_idx].clear() + if self.mem_manager is not None and hasattr(self.mem_manager, "free_swa_for_req"): + self.mem_manager.free_swa_for_req(req_idx) + return + + def set_runtime_state(self, req_idx: int, layer_index: int, state: dict): + self._runtime_states[req_idx][layer_index] = state + return + + def get_runtime_state(self, req_idx: int, layer_index: int): + return self._runtime_states[req_idx][layer_index] + + def get_compress_state_for_req(self, layer_index: int, req_idx: int): + if self.compress_rates[layer_index] == 4: + state = self.get_c4_compress_state(layer_index) + elif self.compress_rates[layer_index] == 128: + state = self.get_c128_compress_state(layer_index) + else: + raise AssertionError(f"layer {layer_index} is not a compressed attention layer") + return state[req_idx, 0], state[req_idx, 1] + + def get_compress_state_pool_for_req(self, layer_index: int, req_idx: int): + if self.compress_rates[layer_index] == 4: + cache = self.req_to_c4_state_pool + local = self.layer_to_c4_idx[layer_index] + elif self.compress_rates[layer_index] == 128: + cache = self.req_to_c128_state_pool + local = self.layer_to_c128_idx[layer_index] + else: + raise AssertionError(f"layer {layer_index} is not a compressed attention layer") + return cache.buffer[local, req_idx] + def get_c4_compress_state(self, layer_index: int) -> torch.Tensor: - local = self.mem_manager.layer_to_c4_idx[layer_index] + local = self.layer_to_c4_idx[layer_index] return self.req_to_c4_state.buffer[local] def get_c128_compress_state(self, layer_index: int) -> torch.Tensor: - local = self.mem_manager.layer_to_c128_idx[layer_index] + local = self.layer_to_c128_idx[layer_index] return self.req_to_c128_state.buffer[local] def get_c4_indexer_compress_state(self, layer_index: int) -> torch.Tensor: - local = self.mem_manager.layer_to_c4_idx[layer_index] + local = self.layer_to_c4_idx[layer_index] return self.req_to_c4_indexer_state.buffer[local] - def free(self, free_req_indexes, free_token_index, free_c4_index=None, free_c128_index=None): + def get_c4_indexer_state_pool_for_req(self, layer_index: int, req_idx: int) -> torch.Tensor: + local = self.layer_to_c4_idx[layer_index] + return self.req_to_c4_indexer_state_pool.buffer[local, req_idx] + + def free( + self, + free_req_indexes, + free_token_index, + free_c4_index=None, + free_c128_index=None, + ): """释放 dense 槽(基类)+ 压缩槽。压缩槽由调用方(infer batch)从 req_to_c*_indexs 收集后传入, 与基类用 free_token_index 传 dense 槽的方式一致。""" + for req_index in free_req_indexes: + self.clear_runtime_state(req_index) super().free(free_req_indexes, free_token_index) - if free_c4_index is not None and len(free_c4_index) > 0: - self.mem_manager.free_c4(free_c4_index) - if free_c128_index is not None and len(free_c128_index) > 0: - self.mem_manager.free_c128(free_c128_index) + self.free_compress_indices(free_c4_index=free_c4_index, free_c128_index=free_c128_index) + return + + def free_req(self, free_req_index: int): + self.clear_runtime_state(free_req_index) + c4, c128 = self.pop_compress_indices_for_req(free_req_index) + self.free_compress_indices(free_c4_index=c4, free_c128_index=c128) + return super().free_req(free_req_index) + + def free_all(self): + super().free_all() + self._runtime_states = [{} for _ in range(self.max_request_num + 1)] + self._c4_entry_counts = [0 for _ in range(self.max_request_num + 1)] + self._c128_entry_counts = [0 for _ in range(self.max_request_num + 1)] + if self.n_c4 > 0: + self.req_to_c4_indexs.fill_(0) + self.req_to_c4_state.buffer.fill_(0) + self.req_to_c4_indexer_state.buffer.fill_(0) + self.req_to_c4_state_pool.buffer.fill_(0) + self.req_to_c4_indexer_state_pool.buffer.fill_(0) + if self.n_c128 > 0: + self.req_to_c128_indexs.fill_(0) + self.req_to_c128_state.buffer.fill_(0) + self.req_to_c128_state_pool.buffer.fill_(0) + self._init_all_score_state() return diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 5831044311..cd33386666 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,14 +1,20 @@ import copy from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel -from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer -from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import ( + Deepseek3_2TransformerLayerWeight, +) +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import ( + Deepseek3_2TransformerLayerInfer, +) +from lightllm.common.basemodel.attention import ( + get_nsa_prefill_att_backend_class, + get_nsa_decode_att_backend_class, +) @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): - # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight @@ -21,24 +27,11 @@ def _init_att_backend(self): return -class DeepSeekV32Tokenizer: - """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based - encoding_dsv32 module instead of Jinja chat templates. - - DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat - template, so ``apply_chat_template`` would fail without either a manually - supplied ``--chat_template`` file or this wrapper. - """ - +class DeepSeekChatTokenizerBase: def __init__(self, tokenizer): self.tokenizer = tokenizer - # Cache added vocabulary for performance (HuggingFace can be slow). self._added_vocab = None - # ------------------------------------------------------------------ - # Attribute delegation – everything not overridden goes to the inner - # tokenizer so that encode/decode/vocab_size/eos_token_id/… all work. - # ------------------------------------------------------------------ def __getattr__(self, name): return getattr(self.tokenizer, name) @@ -47,9 +40,9 @@ def get_added_vocab(self): self._added_vocab = self.tokenizer.get_added_vocab() return self._added_vocab - # ------------------------------------------------------------------ - # Core override: route apply_chat_template through encode_messages. - # ------------------------------------------------------------------ + def _encode_messages(self, msgs, thinking_mode, kwargs): + raise NotImplementedError("subclass must provide DeepSeek encode_messages") + def apply_chat_template( self, conversation=None, @@ -58,27 +51,16 @@ def apply_chat_template( tokenize=False, add_generation_prompt=True, thinking=None, + enable_thinking=None, **kwargs, ): - from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages, render_tools - msgs = conversation if conversation is not None else messages if msgs is None: raise ValueError("Either 'conversation' or 'messages' must be provided") - # Deep copy to avoid mutating the caller's messages. msgs = copy.deepcopy(msgs) - # Determine thinking mode. - thinking_mode = "thinking" if thinking else "chat" - - # Inject tools into the first system message (or create one) so that - # encode_messages / render_message picks them up. if tools: - # build_prompt passes tools as bare function dicts: - # [{"name": "f", "description": "...", "parameters": {...}}] - # encoding_dsv32's render_message expects OpenAI wrapper format: - # [{"type": "function", "function": {...}}] wrapped_tools = [] for t in tools: if "function" in t: @@ -95,16 +77,27 @@ def apply_chat_template( break if not injected: - # Prepend a system message that carries the tools. msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) - prompt = encode_messages( + if thinking is None: + thinking = bool(enable_thinking) if enable_thinking is not None else False + thinking_mode = "thinking" if thinking else "chat" + prompt = self._encode_messages(msgs, thinking_mode, kwargs) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt + + +class DeepSeekV32Tokenizer(DeepSeekChatTokenizerBase): + """Tokenizer wrapper for DeepSeek-V3.2's Python-based encoding_dsv32 module.""" + + def _encode_messages(self, msgs, thinking_mode, kwargs): + from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages + + return encode_messages( msgs, thinking_mode=thinking_mode, drop_thinking=kwargs.get("drop_thinking", True), add_default_bos_token=kwargs.get("add_default_bos_token", True), ) - - if tokenize: - return self.tokenizer.encode(prompt, add_special_tokens=False) - return prompt diff --git a/lightllm/models/deepseek_v4/layer_infer/attention.py b/lightllm/models/deepseek_v4/layer_infer/attention.py index a25a2aa3d1..a24949696f 100644 --- a/lightllm/models/deepseek_v4/layer_infer/attention.py +++ b/lightllm/models/deepseek_v4/layer_infer/attention.py @@ -1,34 +1,101 @@ +import os + import torch -import torch.nn.functional as F -# DeepSeek-V4 attention: MLA with a single shared KV head (head_dim=512), per-head learnable attention -# sink, and a candidate set = sliding-window tokens (size `window`) ++ compressed KV entries. Pure-torch -# transcription of the bundled reference (inference/model.py Attention.forward + kernel.py sparse_attn). -# Correctness-first prefill path. head_dim=512 > 256 so FlashAttention is unusable anyway; a fused -# triton sparse-gather kernel is a perf follow-up. +FLASHMLA_MIN_HEADS = 64 +FLASHMLA_TOPK_MULTIPLE = 128 +DSV4_DEBUG_TORCH_SPARSE_ATTN = os.getenv("DSV4_DEBUG_TORCH_SPARSE_ATTN", "0") == "1" + + +def _pad_topk_for_flashmla(topk_idxs): + K = topk_idxs.shape[-1] + padded_K = ((K + FLASHMLA_TOPK_MULTIPLE - 1) // FLASHMLA_TOPK_MULTIPLE) * FLASHMLA_TOPK_MULTIPLE + if padded_K == K: + return topk_idxs.contiguous() + padded = torch.full((*topk_idxs.shape[:-1], padded_K), -1, device=topk_idxs.device, dtype=topk_idxs.dtype) + padded[..., :K] = topk_idxs + return padded.contiguous() + + +def _compact_topk_indices(topk_idxs, kv_len): + valid = (topk_idxs >= 0) & (topk_idxs < kv_len) + topk_lens = valid.sum(dim=-1).to(torch.int32) + if valid.all(): + return topk_idxs.contiguous(), topk_lens.contiguous() + + compact = torch.full_like(topk_idxs, -1) + ranks = valid.to(torch.int32).cumsum(dim=-1) - 1 + rows = torch.arange(topk_idxs.shape[0], device=topk_idxs.device).unsqueeze(1).expand_as(topk_idxs) + compact[rows[valid], ranks[valid].long()] = topk_idxs[valid] + return compact.contiguous(), topk_lens.contiguous() + + +def _pad_heads_for_flashmla(q, attn_sink): + h = q.shape[1] + if h == FLASHMLA_MIN_HEADS: + return q.contiguous(), attn_sink.to(torch.float32).contiguous(), h + if h > FLASHMLA_MIN_HEADS: + raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention only supports up to 64 local heads, got {h}") -def torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale): - """Gather-then-softmax attention with a per-head sink, matching reference kernel.sparse_attn. + q_pad = q.new_zeros(q.shape[0], FLASHMLA_MIN_HEADS, q.shape[2]) + q_pad[:, :h] = q + sink_pad = torch.full((FLASHMLA_MIN_HEADS,), -float("inf"), device=q.device, dtype=torch.float32) + sink_pad[:h] = attn_sink.to(torch.float32) + return q_pad.contiguous(), sink_pad.contiguous(), h - q:[b,m,h,d], kv:[b,n,d] (single KV head shared over h), attn_sink:[h] (fp32), - topk_idxs:[b,m,K] int (-1 = invalid/skip). Returns o:[b,m,h,d]. + +def _torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale): + q0 = q[0].float() + kv0 = kv[0].float() + indices = topk_idxs[0].long() + valid = (indices >= 0) & (indices < kv0.shape[0]) + safe_indices = torch.where(valid, indices, torch.zeros_like(indices)) + kv_sel = kv0[safe_indices] + scores = torch.einsum("mhd,mkd->mhk", q0, kv_sel) * scale + scores = scores.masked_fill(~valid.unsqueeze(1), float("-inf")) + sink = attn_sink.float().view(1, -1) + max_scores = torch.maximum(scores.max(dim=-1).values, sink) + exp_scores = torch.exp(scores - max_scores.unsqueeze(-1)).masked_fill(~valid.unsqueeze(1), 0.0) + exp_sink = torch.exp(sink - max_scores) + denom = exp_scores.sum(dim=-1) + exp_sink + out = torch.einsum("mhk,mkd->mhd", exp_scores / denom.unsqueeze(-1), kv_sel) + return out.unsqueeze(0).to(q.dtype) + + +def vllm_sparse_attn(q, kv, attn_sink, topk_idxs, scale): + """DeepSeek-V4 sparse MLA through vLLM FlashMLA. + + q:[1,m,h,d], kv:[1,n,d] (single KV head shared over h), attn_sink:[h], + topk_idxs:[1,m,K] int (-1 = invalid/skip). Returns o:[1,m,h,d]. """ b, m, h, d = q.shape - n = kv.shape[1] - K = topk_idxs.shape[-1] - idx = topk_idxs.clamp(min=0).long() # [b,m,K] - keys = torch.gather(kv.unsqueeze(1).expand(b, m, n, d), 2, idx.unsqueeze(-1).expand(b, m, K, d)) # [b,m,K,d] - qf, kf = q.float(), keys.float() - scores = torch.einsum("bmhd,bmkd->bmhk", qf, kf) * scale # [b,m,h,K] - valid = (topk_idxs != -1).unsqueeze(2) # [b,m,1,K] - scores = scores.masked_fill(~valid, float("-inf")) - mx = scores.amax(dim=-1, keepdim=True) # [b,m,h,1] - mx = torch.nan_to_num(mx, neginf=0.0) - ex = (scores - mx).exp() # [b,m,h,K] - denom = ex.sum(-1) + (attn_sink.view(1, 1, h) - mx.squeeze(-1)).exp() # [b,m,h] - o = torch.einsum("bmhk,bmkd->bmhd", ex, kf) / denom.unsqueeze(-1) - return o.to(q.dtype) + if b != 1 or kv.shape[0] != 1 or topk_idxs.shape[0] != 1: + raise RuntimeError("DeepSeek-V4 FlashMLA sparse attention wrapper expects one request per call") + if d != 512: + raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires head_dim=512, got {d}") + if q.dtype != torch.bfloat16 or kv.dtype != torch.bfloat16: + raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires bf16 q/kv, got {q.dtype}/{kv.dtype}") + + if DSV4_DEBUG_TORCH_SPARSE_ATTN: + return _torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale) + + from vllm.third_party.flashmla.flash_mla_interface import flash_mla_sparse_fwd + + q_pad, sink_pad, real_heads = _pad_heads_for_flashmla(q[0], attn_sink) + indices, topk_lens = _compact_topk_indices(topk_idxs[0].to(torch.int32), kv.shape[1]) + indices = _pad_topk_for_flashmla(indices).unsqueeze(1) + kv_flat = kv[0].unsqueeze(1).contiguous() + out, _, _ = flash_mla_sparse_fwd( + q=q_pad, + kv=kv_flat, + indices=indices, + sm_scale=scale, + attn_sink=sink_pad, + topk_length=topk_lens, + out=None, + ) + return out[:, :real_heads].unsqueeze(0).to(q.dtype) def build_prefill_topk_idxs(seqlen, window, ratio, n_window, device): @@ -40,11 +107,9 @@ def build_prefill_topk_idxs(seqlen, window, ratio, n_window, device): entries are attended (matches the reference for short context). """ t = torch.arange(seqlen, device=device) - # sliding window: query t attends tokens [max(0, t-window+1) .. t] - j = torch.arange(n_window, device=device) - win = j.unsqueeze(0).expand(seqlen, n_window).clone() # [s, n_window] - win_valid = (j.unsqueeze(0) <= t.unsqueeze(1)) & (j.unsqueeze(0) > (t.unsqueeze(1) - window)) - win = torch.where(win_valid, win, torch.full_like(win, -1)) + offsets = torch.arange(window, device=device) + win = t.unsqueeze(1) - (window - 1 - offsets).unsqueeze(0) + win = torch.where(win >= 0, win, torch.full_like(win, -1)) if ratio: ncomp = seqlen // ratio c = torch.arange(ncomp, device=device) diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py index 902de113db..c91799f9ee 100644 --- a/lightllm/models/deepseek_v4/layer_infer/compressor.py +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -1,7 +1,20 @@ +import importlib.util +import logging +import sys +import types +from pathlib import Path + import torch import torch.nn.functional as F from ..triton_kernel.rotary_emb import apply_rotary_emb +logger = logging.getLogger(__name__) + +_SGLANG_COMPRESS_MOD = None +_SGLANG_COMPRESS_ERR = None +_SGLANG_COMPRESS_WARNED = False +_FREQ_CIS_CACHE = {} + # KV compressor: pools every `ratio` consecutive tokens into one compressed KV entry via gated # (softmax) pooling + a learned absolute-position bias (ape), RMSNorm, and rope on the trailing # rope_dim. ratio==4 uses overlapping windows (two-series Ca/Cb scheme). Pure-torch transcription of @@ -25,6 +38,231 @@ def _rmsnorm(x, weight, eps): return (xf * weight.float()).to(x.dtype) +def _load_file_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + +def _load_sglang_compressor(): + global _SGLANG_COMPRESS_MOD, _SGLANG_COMPRESS_ERR + if _SGLANG_COMPRESS_MOD is not None: + return _SGLANG_COMPRESS_MOD + if _SGLANG_COMPRESS_ERR is not None: + raise _SGLANG_COMPRESS_ERR + try: + from sglang.jit_kernel.dsv4 import compress_old as mod + + _SGLANG_COMPRESS_MOD = mod + return mod + except Exception as first_exc: + root = Path("/data/wanzihao/sglang/python/sglang") + try: + if not root.exists(): + raise first_exc + if "sglang" not in sys.modules: + sglang_mod = types.ModuleType("sglang") + sglang_mod.__path__ = [str(root)] + sys.modules["sglang"] = sglang_mod + if "sglang.utils" not in sys.modules: + utils_mod = types.ModuleType("sglang.utils") + utils_mod.is_in_ci = lambda: False + sys.modules["sglang.utils"] = utils_mod + if "sglang.jit_kernel" not in sys.modules: + jit_mod = types.ModuleType("sglang.jit_kernel") + jit_mod.__path__ = [str(root / "jit_kernel")] + sys.modules["sglang.jit_kernel"] = jit_mod + if "sglang.jit_kernel.dsv4" not in sys.modules: + dsv4_mod = types.ModuleType("sglang.jit_kernel.dsv4") + dsv4_mod.__path__ = [str(root / "jit_kernel" / "dsv4")] + sys.modules["sglang.jit_kernel.dsv4"] = dsv4_mod + if "sglang.srt" not in sys.modules: + srt_mod = types.ModuleType("sglang.srt") + srt_mod.__path__ = [str(root / "srt")] + sys.modules["sglang.srt"] = srt_mod + if "sglang.srt.environ" not in sys.modules: + env_mod = types.ModuleType("sglang.srt.environ") + + class _FalseEnv: + def get(self): + return False + + class _Envs: + SGLANG_OPT_USE_ONLINE_COMPRESS = _FalseEnv() + + env_mod.envs = _Envs() + sys.modules["sglang.srt.environ"] = env_mod + if "sglang.jit_kernel.utils" not in sys.modules: + _load_file_module("sglang.jit_kernel.utils", root / "jit_kernel" / "utils.py") + if "sglang.jit_kernel.dsv4.utils" not in sys.modules: + _load_file_module( + "sglang.jit_kernel.dsv4.utils", + root / "jit_kernel" / "dsv4" / "utils.py", + ) + _SGLANG_COMPRESS_MOD = _load_file_module( + "sglang.jit_kernel.dsv4.compress_old", + root / "jit_kernel" / "dsv4" / "compress_old.py", + ) + return _SGLANG_COMPRESS_MOD + except Exception as exc: + _SGLANG_COMPRESS_ERR = exc + raise exc + + +def _warn_sglang_fallback(exc): + global _SGLANG_COMPRESS_WARNED + if not _SGLANG_COMPRESS_WARNED: + logger.warning("DeepSeek-V4 SGLang compressor JIT unavailable, fallback to torch: %s", exc) + _SGLANG_COMPRESS_WARNED = True + + +def _freq_cis(cos_table, sin_table): + key = ( + cos_table.data_ptr(), + sin_table.data_ptr(), + cos_table.device, + tuple(cos_table.shape), + tuple(sin_table.shape), + ) + cached = _FREQ_CIS_CACHE.get(key) + if cached is None: + cached = torch.complex(cos_table.float(), sin_table.float()) + _FREQ_CIS_CACHE[key] = cached + return cached + + +def _sglang_ape(ape, ratio, head_dim): + if ratio == 4: + return torch.cat([ape[:, :head_dim], ape[:, head_dim:]], dim=0).contiguous() + return ape.contiguous() + + +def _pack_kv_score(kv, score, ratio, head_dim): + if ratio == 4: + return torch.cat( + [ + kv[:, :head_dim], + kv[:, head_dim:], + score[:, :head_dim], + score[:, head_dim:], + ], + dim=1, + ).contiguous() + return torch.cat([kv, score], dim=1).contiguous() + + +def _build_state_from_kv_score(kv, score, ape, ratio, head_dim): + overlap = ratio == 4 + kv_state, score_state = new_compressor_state(ratio, head_dim, kv.device) + s = kv.shape[0] + remainder = s % ratio + cutoff = s - remainder + offset = ratio if overlap else 0 + if overlap and cutoff >= ratio: + kv_state[:ratio] = kv[cutoff - ratio : cutoff] + score_state[:ratio] = score[cutoff - ratio : cutoff] + ape.float() + if remainder > 0: + kv_state[offset : offset + remainder] = kv[cutoff:] + score_state[offset : offset + remainder] = score[cutoff:] + ape.float()[:remainder] + return kv_state, score_state + + +def _sglang_prefill_from_kv_score( + kv, + score, + norm_w, + ape, + ratio, + head_dim, + cos_table, + sin_table, + eps, + dtype, + state_pool=None, +): + if not kv.is_cuda or head_dim % 128 != 0 or ratio not in (4, 128): + return None, None + mod = _load_sglang_compressor() + kv_score = _pack_kv_score(kv, score, ratio, head_dim) + ape_sglang = _sglang_ape(ape.float(), ratio, head_dim) + slots = 8 if ratio == 4 else ratio + if state_pool is None: + state_pool = torch.zeros((1, slots, kv_score.shape[1]), device=kv.device, dtype=kv_score.dtype) + else: + state_pool.zero_() + seq_len = kv.shape[0] + plan = mod.CompressorPrefillPlan.generate( + ratio, + seq_len, + torch.tensor([seq_len], dtype=torch.int64), + torch.tensor([seq_len], dtype=torch.int64), + kv.device, + ) + indices = torch.zeros((1,), device=kv.device, dtype=torch.int32) + out = mod.compress_forward( + state_pool, + kv_score, + ape_sglang, + indices, + plan, + head_dim=head_dim, + compress_ratio=ratio, + ) + ncomp = seq_len // ratio + if ncomp: + mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) + ragged_ids = plan.compress_plan.view(torch.int32)[:ncomp, 0].long() + out = out.index_select(0, ragged_ids).to(dtype) + else: + out = kv.new_zeros(0, head_dim).to(dtype) + return out, state_pool + + +def _sglang_decode_step_from_state_pool( + x_new, + wkv_w, + wgate_w, + norm_w, + ape, + ratio, + head_dim, + cos_table, + sin_table, + eps, + start_pos, + state_pool, +): + if state_pool is None or not x_new.is_cuda or head_dim % 128 != 0 or ratio not in (4, 128): + return None, False + mod = _load_sglang_compressor() + xf = x_new.float().view(1, -1) + kv = F.linear(xf, wkv_w.float()) + score = F.linear(xf, wgate_w.float()) + kv_score = _pack_kv_score(kv, score, ratio, head_dim) + ape_sglang = _sglang_ape(ape.float(), ratio, head_dim) + seq_len = start_pos + 1 + plan = mod.CompressorDecodePlan( + ratio, + torch.tensor([seq_len], device=x_new.device, dtype=torch.int32), + ) + indices = torch.zeros((1,), device=x_new.device, dtype=torch.int32) + out = mod.compress_forward( + state_pool, + kv_score, + ape_sglang, + indices, + plan, + head_dim=head_dim, + compress_ratio=ratio, + ) + if seq_len % ratio != 0: + return None, True + mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) + return out[0].to(x_new.dtype), True + + def compress_prefill(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, rope_dim, cos_table, sin_table, eps): """x:[s,dim] (one request, start_pos=0) -> compressed kv [nwin, head_dim] (rope applied to last rope_dim). @@ -69,7 +307,21 @@ def _finish_entry(kv, norm_w, ape_unused, rope_dim, cos_table, sin_table, positi return torch.cat([kv[:-rope_dim], kv_rope], dim=0) -def compressor_prefill_state(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, rope_dim, cos_table, sin_table, eps): +def compressor_prefill_state( + x, + wkv_w, + wgate_w, + norm_w, + ape, + ratio, + head_dim, + rope_dim, + cos_table, + sin_table, + eps, + return_state_pool=False, + state_pool=None, +): """Faithful reference start_pos==0 path (incl. remainder). Returns (entries[ncomp,d], kv_state, score_state). entries have rope applied; kv_state/score_state carry the partial window for the decode path. @@ -83,21 +335,40 @@ def compressor_prefill_state(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, ro kv = F.linear(xf, wkv_w.float()) # [s, coff*d] score = F.linear(xf, wgate_w.float()) # [s, coff*d] ape = ape.float() - kv_state, score_state = new_compressor_state(ratio, head_dim, x.device) + kv_state, score_state = _build_state_from_kv_score(kv, score, ape, ratio, head_dim) + sglang_state_pool = state_pool + try: + comp, sglang_state_pool = _sglang_prefill_from_kv_score( + kv, + score, + norm_w, + ape, + ratio, + head_dim, + cos_table, + sin_table, + eps, + dtype, + state_pool=sglang_state_pool, + ) + if comp is not None: + if return_state_pool: + return comp, kv_state, score_state, sglang_state_pool + return comp, kv_state, score_state + except Exception as exc: + _warn_sglang_fallback(exc) + should_compress = s >= ratio remainder = s % ratio cutoff = s - remainder - offset = ratio if overlap else 0 - if overlap and cutoff >= ratio: - kv_state[:ratio] = kv[cutoff - ratio : cutoff] - score_state[:ratio] = score[cutoff - ratio : cutoff] + ape if remainder > 0: - kv_state[offset : offset + remainder] = kv[cutoff:] - score_state[offset : offset + remainder] = score[cutoff:] + ape[:remainder] kv = kv[:cutoff] score = score[:cutoff] if not should_compress: - return x.new_zeros(0, head_dim), kv_state, score_state + comp = x.new_zeros(0, head_dim) + if return_state_pool: + return comp, kv_state, score_state, sglang_state_pool + return comp, kv_state, score_state nwin = cutoff // ratio kvw = kv.view(nwin, ratio, coff * d) scw = score.view(nwin, ratio, coff * d) + ape @@ -109,6 +380,8 @@ def compressor_prefill_state(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, ro pos = torch.arange(nwin, device=x.device) * ratio comp_rope = apply_rotary_emb(comp[:, -rope_dim:], cos_table[pos], sin_table[pos]) comp = torch.cat([comp[:, :-rope_dim], comp_rope], dim=1) + if return_state_pool: + return comp, kv_state, score_state, sglang_state_pool return comp, kv_state, score_state @@ -127,12 +400,34 @@ def compressor_decode_step( kv_state, score_state, start_pos, + state_pool=None, ): """Faithful reference start_pos>0 path for one new token. Mutates kv_state/score_state in place. - Returns the new compressed entry [d] (rope applied) when a window completes, else None.""" + Returns the new compressed entry [d] (rope applied) when a window completes, else None. + """ overlap = ratio == 4 d = head_dim dtype = x_new.dtype + try: + entry, handled = _sglang_decode_step_from_state_pool( + x_new, + wkv_w, + wgate_w, + norm_w, + ape, + ratio, + head_dim, + cos_table, + sin_table, + eps, + start_pos, + state_pool, + ) + if handled: + return entry + except Exception as exc: + _warn_sglang_fallback(exc) + xf = x_new.float().view(-1) # [dim] kv = F.linear(xf, wkv_w.float()) # [coff*d] score = F.linear(xf, wgate_w.float()) + ape.float()[start_pos % ratio] # [coff*d] @@ -153,4 +448,14 @@ def compressor_decode_step( entry = (kv_state * torch.softmax(score_state, dim=0)).sum(dim=0) # [d] if not should_compress: return None - return _finish_entry(entry, norm_w, ape, rope_dim, cos_table, sin_table, start_pos + 1 - ratio, eps, dtype) + return _finish_entry( + entry, + norm_w, + ape, + rope_dim, + cos_table, + sin_table, + start_pos + 1 - ratio, + eps, + dtype, + ) diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py index 75f540725b..78cdb3a3f8 100644 --- a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py +++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py @@ -1,58 +1,50 @@ import torch -import torch.nn.functional as F - -# Manifold-constrained Hyper-Connections (mHC). Replaces the plain residual add: the hidden state is -# carried as ``hc_mult`` parallel streams. Each sub-layer (attn / ffn) collapses the streams to one -# vector (hc_pre), runs the sub-layer, then re-expands into the streams via learned post/comb weights -# (hc_post). A doubly-stochastic (Sinkhorn-normalized) ``comb`` matrix mixes the residual streams. -# Pure-torch transcription of the bundled reference inference/model.py (Block.hc_pre/hc_post, -# ParallelHead.hc_head) + inference/kernel.py (hc_split_sinkhorn). All math in fp32, as in the reference. - - -def hc_split_sinkhorn(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps): - """mixes:[N, (2+hc)*hc] fp32 -> pre[N,hc], post[N,hc], comb[N,hc,hc] (doubly stochastic).""" - hc = hc_mult - pre = torch.sigmoid(mixes[:, :hc] * hc_scale[0] + hc_base[:hc]) + eps - post = 2.0 * torch.sigmoid(mixes[:, hc : 2 * hc] * hc_scale[1] + hc_base[hc : 2 * hc]) - comb = mixes[:, 2 * hc :].view(-1, hc, hc) * hc_scale[2] + hc_base[2 * hc :].view(hc, hc) - # comb = softmax(comb, dim=-1) + eps - comb = torch.softmax(comb, dim=-1) + eps - # one column normalization, then (iters-1) of (row, column) - comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) - for _ in range(sinkhorn_iters - 1): - comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) - comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) - return pre, post, comb + + +def _ensure_vllm_mhc_ops(): + try: + import vllm.model_executor.layers.mhc # noqa: F401 + except Exception as e: + raise RuntimeError("DeepSeek-V4 requires vLLM mHC custom ops; failed to import vllm MHC kernels") from e def hc_pre(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, eps, sinkhorn_iters): - """streams:[N, hc*dim] -> (collapsed[N,dim], post[N,hc], comb[N,hc,hc]).""" - dtype = streams.dtype - x = streams.float() # [N, hc*dim] - rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + eps) - mixes = F.linear(x, hc_fn) * rsqrt # [N, (2+hc)*hc] - pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps) - streams3 = x.view(-1, hc_mult, dim) - collapsed = torch.sum(pre.unsqueeze(-1) * streams3, dim=1) # [N, dim] - return collapsed.to(dtype), post, comb + """streams:[N, hc*dim] -> (collapsed[N,dim], post[N,hc,1], comb[N,hc,hc]).""" + _ensure_vllm_mhc_ops() + post, comb, collapsed = torch.ops.vllm.mhc_pre( + residual=streams.view(-1, hc_mult, dim).contiguous(), + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=eps, + hc_pre_eps=eps, + hc_sinkhorn_eps=eps, + hc_post_mult_value=2.0, + sinkhorn_repeat=sinkhorn_iters, + ) + return collapsed, post, comb def hc_post(x, residual, post, comb, hc_mult, dim): """x:[N,dim] sub-layer output, residual:[N, hc*dim] -> [N, hc*dim].""" - res = residual.float().view(-1, hc_mult, dim) # [N, hc, dim] - xf = x.float() - # post: [N,hc] -> [N,hc,dim]; comb mixes residual streams: out[i] = post[i]*x + sum_j comb[i,j]*res[j] - y = post.unsqueeze(-1) * xf.unsqueeze(-2) + torch.einsum("nij,njd->nid", comb, res) - return y.reshape(-1, hc_mult * dim).to(x.dtype) + _ensure_vllm_mhc_ops() + out = torch.ops.vllm.mhc_post(x, residual.view(-1, hc_mult, dim).contiguous(), post, comb) + return out.reshape(-1, hc_mult * dim) def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, eps): - """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim] (sigmoid gate, no sinkhorn).""" - dtype = streams.dtype - x = streams.float() - rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + eps) - mixes = F.linear(x, hc_fn) * rsqrt # [N, hc] - pre = torch.sigmoid(mixes * hc_scale + hc_base) + eps # [N, hc] - streams3 = x.view(-1, hc_mult, dim) - collapsed = torch.sum(pre.unsqueeze(-1) * streams3, dim=1) - return collapsed.to(dtype) + """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim].""" + _ensure_vllm_mhc_ops() + out = torch.empty(streams.shape[0], dim, device=streams.device, dtype=streams.dtype) + torch.ops.vllm.hc_head_fused_kernel( + streams.view(-1, hc_mult, dim).contiguous(), + hc_fn, + hc_scale, + hc_base, + out, + dim, + eps, + eps, + hc_mult, + ) + return out diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index a8dd0bb1e8..98dee7fd8a 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -1,3 +1,5 @@ +import os + import torch import torch.nn.functional as F import torch.distributed as dist @@ -7,21 +9,23 @@ from .hyper_connection import hc_pre, hc_post from ..triton_kernel.rotary_emb import apply_rotary_emb from .compressor import compressor_prefill_state, compressor_decode_step -from .attention import torch_sparse_attn -from ..triton_kernel.quant_convert import dequant_fp4_group_to_bf16 +from .attention import vllm_sparse_attn + + +DSV4_DEBUG_DIRECT_PREFILL_COMP = os.getenv("DSV4_DEBUG_DIRECT_PREFILL_COMP", "0") == "1" +DSV4_DEBUG_DISABLE_COMP_ATTN = os.getenv("DSV4_DEBUG_DISABLE_COMP_ATTN", "0") == "1" class DeepseekV4TransformerLayerInfer(TransformerLayerInferTpl): - """One V4 decoder layer: HC(attn) then HC(ffn). Correctness-first pure-torch. + """One V4 decoder layer: HC(attn) then HC(ffn). The residual is carried as ``hc_mult`` streams flattened to [T, hc_mult*hidden]; each sub-layer collapses (hc_pre), computes, and re-expands (hc_post). Attention is MLA over a sliding window + - compressed KV with a per-head sink (torch_sparse_attn); the MoE reuses lightllm's deepgemm FP8 + compressed KV with a per-head sink (vLLM FlashMLA sparse); the MoE reuses lightllm's deepgemm FP8 grouped GEMM driven by V4's custom router (sqrtsoftplus + hash/topk + bias-for-selection). Per-request decode state (window KV history + compressed KV + compressor running state) is kept in - a dict keyed by request id. NOTE: correctness-first — this should move into the KV mem manager for - production memory management / request eviction. + DeepseekV4ReqManager so request alloc/free owns its lifetime. """ def __init__(self, layer_num, network_config): @@ -32,6 +36,9 @@ def __init__(self, layer_num, network_config): self.n_heads = cfg["num_attention_heads"] self.head_dim = cfg["head_dim"] self.rope_dim = cfg["qk_rope_head_dim"] + self.index_n_heads = cfg["index_n_heads"] + self.index_head_dim = cfg["index_head_dim"] + self.index_topk = cfg["index_topk"] self.o_groups = cfg["o_groups"] self.o_lora = cfg["o_lora_rank"] self.hc_mult = cfg["hc_mult"] @@ -43,12 +50,14 @@ def __init__(self, layer_num, network_config): self.topk = cfg["num_experts_per_tok"] self.route_scale = cfg["routed_scaling_factor"] self.swiglu_limit = cfg["swiglu_limit"] - self.softmax_scale = self.head_dim**-0.5 + self.softmax_scale = self.head_dim ** -0.5 self.tp_q_heads = self.n_heads // self.tp_world_size_ + self.tp_index_heads = self.index_n_heads // self.tp_world_size_ self.tp_groups = self.o_groups // self.tp_world_size_ self.embed_dim_ = self.hc_mult * self.hidden self.enable_ep_moe = get_env_start_args().enable_ep_moe - self._state = {} # req_id -> dict(kv_hist, comp_kv, cstate_kv, cstate_score) + self.indexer_score_scale = self.index_head_dim ** -0.5 + self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 # ------------------------------------------------------------------ forward (HC-wrapped) def _hc_block(self, streams, infer_state, lw, attn_fn): @@ -100,8 +109,14 @@ def _qkv(self, x, cos_tok, sin_tok, lw): dim=-1, ) kv = lw.kv_norm_(lw.wkv_.mm(x), eps=self.eps_) - kv = torch.cat([kv[:, : -self.rope_dim], apply_rotary_emb(kv[:, -self.rope_dim :], cos_tok, sin_tok)], dim=1) - return q, kv + kv = torch.cat( + [ + kv[:, : -self.rope_dim], + apply_rotary_emb(kv[:, -self.rope_dim :], cos_tok, sin_tok), + ], + dim=1, + ) + return q, kv, qa def _out_proj(self, o, infer_state, lw): # o: [T, tp_q_heads, head_dim] -> inverse rope -> grouped low-rank O -> [T, hidden] @@ -117,35 +132,130 @@ def _inv_rope(self, o, cos_tok, sin_tok): return torch.cat( [ o[..., : -self.rope_dim], - apply_rotary_emb(o[..., -self.rope_dim :], cos_tok.unsqueeze(1), sin_tok.unsqueeze(1), inverse=True), + apply_rotary_emb( + o[..., -self.rope_dim :], + cos_tok.unsqueeze(1), + sin_tok.unsqueeze(1), + inverse=True, + ), ], dim=-1, ) + def _post_dense_kv(self, infer_state, req, start_pos, mem_index, kv): + positions = torch.arange( + start_pos, + start_pos + kv.shape[0], + device=mem_index.device, + dtype=torch.long, + ) + infer_state.mem_manager.pack_mla_kv_to_cache( + layer_index=self.layer_num_, + mem_index=mem_index, + kv=kv.reshape(kv.shape[0], 1, kv.shape[-1]), + req_idx=req, + positions=positions, + ) + return + + def _write_compressed_kv(self, infer_state, req, entry_start, comp): + slots = infer_state.req_manager.ensure_compress_slots(self.layer_num_, req, entry_start, comp.shape[0]) + if comp.shape[0] == 0: + return slots + infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, comp) + return slots + + def _write_c4_indexer_k(self, infer_state, slots, idx_comp): + if idx_comp is None or idx_comp.shape[0] == 0: + return + infer_state.mem_manager.pack_c4_indexer_k_to_cache(self.layer_num_, slots, idx_comp) + return + + def _dense_kv_from_cache(self, infer_state, req, start_pos, end_pos): + if end_pos <= start_pos: + return torch.empty((0, self.head_dim), dtype=infer_state.mem_manager.dtype, device="cuda") + slots = infer_state.req_manager.req_to_token_indexs[req, start_pos:end_pos].long() + return infer_state.mem_manager.gather_mla_kv(self.layer_num_, slots) + + def _compressed_kv_from_cache(self, infer_state, req, ncomp): + if ncomp == 0: + return torch.empty((0, self.head_dim), dtype=infer_state.mem_manager.dtype, device="cuda") + if self.compress_ratio == 4: + slots = infer_state.req_manager.req_to_c4_indexs[req, :ncomp].long() + else: + slots = infer_state.req_manager.req_to_c128_indexs[req, :ncomp].long() + return infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots) + + def _c4_indexer_k_from_cache(self, infer_state, req, ncomp): + if self.compress_ratio != 4 or ncomp == 0: + return None + slots = infer_state.req_manager.req_to_c4_indexs[req, :ncomp].long() + return infer_state.mem_manager.gather_c4_indexer_k(self.layer_num_, slots) + # ------------------------------------------------------------------ attention (prefill) def _attention_prefill(self, x, infer_state, lw): T = x.shape[0] if self.compress_ratio: - cos_tok, sin_tok = infer_state.position_cos_compress, infer_state.position_sin_compress + cos_tok, sin_tok = ( + infer_state.position_cos_compress, + infer_state.position_sin_compress, + ) else: - cos_tok, sin_tok = infer_state.position_cos_sliding, infer_state.position_sin_sliding - q, kv = self._qkv(x, cos_tok, sin_tok, lw) + cos_tok, sin_tok = ( + infer_state.position_cos_sliding, + infer_state.position_sin_sliding, + ) + q, kv, qa = self._qkv(x, cos_tok, sin_tok, lw) sink = lw.attn_sink_.weight o = x.new_empty(T, self.tp_q_heads, self.head_dim) b_req = infer_state.b_req_idx.tolist() starts = infer_state.b_q_start_loc.tolist() lens = infer_state.b_q_seq_len.tolist() - for req, st, ln in zip(b_req, starts, lens): + ready_lens = infer_state.b_ready_cache_len.tolist() + idx_q, idx_weight = self._indexer_q_weight( + x, + qa, + infer_state.position_cos_compress, + infer_state.position_sin_compress, + lw, + ) + for req, st, ln, ready_len in zip(b_req, starts, lens, ready_lens): q_r, kv_r, x_r = q[st : st + ln], kv[st : st + ln], x[st : st + ln] - kv_all, n_window, ncomp = self._gather_prefill(x_r, kv_r, req, lw, infer_state) - ti = self._topk_idxs_prefill(ln, n_window, ncomp, x.device) - o[st : st + ln] = torch_sparse_attn(q_r.unsqueeze(0), kv_all.unsqueeze(0), sink, ti, self.softmax_scale)[0] + idx_q_r = None if idx_q is None else idx_q[st : st + ln] + idx_weight_r = None if idx_weight is None else idx_weight[st : st + ln] + kv_all, dense_base, n_window, ncomp, idx_comp = self._gather_prefill( + x_r, kv_r, req, ready_len, lw, infer_state + ) + ti = self._topk_idxs_prefill( + ln, + dense_base, + n_window, + ncomp, + x.device, + ready_len, + idx_q_r, + idx_comp, + idx_weight_r, + infer_state, + ) + o[st : st + ln] = vllm_sparse_attn(q_r.unsqueeze(0), kv_all.unsqueeze(0), sink, ti, self.softmax_scale)[0] + self._post_dense_kv( + infer_state, + req, + ready_len, + infer_state.mem_index[st : st + ln], + kv_r, + ) return self._out_proj(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) - def _gather_prefill(self, x_r, kv_r, req, lw, infer_state): + def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state): ln = kv_r.shape[0] + idx_comp = None + if ready_len > 0: + return self._gather_prefill_extend(x_r, kv_r, req, ready_len, lw, infer_state) if self.compress_ratio: - comp, ks, ss = compressor_prefill_state( + cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) + comp, ks, ss, cstate_pool = compressor_prefill_state( x_r, lw.compressor_wkv_.mm_param.weight, lw.compressor_wgate_.mm_param.weight, @@ -157,27 +267,180 @@ def _gather_prefill(self, x_r, kv_r, req, lw, infer_state): infer_state.cos_compress_table, infer_state.sin_compress_table, self.eps_, + return_state_pool=True, + state_pool=cstate_pool, ) - self._state[req] = {"kv_hist": kv_r.detach(), "comp_kv": comp.detach(), "cstate_kv": ks, "cstate_score": ss} - return torch.cat([kv_r, comp], dim=0), ln, comp.shape[0] - self._state[req] = {"kv_hist": kv_r.detach()} - return kv_r, ln, 0 + comp_slots = self._write_compressed_kv(infer_state, req, 0, comp) + ( + cstate_kv, + cstate_score, + ) = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) + cstate_kv.copy_(ks) + cstate_score.copy_(ss) + state = { + "cstate_kv": cstate_kv, + "cstate_score": cstate_score, + } + if self.compress_ratio == 4: + idx_cstate_pool = infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) + idx_comp, idx_ks, idx_ss, idx_cstate_pool = compressor_prefill_state( + x_r, + lw.idx_cmp_wkv_.mm_param.weight, + lw.idx_cmp_wgate_.mm_param.weight, + lw.idx_cmp_norm_.weight, + lw.idx_cmp_ape_.weight, + 4, + self.index_head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + return_state_pool=True, + state_pool=idx_cstate_pool, + ) + self._write_c4_indexer_k(infer_state, comp_slots, idx_comp) + idx_state = infer_state.req_manager.get_c4_indexer_compress_state(self.layer_num_) + idx_cstate_kv = idx_state[req, 0] + idx_cstate_score = idx_state[req, 1] + idx_cstate_kv.copy_(idx_ks) + idx_cstate_score.copy_(idx_ss) + state.update( + { + "idx_cstate_kv": idx_cstate_kv, + "idx_cstate_score": idx_cstate_score, + } + ) + infer_state.req_manager.set_runtime_state( + req, + self.layer_num_, + state, + ) + ncomp = comp.shape[0] + if DSV4_DEBUG_DISABLE_COMP_ATTN: + return kv_r, 0, ln, 0, None + if not DSV4_DEBUG_DIRECT_PREFILL_COMP: + comp = self._compressed_kv_from_cache(infer_state, req, ncomp) + idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) + return torch.cat([kv_r, comp], dim=0), 0, ln, ncomp, idx_comp + return kv_r, 0, ln, 0, None - def _topk_idxs_prefill(self, seqlen, n_window, ncomp, device): - t = torch.arange(seqlen, device=device) - j = torch.arange(n_window, device=device) - win = torch.where( - (j.unsqueeze(0) <= t.unsqueeze(1)) & (j.unsqueeze(0) > (t.unsqueeze(1) - self.window)), - j.unsqueeze(0).expand(seqlen, n_window), - torch.full((seqlen, n_window), -1, device=device, dtype=torch.long), + def _gather_prefill_extend(self, x_r, kv_r, req, ready_len, lw, infer_state): + if self.compress_ratio: + try: + state = infer_state.req_manager.get_runtime_state(req, self.layer_num_) + except KeyError as exc: + raise RuntimeError( + "DeepSeek-V4 prefill chunk is missing runtime state; radix prompt cache " + "must stay disabled until V4 managed token cache is implemented." + ) from exc + cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) + idx_cstate_pool = ( + infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) + if self.compress_ratio == 4 + else None + ) + + for j in range(x_r.shape[0]): + start_pos = ready_len + j + entry = compressor_decode_step( + x_r[j], + lw.compressor_wkv_.mm_param.weight, + lw.compressor_wgate_.mm_param.weight, + lw.compressor_norm_.weight, + lw.compressor_ape_.weight, + self.compress_ratio, + self.head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + state["cstate_kv"], + state["cstate_score"], + start_pos, + state_pool=cstate_pool, + ) + if entry is not None: + entry_start = (start_pos + 1) // self.compress_ratio - 1 + slots = self._write_compressed_kv(infer_state, req, entry_start, entry.unsqueeze(0)) + if self.compress_ratio == 4: + idx_entry = compressor_decode_step( + x_r[j], + lw.idx_cmp_wkv_.mm_param.weight, + lw.idx_cmp_wgate_.mm_param.weight, + lw.idx_cmp_norm_.weight, + lw.idx_cmp_ape_.weight, + 4, + self.index_head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + state["idx_cstate_kv"], + state["idx_cstate_score"], + start_pos, + state_pool=idx_cstate_pool, + ) + if idx_entry is not None: + if entry is None: + entry_start = (start_pos + 1) // self.compress_ratio - 1 + slots = infer_state.req_manager.ensure_compress_slots(self.layer_num_, req, entry_start, 1) + self._write_c4_indexer_k(infer_state, slots, idx_entry.unsqueeze(0)) + dense_end = ready_len + x_r.shape[0] + ncomp = dense_end // self.compress_ratio + dense_base = max(0, ready_len - self.window + 1) + cached_dense = self._dense_kv_from_cache(infer_state, req, dense_base, ready_len) + dense = torch.cat([cached_dense, kv_r], dim=0) + comp = self._compressed_kv_from_cache(infer_state, req, ncomp) + idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) + if DSV4_DEBUG_DISABLE_COMP_ATTN: + return dense, dense_base, dense.shape[0], 0, None + return ( + torch.cat([dense, comp], dim=0), + dense_base, + dense.shape[0], + ncomp, + idx_comp, + ) + dense_base = max(0, ready_len - self.window + 1) + cached_dense = self._dense_kv_from_cache(infer_state, req, dense_base, ready_len) + dense = torch.cat([cached_dense, kv_r], dim=0) + return ( + dense, + dense_base, + dense.shape[0], + 0, + None, ) + + def _topk_idxs_prefill( + self, + seqlen, + dense_base, + n_window, + ncomp, + device, + base_pos, + idx_q, + idx_comp, + idx_weight, + infer_state, + ): + t = torch.arange(seqlen, device=device) + abs_pos = t + base_pos + offsets = torch.arange(self.window, device=device) + win_abs = abs_pos.unsqueeze(1) - (self.window - 1 - offsets).unsqueeze(0) + valid = (win_abs >= dense_base) & (win_abs < dense_base + n_window) + win = torch.where(valid, win_abs - dense_base, torch.full_like(win_abs, -1)) if ncomp: - c = torch.arange(ncomp, device=device) - comp = torch.where( - c.unsqueeze(0) < ((t.unsqueeze(1) + 1) // self.compress_ratio), - (c.unsqueeze(0) + n_window).expand(seqlen, ncomp), - torch.full((seqlen, ncomp), -1, device=device, dtype=torch.long), - ) + if self.compress_ratio == 4 and ncomp > self.index_topk: + comp = self._indexer_topk(idx_q, idx_comp, idx_weight, abs_pos + 1, n_window, infer_state) + else: + c = torch.arange(ncomp, device=device) + comp = torch.where( + c.unsqueeze(0) < ((abs_pos.unsqueeze(1) + 1) // self.compress_ratio), + (c.unsqueeze(0) + n_window).expand(seqlen, ncomp), + torch.full((seqlen, ncomp), -1, device=device, dtype=torch.long), + ) return torch.cat([win, comp], dim=1).int().unsqueeze(0) return win.int().unsqueeze(0) @@ -185,19 +448,45 @@ def _topk_idxs_prefill(self, seqlen, n_window, ncomp, device): def _attention_decode(self, x, infer_state, lw): B = x.shape[0] # one new token per request if self.compress_ratio: - cos_tok, sin_tok = infer_state.position_cos_compress, infer_state.position_sin_compress + cos_tok, sin_tok = ( + infer_state.position_cos_compress, + infer_state.position_sin_compress, + ) else: - cos_tok, sin_tok = infer_state.position_cos_sliding, infer_state.position_sin_sliding - q, kv = self._qkv(x, cos_tok, sin_tok, lw) # [B, heads, hd], [B, hd] + cos_tok, sin_tok = ( + infer_state.position_cos_sliding, + infer_state.position_sin_sliding, + ) + q, kv, qa = self._qkv(x, cos_tok, sin_tok, lw) # [B, heads, hd], [B, hd] + idx_q, idx_weight = self._indexer_q_weight( + x, + qa, + infer_state.position_cos_compress, + infer_state.position_sin_compress, + lw, + ) sink = lw.attn_sink_.weight b_req = infer_state.b_req_idx.tolist() seqlens = infer_state.b_seq_len.tolist() o = x.new_empty(B, self.tp_q_heads, self.head_dim) for i, (req, seq) in enumerate(zip(b_req, seqlens)): - stt = self._state[req] - stt["kv_hist"] = torch.cat([stt["kv_hist"], kv[i : i + 1]], dim=0) start_pos = seq - 1 + self._post_dense_kv( + infer_state, + req, + start_pos, + infer_state.mem_index[i : i + 1], + kv[i : i + 1], + ) if self.compress_ratio: + try: + stt = infer_state.req_manager.get_runtime_state(req, self.layer_num_) + except KeyError as exc: + raise RuntimeError( + "DeepSeek-V4 decode is missing runtime state; radix prompt cache " + "must stay disabled until V4 managed token cache is implemented." + ) from exc + cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) e = compressor_decode_step( x[i], lw.compressor_wkv_.mm_param.weight, @@ -213,58 +502,257 @@ def _attention_decode(self, x, infer_state, lw): stt["cstate_kv"], stt["cstate_score"], start_pos, + state_pool=cstate_pool, ) + entry_slots = None if e is not None: - stt["comp_kv"] = torch.cat([stt["comp_kv"], e.unsqueeze(0)], dim=0) - win_kv = stt["kv_hist"][-self.window :] - kv_all = torch.cat([win_kv, stt["comp_kv"]], dim=0) + entry_start = (start_pos + 1) // self.compress_ratio - 1 + entry_slots = self._write_compressed_kv(infer_state, req, entry_start, e.unsqueeze(0)) + if self.compress_ratio == 4: + idx_cstate_pool = infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) + idx_e = compressor_decode_step( + x[i], + lw.idx_cmp_wkv_.mm_param.weight, + lw.idx_cmp_wgate_.mm_param.weight, + lw.idx_cmp_norm_.weight, + lw.idx_cmp_ape_.weight, + 4, + self.index_head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + stt["idx_cstate_kv"], + stt["idx_cstate_score"], + start_pos, + state_pool=idx_cstate_pool, + ) + if idx_e is not None: + if entry_slots is None: + entry_start = (start_pos + 1) // self.compress_ratio - 1 + entry_slots = infer_state.req_manager.ensure_compress_slots( + self.layer_num_, req, entry_start, 1 + ) + self._write_c4_indexer_k(infer_state, entry_slots, idx_e.unsqueeze(0)) + win_start = max(0, seq - self.window) + win_kv = self._dense_kv_from_cache(infer_state, req, win_start, seq) + comp_kv = self._compressed_kv_from_cache(infer_state, req, seq // self.compress_ratio) + idx_comp = self._c4_indexer_k_from_cache(infer_state, req, comp_kv.shape[0]) + if DSV4_DEBUG_DISABLE_COMP_ATTN: + comp_kv = None + idx_comp = None + kv_all = win_kv + else: + kv_all = torch.cat([win_kv, comp_kv], dim=0) else: - win_kv = stt["kv_hist"][-self.window :] + win_start = max(0, seq - self.window) + win_kv = self._dense_kv_from_cache(infer_state, req, win_start, seq) kv_all = win_kv - ti = torch.arange(kv_all.shape[0], device=x.device).view(1, 1, -1).int() - o[i] = torch_sparse_attn( - q[i].view(1, 1, self.tp_q_heads, self.head_dim), kv_all.unsqueeze(0), sink, ti, self.softmax_scale + comp_kv = None + idx_comp = None + ti = self._topk_idxs_decode( + win_kv.shape[0], + comp_kv, + None if idx_q is None else idx_q[i : i + 1], + idx_comp, + None if idx_weight is None else idx_weight[i : i + 1], + seq, + x.device, + infer_state, + ) + o[i] = vllm_sparse_attn( + q[i].view(1, 1, self.tp_q_heads, self.head_dim), + kv_all.unsqueeze(0), + sink, + ti, + self.softmax_scale, )[0, 0] return self._out_proj(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + def _indexer_q_weight(self, x, qa, cos_tok, sin_tok, lw): + if self.compress_ratio != 4: + return None, None + idx_q = lw.idx_wq_b_.mm(qa).view(x.shape[0], self.tp_index_heads, self.index_head_dim) + idx_q = torch.cat( + [ + idx_q[..., : -self.rope_dim], + apply_rotary_emb( + idx_q[..., -self.rope_dim :], + cos_tok.unsqueeze(1), + sin_tok.unsqueeze(1), + ), + ], + dim=-1, + ) + idx_weight = lw.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale + return idx_q, idx_weight + + def _indexer_topk(self, idx_q, idx_comp, idx_weight, positions_1based, offset, infer_state): + ncomp = idx_comp.shape[0] + k = min(self.index_topk, ncomp) + if k == 0: + return torch.empty((idx_q.shape[0], 0), device=idx_q.device, dtype=torch.long) + + scores = torch.einsum("thd,nd->thn", idx_q.float(), idx_comp.float()) + scores = F.relu(scores) * self.indexer_score_scale + index_scores = (scores * idx_weight.unsqueeze(-1)).sum(dim=1) + if self.tp_world_size_ > 1: + all_reduce( + index_scores, + op=dist.ReduceOp.SUM, + group=infer_state.dist_group, + async_op=False, + ) + + causal_threshold = positions_1based // 4 + top = self._indexer_topk_kernel(index_scores, causal_threshold, k) + valid = top >= 0 + return torch.where(valid, top + offset, torch.full_like(top, -1)) + + def _indexer_topk_kernel(self, index_scores, causal_threshold, topk): + if index_scores.is_cuda: + try: + import vllm._C # noqa: F401 + + scores = index_scores.contiguous() + lengths = causal_threshold.to(torch.int32).contiguous() + starts = torch.zeros_like(lengths, dtype=torch.int32) + top = torch.empty((scores.shape[0], topk), dtype=torch.int32, device=scores.device) + torch.ops._C.top_k_per_row_prefill( + scores, + starts, + lengths, + top, + scores.shape[0], + scores.stride(0), + scores.stride(1), + topk, + ) + return top.long() + except Exception: + pass + + entry_indices = torch.arange(index_scores.shape[1], device=index_scores.device) + index_scores = index_scores.masked_fill( + entry_indices.unsqueeze(0) >= causal_threshold.unsqueeze(1), float("-inf") + ) + top = index_scores.topk(topk, dim=-1).indices + valid = top < causal_threshold.unsqueeze(1) + return torch.where(valid, top, torch.full_like(top, -1)) + + def _topk_idxs_decode( + self, + win_len, + comp_kv, + idx_q, + idx_comp, + idx_weight, + seq_len, + device, + infer_state, + ): + win = torch.arange(win_len, device=device, dtype=torch.long) + if comp_kv is None or comp_kv.shape[0] == 0: + return win.view(1, 1, -1).int() + ncomp = comp_kv.shape[0] + if self.compress_ratio == 4 and ncomp > self.index_topk: + comp = self._indexer_topk( + idx_q, + idx_comp, + idx_weight, + torch.tensor([seq_len], device=device, dtype=torch.long), + win_len, + infer_state, + )[0] + else: + comp = torch.arange(ncomp, device=device, dtype=torch.long) + win_len + return torch.cat([win, comp], dim=0).view(1, 1, -1).int() + # ------------------------------------------------------------------ moe def _fp4_experts(self, x, weights, indices, lw): experts = lw.experts_ - out = torch.zeros(x.shape, device=x.device, dtype=torch.float32) - counts = torch.bincount(indices.reshape(-1), minlength=experts.n_routed_experts) - for expert_id in torch.nonzero(counts, as_tuple=False).flatten().tolist(): - token_idx, top_idx = torch.where(indices == expert_id) - if token_idx.numel() == 0: - continue - x_i = x[token_idx] - w1 = dequant_fp4_group_to_bf16(experts.w1[expert_id], experts.w1_scale[expert_id]) - w3 = dequant_fp4_group_to_bf16(experts.w3[expert_id], experts.w3_scale[expert_id]) - gate = F.linear(x_i, w1).float().clamp(max=self.swiglu_limit) - up = F.linear(x_i, w3).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) - hidden = F.silu(gate) * up - hidden.mul_(weights[token_idx, top_idx].unsqueeze(-1)) - w2 = dequant_fp4_group_to_bf16(experts.w2[expert_id], experts.w2_scale[expert_id]) - out.index_add_(0, token_idx, F.linear(hidden.to(x.dtype), w2).float()) - return out.to(x.dtype) + if getattr(experts, "moe_backend", None) != "marlin": + err = getattr(experts, "moe_backend_error", "unknown") + raise RuntimeError(f"DeepSeek-V4 FP4 MoE requires vLLM Marlin backend, init_error={err}") + return self._fp4_experts_marlin(x, weights, indices, experts) + + def _fp4_experts_marlin(self, x, weights, indices, experts): + from vllm.model_executor.layers.fused_moe.activation import MoEActivation + from vllm.model_executor.layers.fused_moe.experts.marlin_moe import ( + fused_marlin_moe, + ) + from vllm.scalar_type import scalar_types + + return fused_marlin_moe( + hidden_states=x.contiguous(), + w1=experts.marlin_w13, + w2=experts.marlin_w2, + bias1=None, + bias2=None, + w1_scale=experts.marlin_w13_scale, + w2_scale=experts.marlin_w2_scale, + topk_weights=weights.to(torch.float32).contiguous(), + topk_ids=indices.to(torch.long).contiguous(), + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=experts.n_routed_experts, + activation=MoEActivation.SILU, + clamp_limit=float(self.swiglu_limit), + ) def _moe_ffn(self, x, infer_state, lw): gw = lw.gate_weight_.mm_param.weight - scores = F.softplus(F.linear(x.float(), gw.float())).sqrt() # sqrtsoftplus - if self.is_hash: - indices = lw.gate_tid2eid_.weight[infer_state.input_ids.long()] - else: - indices = (scores + lw.gate_bias_.weight.unsqueeze(0)).topk(self.topk, dim=-1)[1] - weights = scores.gather(1, indices) - weights = (weights / (weights.sum(-1, keepdim=True) + 1e-20) * self.route_scale).to(torch.float32) - routed = self._fp4_experts(x, weights, indices.long(), lw) + logits = F.linear(x.float(), gw.float()).contiguous() + weights, indices = self._select_experts(logits, infer_state, lw) + routed = self._fp4_experts(x, weights, indices, lw) g = lw.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) u = lw.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) shared = lw.shared_down_.mm((F.silu(g) * u).to(x.dtype)) if self.enable_ep_moe and getattr(lw.experts_, "is_ep", False): if self.tp_world_size_ > 1: - all_reduce(shared, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + all_reduce( + shared, + op=dist.ReduceOp.SUM, + group=infer_state.dist_group, + async_op=False, + ) return routed + shared out = routed + shared if self.tp_world_size_ > 1: all_reduce(out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return out + + def _select_experts(self, logits, infer_state, lw): + return self._select_experts_vllm(logits, infer_state, lw) + + def _select_experts_vllm(self, logits, infer_state, lw): + from vllm import _custom_ops as ops + + M = logits.shape[0] + bias = None + input_tokens = None + hash_indices_table = None + indices_dtype = torch.int64 + if self.is_hash: + hash_indices_table = lw.gate_tid2eid_.weight + if not hash_indices_table.is_contiguous(): + hash_indices_table = hash_indices_table.contiguous() + indices_dtype = hash_indices_table.dtype + input_tokens = infer_state.input_ids.to(dtype=indices_dtype).contiguous() + else: + bias = lw.gate_bias_.weight + + weights = torch.empty((M, self.topk), dtype=torch.float32, device=logits.device) + indices = torch.empty((M, self.topk), dtype=indices_dtype, device=logits.device) + token_expert_indices = torch.empty((M, self.topk), dtype=torch.int32, device=logits.device) + ops.topk_hash_softplus_sqrt( + weights, + indices, + token_expert_indices, + logits, + True, + self.route_scale, + bias, + input_tokens, + hash_indices_table, + ) + return weights, indices.long() diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py index 7c12f714db..cdaaac2cdb 100644 --- a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -1,3 +1,5 @@ +import threading + import torch from lightllm.common.basemodel import TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( @@ -10,10 +12,16 @@ ) from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.utils.log_utils import init_logger from ..triton_kernel.quant_convert import dequant_fp8_block_to_bf16 +logger = init_logger(__name__) + + class DeepseekV4FP4ExpertsWeight(BaseWeightTpl): + _marlin_pack_lock = threading.Lock() + def __init__(self, weight_prefix, n_routed_experts, hidden_size, moe_intermediate_size, data_type): super().__init__(data_type=data_type) self.weight_prefix = weight_prefix @@ -23,10 +31,21 @@ def __init__(self, weight_prefix, n_routed_experts, hidden_size, moe_intermediat self.split_inter_size = moe_intermediate_size // self.tp_world_size_ self.local_expert_ids = list(range(n_routed_experts)) self.expert_idx_to_local_idx = {expert_idx: expert_idx for expert_idx in self.local_expert_ids} - self._create_weight() + self.moe_backend = None + self.moe_backend_error = None + self._marlin_checked = False + self._load_lock = threading.Lock() + self.load_ok = { + name: [False] * n_routed_experts for name in ("w1", "w1_scale", "w2", "w2_scale", "w3", "w3_scale") + } def _create_weight(self): - device = f"cuda:{self.device_id_}" + self._ensure_raw_fp4_weight() + + def _ensure_raw_fp4_weight(self): + if hasattr(self, "w1"): + return + device = "cpu" n = self.n_routed_experts h = self.hidden_size inter = self.split_inter_size @@ -36,10 +55,6 @@ def _create_weight(self): self.w1_scale = torch.empty((n, inter, h // 32), dtype=torch.float8_e8m0fnu, device=device) self.w3_scale = torch.empty((n, inter, h // 32), dtype=torch.float8_e8m0fnu, device=device) self.w2_scale = torch.empty((n, h, inter // 32), dtype=torch.float8_e8m0fnu, device=device) - self.load_ok = { - name: [False] * n - for name in ("w1", "w1_scale", "w2", "w2_scale", "w3", "w3_scale") - } def _copy_expert_weight(self, dst, weight, expert_idx, name, is_down=False): if is_down: @@ -66,30 +81,122 @@ def _copy_expert_scale(self, dst, scale, expert_idx, name, is_down=False): self.load_ok[name][expert_idx] = True def load_hf_weights(self, weights): + if self._marlin_checked: + return + has_weight = False for expert_idx in self.local_expert_ids: prefix = f"{self.weight_prefix}.{expert_idx}" - w1 = f"{prefix}.w1.weight" - w1_scale = f"{prefix}.w1.scale" - w2 = f"{prefix}.w2.weight" - w2_scale = f"{prefix}.w2.scale" - w3 = f"{prefix}.w3.weight" - w3_scale = f"{prefix}.w3.scale" - if w1 in weights: - self._copy_expert_weight(self.w1, weights[w1], expert_idx, "w1") - if w1_scale in weights: - self._copy_expert_scale(self.w1_scale, weights[w1_scale], expert_idx, "w1_scale") - if w3 in weights: - self._copy_expert_weight(self.w3, weights[w3], expert_idx, "w3") - if w3_scale in weights: - self._copy_expert_scale(self.w3_scale, weights[w3_scale], expert_idx, "w3_scale") - if w2 in weights: - self._copy_expert_weight(self.w2, weights[w2], expert_idx, "w2", is_down=True) - if w2_scale in weights: - self._copy_expert_scale(self.w2_scale, weights[w2_scale], expert_idx, "w2_scale", is_down=True) + if ( + f"{prefix}.w1.weight" in weights + or f"{prefix}.w1.scale" in weights + or f"{prefix}.w2.weight" in weights + or f"{prefix}.w2.scale" in weights + or f"{prefix}.w3.weight" in weights + or f"{prefix}.w3.scale" in weights + ): + has_weight = True + break + if not has_weight: + return + + with self._load_lock: + if self._marlin_checked: + return + self._ensure_raw_fp4_weight() + for expert_idx in self.local_expert_ids: + prefix = f"{self.weight_prefix}.{expert_idx}" + w1 = f"{prefix}.w1.weight" + w1_scale = f"{prefix}.w1.scale" + w2 = f"{prefix}.w2.weight" + w2_scale = f"{prefix}.w2.scale" + w3 = f"{prefix}.w3.weight" + w3_scale = f"{prefix}.w3.scale" + if w1 in weights: + self._copy_expert_weight(self.w1, weights[w1], expert_idx, "w1") + if w1_scale in weights: + self._copy_expert_scale(self.w1_scale, weights[w1_scale], expert_idx, "w1_scale") + if w3 in weights: + self._copy_expert_weight(self.w3, weights[w3], expert_idx, "w3") + if w3_scale in weights: + self._copy_expert_scale(self.w3_scale, weights[w3_scale], expert_idx, "w3_scale") + if w2 in weights: + self._copy_expert_weight(self.w2, weights[w2], expert_idx, "w2", is_down=True) + if w2_scale in weights: + self._copy_expert_scale(self.w2_scale, weights[w2_scale], expert_idx, "w2_scale", is_down=True) + if self._raw_load_complete(): + self._try_init_marlin() def verify_load(self): + with self._load_lock: + ok = self._raw_load_complete() + if ok and not self._marlin_checked: + self._try_init_marlin() + return ok + + def _raw_load_complete(self): return all(all(ok_list) for ok_list in self.load_ok.values()) + def _try_init_marlin(self): + try: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + + class _MarlinLayer: + pass + + with self._marlin_pack_lock: + torch.cuda.set_device(self.device_id_) + device = torch.device("cuda", self.device_id_) + layer = _MarlinLayer() + layer.params_dtype = self.data_type_ + w13_cpu, w13_scale_cpu = self._build_w13_weight() + w13 = w13_cpu.to(device=device, non_blocking=True).contiguous() + w2 = self.w2.view(torch.uint8).to(device=device, non_blocking=True).contiguous() + w13_scale = w13_scale_cpu.to(device=device, non_blocking=True).contiguous() + w2_scale = self.w2_scale.to(device=device, non_blocking=True).contiguous() + ( + self.marlin_w13, + self.marlin_w2, + self.marlin_w13_scale, + self.marlin_w2_scale, + _, + _, + ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None) + del w13_cpu, w13_scale_cpu, w13, w2, w13_scale, w2_scale + self.moe_backend = "marlin" + self._marlin_checked = True + self._release_raw_fp4_weight() + torch.cuda.empty_cache() + logger.info( + "DeepSeek-V4 FP4 experts use vLLM Marlin backend, prefix=%s, rank=%s", + self.weight_prefix, + self.tp_rank_, + ) + except Exception as e: + self.moe_backend_error = repr(e) + raise RuntimeError( + "DeepSeek-V4 FP4 experts require vLLM Marlin backend, " + f"prefix={self.weight_prefix}, rank={self.tp_rank_}, error={self.moe_backend_error}" + ) from e + + def _build_w13_weight(self): + n = self.n_routed_experts + h = self.hidden_size + inter = self.split_inter_size + w13 = torch.empty((n, 2 * inter, h // 2), dtype=torch.uint8, device=self.w1.device) + w13[:, :inter, :].copy_(self.w1.view(torch.uint8)) + w13[:, inter:, :].copy_(self.w3.view(torch.uint8)) + w13_scale = torch.empty((n, 2 * inter, h // 32), dtype=self.w1_scale.dtype, device=self.w1_scale.device) + w13_scale[:, :inter, :].copy_(self.w1_scale) + w13_scale[:, inter:, :].copy_(self.w3_scale) + return w13.contiguous(), w13_scale.contiguous() + + def _release_raw_fp4_weight(self): + for name in ("w1", "w1_scale", "w2", "w2_scale", "w3", "w3_scale"): + if hasattr(self, name): + delattr(self, name) + class DeepseekV4TransformerLayerWeight(TransformerLayerWeight): """Per-layer weights for DeepSeek-V4-Flash. diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 02c71f01b2..c87f2fdebd 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -1,16 +1,37 @@ +import importlib.util +import os + import torch from lightllm.models.registry import ModelRegistry from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.req_manager import ReqManager, DeepseekV4ReqManager +from lightllm.common.req_manager import DeepseekV4ReqManager from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager -from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend -from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import DeepseekV4PreAndPostLayerWeight -from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight -from lightllm.models.deepseek_v4.layer_infer.pre_layer_infer import DeepseekV4PreLayerInfer -from lightllm.models.deepseek_v4.layer_infer.post_layer_infer import DeepseekV4PostLayerInfer -from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import DeepseekV4TransformerLayerInfer +from lightllm.common.basemodel.attention.base_att import ( + BaseAttBackend, + BasePrefillAttState, + BaseDecodeAttState, +) +from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import ( + DeepseekV4PreAndPostLayerWeight, +) +from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import ( + DeepseekV4TransformerLayerWeight, +) +from lightllm.models.deepseek_v4.layer_infer.pre_layer_infer import ( + DeepseekV4PreLayerInfer, +) +from lightllm.models.deepseek_v4.layer_infer.post_layer_infer import ( + DeepseekV4PostLayerInfer, +) +from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import ( + DeepseekV4TransformerLayerInfer, +) from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo -from lightllm.models.llama.yarn_rotary_utils import find_correction_range, linear_ramp_mask +from lightllm.models.deepseek3_2.model import DeepSeekChatTokenizerBase +from lightllm.models.llama.yarn_rotary_utils import ( + find_correction_range, + linear_ramp_mask, +) from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -18,9 +39,38 @@ logger = init_logger(__name__) +class DeepseekV4DirectSparseAttBackend(BaseAttBackend): + """Lifecycle placeholder for V4 direct attention. + + V4 attention is currently driven inside the layer by `vllm_sparse_attn()`, not by the generic + `infer_state.prefill_att_state.prefill_att()` / `decode_att()` backend selector. + """ + + def create_att_prefill_state(self, infer_state): + return DeepseekV4DirectSparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return DeepseekV4DirectSparseDecodeAttState(backend=self, infer_state=infer_state) + + +class DeepseekV4DirectSparsePrefillAttState(BasePrefillAttState): + def init_state(self): + return + + def prefill_att(self, *args, **kwargs): + raise RuntimeError("DeepSeek-V4 attention is executed directly by vllm_sparse_attn() in layer_infer.") + + +class DeepseekV4DirectSparseDecodeAttState(BaseDecodeAttState): + def init_state(self): + return + + def decode_att(self, *args, **kwargs): + raise RuntimeError("DeepSeek-V4 attention is executed directly by vllm_sparse_attn() in layer_infer.") + + @ModelRegistry("deepseek_v4") class DeepseekV4TpPartModel(LlamaTpPartModel): - pre_and_post_weight_class = DeepseekV4PreAndPostLayerWeight transformer_weight_class = DeepseekV4TransformerLayerWeight @@ -50,35 +100,46 @@ def _init_req_manager(self): create_max_seq_len = max(create_max_seq_len, self.max_seq_length) self._dsv4_req_manager_seq_len = create_max_seq_len - self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, None) + layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num() + self._dsv4_compress_rates = self._get_compress_rates(layer_num) + self.req_manager = DeepseekV4ReqManager( + self.max_req_num, + create_max_seq_len, + compress_rates=self._dsv4_compress_rates, + head_dim=self.config["head_dim"], + indexer_head_dim=self.config["index_head_dim"], + ) return def _get_compress_rates(self, layer_num): - rates = list(self.config.get("compress_ratios", [])) - if len(rates) < layer_num: - rates.extend([0] * (layer_num - len(rates))) + rates = list(self.config["compress_ratios"]) + assert ( + len(rates) >= layer_num + ), f"DeepSeek-V4 compress_ratios length {len(rates)} is shorter than layer_num {layer_num}" return rates[:layer_num] def _init_mem_manager(self): layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num() + compress_rates = getattr(self, "_dsv4_compress_rates", self._get_compress_rates(layer_num)) self.mem_manager = DeepseekV4MemoryManager( self.max_total_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["head_dim"], layer_num=layer_num, - compress_rates=self._get_compress_rates(layer_num), + compress_rates=compress_rates, indexer_head_dim=self.config["index_head_dim"], + max_request_num=self.max_req_num, + sliding_window=self.config["sliding_window"], mem_fraction=self.mem_fraction, ) - self.req_manager = DeepseekV4ReqManager( - self.max_req_num, self._dsv4_req_manager_seq_len, self.mem_manager - ) + assert isinstance(self.req_manager, DeepseekV4ReqManager) + self.req_manager.bind_mem_manager(self.mem_manager) return def _init_att_backend(self): - self.prefill_att_backend = TritonAttBackend(model=self) - self.decode_att_backend = TritonAttBackend(model=self) + self.prefill_att_backend = DeepseekV4DirectSparseAttBackend(model=self) + self.decode_att_backend = DeepseekV4DirectSparseAttBackend(model=self) return def _init_custom(self): @@ -114,8 +175,49 @@ def build(base, factor, orig_max): f = torch.outer(torch.arange(max_seq, dtype=torch.float32, device="cuda"), freqs) # [max_seq, dim//2] return f.cos(), f.sin() - self._cos_cached_sliding, self._sin_cached_sliding = build(cfg["rope_theta"], rs.get("factor", 1.0), 0) + self._cos_cached_sliding, self._sin_cached_sliding = build( + cfg["rope_theta"], + rs.get("factor", 16), + rs.get("original_max_position_embeddings", 65536), + ) self._cos_cached_compress, self._sin_cached_compress = build( - cfg["compress_rope_theta"], rs.get("factor", 16), rs.get("original_max_position_embeddings", 65536) + cfg["compress_rope_theta"], + rs.get("factor", 16), + rs.get("original_max_position_embeddings", 65536), ) return + + +class DeepSeekV4Tokenizer(DeepSeekChatTokenizerBase): + """Tokenizer wrapper for DeepSeek-V4's Python prompt encoding.""" + + def __init__(self, tokenizer, model_dir): + super().__init__(tokenizer) + self.model_dir = model_dir + self._encoding_module = None + + def _get_encoding_module(self): + if self._encoding_module is not None: + return self._encoding_module + + encoding_path = os.path.join(self.model_dir, "encoding", "encoding_dsv4.py") + if not os.path.exists(encoding_path): + raise FileNotFoundError(f"DeepSeek-V4 encoding file not found: {encoding_path}") + + spec = importlib.util.spec_from_file_location("lightllm_deepseek_v4_encoding_dsv4", encoding_path) + if spec is None or spec.loader is None: + raise ImportError(f"failed to load DeepSeek-V4 encoding module from {encoding_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._encoding_module = module + return module + + def _encode_messages(self, msgs, thinking_mode, kwargs): + encoding = self._get_encoding_module() + return encoding.encode_messages( + msgs, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + add_default_bos_token=kwargs.get("add_default_bos_token", True), + reasoning_effort=kwargs.get("reasoning_effort"), + ) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 654ba0f3e5..5a92a339cb 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -10,7 +10,11 @@ from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name +from lightllm.utils.envs_utils import ( + set_env_start_args, + set_unique_server_name, + get_unique_server_name, +) from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process @@ -23,6 +27,8 @@ has_vision_module, is_linear_att_mixed_model, auto_set_max_req_total_len, + get_model_type, + get_config_json, ) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args @@ -83,7 +89,14 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual_only"]: + if args.run_mode not in [ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "visual_only", + ]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -108,6 +121,23 @@ def normal_or_p_d_start(args): else: args.enable_multimodal = True + model_type = get_model_type(args.model_dir) + if model_type == "deepseek_v4": + if args.run_mode != "normal": + raise NotImplementedError("DeepSeek-V4 currently supports only run_mode=normal in LightLLM.") + if args.enable_cpu_cache or args.enable_disk_cache: + raise NotImplementedError("DeepSeek-V4 CPU/disk KV cache is not supported yet.") + if args.mtp_mode is not None or args.mtp_draft_model_dir is not None or args.mtp_step != 0: + raise NotImplementedError("DeepSeek-V4 MTP/speculative decoding is not supported yet.") + if args.enable_ep_moe: + raise NotImplementedError("DeepSeek-V4 EP MoE is not supported yet; use TP for now.") + if "prompt_cache_kv_buffer" in get_config_json(args.model_dir): + raise NotImplementedError("DeepSeek-V4 prompt_cache_kv_buffer is not supported yet.") + if not args.disable_dynamic_prompt_cache: + logger.info("DeepSeek-V4 runtime state does not support radix prompt cache yet; disabling it.") + args.disable_dynamic_prompt_cache = True + args.use_dynamic_prompt_cache = False + if args.enable_cpu_cache: # 生成一个用于创建cpu kv cache的共享内存id。 args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789 @@ -333,7 +363,14 @@ def normal_or_p_d_start(args): from lightllm.utils.config_utils import get_dtype args.data_type = get_dtype(args.model_dir) - assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + assert args.data_type in [ + "fp16", + "float16", + "bf16", + "bfloat16", + "fp32", + "float32", + ] already_uesd_ports = [args.port] if args.nccl_port is not None: @@ -432,7 +469,6 @@ def normal_or_p_d_start(args): ) if not args.disable_vision: - if not args.visual_use_proxy_mode: from .visualserver.manager import start_visual_process @@ -616,7 +652,14 @@ def visual_only_start(args): from lightllm.utils.config_utils import get_dtype args.data_type = get_dtype(args.model_dir) - assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + assert args.data_type in [ + "fp16", + "float16", + "bf16", + "bfloat16", + "fp32", + "float32", + ] logger.info(f"alloced ports: {can_use_ports}") diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5e90c9b34a..abeb8d61e9 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -51,7 +51,9 @@ def register( vocab_size: int, ): self.args = get_env_start_args() - from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend + from lightllm.server.router.model_infer.mode_backend.base_backend import ( + ModeBackend, + ) self.backend: ModeBackend = backend self.req_manager = req_manager @@ -122,7 +124,21 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs - def free_a_req_mem(self, free_token_index: List, req: "InferReq"): + def free_a_req_mem( + self, + free_token_index: List, + req: "InferReq", + free_c4_index: Optional[List] = None, + free_c128_index: Optional[List] = None, + ): + if hasattr(self.req_manager, "pop_compress_indices_for_req"): + c4, c128 = self.req_manager.pop_compress_indices_for_req(req.req_idx) + if c4 is not None and free_c4_index is not None: + free_c4_index.append(c4) + if c128 is not None and free_c128_index is not None: + free_c128_index.append(c128) + self.req_manager.clear_runtime_state(req.req_idx) + if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: @@ -258,11 +274,13 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + free_c4_index = [] + free_c128_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) + self.free_a_req_mem(free_token_index, req, free_c4_index, free_c128_index) free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") @@ -270,7 +288,17 @@ def _filter(self, finished_request_ids: List[int]): self.shm_req_manager.put_back_req_obj(req.shm_req) free_token_index = custom_cat(free_token_index) - self.req_manager.free(free_req_index, free_token_index) + if hasattr(self.req_manager, "free_compress_indices"): + free_c4_index = custom_cat(free_c4_index) if free_c4_index else None + free_c128_index = custom_cat(free_c128_index) if free_c128_index else None + self.req_manager.free( + free_req_index, + free_token_index, + free_c4_index=free_c4_index, + free_c128_index=free_c128_index, + ) + else: + self.req_manager.free(free_req_index, free_token_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -299,11 +327,13 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): g_infer_state_lock.acquire() free_token_index = [] + free_c4_index = [] + free_c128_index = [] for req in pause_reqs: if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) + self.free_a_req_mem(free_token_index, req, free_c4_index, free_c128_index) assert req.wait_pause is True req.wait_pause = False req.paused = True @@ -314,11 +344,23 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if len(free_token_index) != 0: free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) + if hasattr(self.req_manager, "free_compress_indices"): + free_c4_index = custom_cat(free_c4_index) if free_c4_index else None + free_c128_index = custom_cat(free_c128_index) if free_c128_index else None + self.req_manager.free_compress_indices( + free_c4_index=free_c4_index, + free_c128_index=free_c128_index, + ) g_infer_state_lock.release() return self - def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): + def recover_paused_reqs( + self, + paused_reqs: List["InferReq"], + is_master_in_dp: bool, + can_alloc_token_num: int, + ): if paused_reqs: g_infer_state_lock.acquire() @@ -375,7 +417,9 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L big_page_buffer_ids = torch.tensor(big_page_buffer_ids, dtype=torch.int32, requires_grad=False, device="cpu") big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True) - from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( + copy_linear_att_state_to_kv_buffer, + ) copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, @@ -405,9 +449,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] dst_buffer_idx = req.tail_linear_att_small_page_buffer_id - dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( - buffer_idx=dst_buffer_idx - ) + ( + dst_conv_state, + dst_ssm_state, + ) = self.radix_cache.linear_att_small_page_buffers.get_state_cache(buffer_idx=dst_buffer_idx) # TODO 对于非连续对象调用 copy_ 效率并不高 dst_conv_state.copy_(gpu_conv_state, non_blocking=True) dst_ssm_state.copy_(gpu_ssm_state, non_blocking=True) @@ -640,7 +685,10 @@ def _linear_match_radix_cache(self): enable_prompt_cache = (not self.sampling_param.disable_prompt_cache) and g_infer_context.radix_cache is not None linear_hash_list = self.shm_req.linear_att_token_hash_list.get_all() linear_att_hash_page_size = self.args.linear_att_hash_page_size - match_tokens = min(len(linear_hash_list) * linear_att_hash_page_size, self.get_cur_total_len() - 1) + match_tokens = min( + len(linear_hash_list) * linear_att_hash_page_size, + self.get_cur_total_len() - 1, + ) match_tokens = max(0, match_tokens) match_tokens = (match_tokens // linear_att_hash_page_size) * linear_att_hash_page_size match_block_num = match_tokens // linear_att_hash_page_size @@ -706,7 +754,8 @@ def _linear_match_radix_cache(self): # 将 对应的 value_tensors 中的 kv 数据 拷贝到 tail_mems 中对应的数据去 radix_cache.mem_manager.operator.copy_mem_to_mem( - value_tensor[cur_big_page_tokens:shared_kv_len], tail_mems + value_tensor[cur_big_page_tokens:shared_kv_len], + tail_mems, ) self.shared_kv_node = share_node # 只是为了保证 copy_small_page_buffer_to_linear_att_state 正确调用 @@ -737,7 +786,8 @@ def _linear_match_radix_cache(self): assert self.tail_linear_att_small_page_buffer_id is None # 恢复linear att 状态 g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state( - big_page_buffer_idx=share_node.big_page_buffer_idx, req=self + big_page_buffer_idx=share_node.big_page_buffer_idx, + req=self, ) self.shm_req.shm_cur_kv_len = self.cur_kv_len diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index f84e6359ba..18d5eafcc0 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -89,6 +89,11 @@ def get_tokenizer( ) logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") return DeepSeekV32Tokenizer(hf_tokenizer) + if model_type == "deepseek_v4": + from ..models.deepseek_v4.model import DeepSeekV4Tokenizer + + logger.info("Using DeepSeek-V4 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV4Tokenizer(tokenizer, tokenizer_name) if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor From a1612445898361b4a48a6f4c4778ad2043b24072 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 5 Jun 2026 05:03:58 +0000 Subject: [PATCH 03/30] add prompt cache --- lightllm/common/basemodel/basemodel.py | 28 ++ .../deepseek4_mem_manager.py | 204 +++++++++++++- lightllm/common/req_manager.py | 255 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 32 ++- lightllm/server/api_start.py | 4 - .../router/dynamic_prompt/radix_cache.py | 172 +++++++++--- .../server/router/model_infer/infer_batch.py | 124 ++++++++- .../model_infer/mode_backend/base_backend.py | 38 +++ 8 files changed, 803 insertions(+), 54 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 473dcbafda..d785991808 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -527,12 +527,22 @@ def _prefill( alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, ) + if hasattr(self.mem_manager, "prepare_prefill_swa_slots"): + self.mem_manager.prepare_prefill_swa_slots( + b_req_idx=infer_state.b_req_idx, + b_seq_len=infer_state.b_seq_len, + b_ready_cache_len=infer_state.b_ready_cache_len, + b_start_loc=model_input.b_prefill_start_loc, + mem_index=infer_state.mem_index, + ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() infer_state.init_some_extra_state(self) infer_state.init_att_state() model_output = self._context_forward(infer_state) + if hasattr(self.mem_manager, "commit_prefill_swa_slots"): + self.mem_manager.commit_prefill_swa_slots() model_output = self._create_unpad_prefill_model_output( padded_model_output=model_output, @@ -747,6 +757,14 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state0.mem_index, max_q_seq_len=infer_state0.max_q_seq_len, ) + if hasattr(self.mem_manager, "prepare_prefill_swa_slots"): + self.mem_manager.prepare_prefill_swa_slots( + b_req_idx=infer_state0.b_req_idx, + b_seq_len=infer_state0.b_seq_len, + b_ready_cache_len=infer_state0.b_ready_cache_len, + b_start_loc=model_input0.b_prefill_start_loc, + mem_index=infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self) infer_state0.init_att_state() @@ -760,6 +778,14 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state1.mem_index, max_q_seq_len=infer_state1.max_q_seq_len, ) + if hasattr(self.mem_manager, "prepare_prefill_swa_slots"): + self.mem_manager.prepare_prefill_swa_slots( + b_req_idx=infer_state1.b_req_idx, + b_seq_len=infer_state1.b_seq_len, + b_ready_cache_len=infer_state1.b_ready_cache_len, + b_start_loc=model_input1.b_prefill_start_loc, + mem_index=infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self) infer_state1.init_att_state() @@ -767,6 +793,8 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod prefill_mem_indexes_ready_event.record() model_output0, model_output1 = self._overlap_tpsp_context_forward(infer_state0, infer_state1=infer_state1) + if hasattr(self.mem_manager, "commit_prefill_swa_slots"): + self.mem_manager.commit_prefill_swa_slots() model_output0 = self._create_unpad_prefill_model_output( padded_model_output=model_output0, diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 739e0bd51a..6735b2deed 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -28,7 +28,7 @@ DSV4_SWA_PAGE_SIZE = 128 DSV4_C4_PAGE_SIZE = 64 DSV4_C128_PAGE_SIZE = 2 -DSV4_PROFILE_MAX_FULL_TOKENS = 2_000_000 +DSV4_PROFILE_MAX_FULL_TOKENS = 1_500_000 def _ceil_div(a: int, b: int) -> int: @@ -294,6 +294,7 @@ def __init__( self.cache_dtype = torch.uint8 self.max_request_num = max_request_num self.sliding_window = sliding_window + self._pending_prefill_swa: Dict[int, Dict[str, torch.Tensor]] = {} # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法) self.layer_to_c4_idx: Dict[int, int] = {} @@ -560,6 +561,127 @@ def ensure_swa_slots(self, req_idx: int, positions: torch.Tensor, full_slots: to out[i] = swa return out + def _reserve_prefill_swa_slots( + self, + req_idx: int, + positions: torch.Tensor, + full_slots: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + full_slots = full_slots.long().reshape(-1) + positions = positions.long().reshape(-1) + assert positions.numel() == full_slots.numel() + + out = torch.empty_like(full_slots, dtype=torch.long) + ring_to_swa: Dict[int, int] = {} + ring_to_old_full: Dict[int, int] = {} + ring_to_final_full: Dict[int, int] = {} + hold = self.swa_pool.HOLD_TOKEN_MEMINDEX + + for i, (pos, full) in enumerate(zip(positions.tolist(), full_slots.tolist())): + if full == self.HOLD_TOKEN_MEMINDEX: + out[i] = hold + continue + + ring_pos = int(pos) % int(self.sliding_window) + swa = ring_to_swa.get(ring_pos) + if swa is None: + old_swa = int(self.req_to_swa_indexs[req_idx, ring_pos].item()) + old_full = int(self.req_to_swa_full_indexs[req_idx, ring_pos].item()) + if old_swa == hold: + old_swa = int(self.swa_allocator.alloc(1)[0].item()) + swa = old_swa + ring_to_swa[ring_pos] = swa + ring_to_old_full[ring_pos] = old_full + + ring_to_final_full[ring_pos] = int(full) + out[i] = swa + + rings = sorted(ring_to_final_full) + return { + "positions": positions.detach().clone(), + "full_slots": full_slots.detach().clone(), + "swa_slots": out.detach().clone(), + "commit_rings": torch.tensor(rings, dtype=torch.long, device=full_slots.device), + "commit_full_slots": torch.tensor( + [ring_to_final_full[r] for r in rings], + dtype=torch.long, + device=full_slots.device, + ), + "commit_swa_slots": torch.tensor( + [ring_to_swa[r] for r in rings], + dtype=torch.long, + device=full_slots.device, + ), + "commit_old_full_slots": torch.tensor( + [ring_to_old_full[r] for r in rings], + dtype=torch.long, + device=full_slots.device, + ), + } + + def prepare_prefill_swa_slots( + self, + b_req_idx: torch.Tensor, + b_seq_len: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_start_loc: torch.Tensor, + mem_index: torch.Tensor, + ) -> None: + if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: + return + + self._pending_prefill_swa = {} + req_list = b_req_idx.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() + start_list = b_start_loc.detach().cpu().tolist() + for req_idx, seq_len, ready_len, start_loc in zip(req_list, seq_list, ready_list, start_list): + token_num = int(seq_len) - int(ready_len) + if token_num <= 0: + continue + pos = torch.arange(int(ready_len), int(seq_len), dtype=torch.long, device=mem_index.device) + slots = mem_index[int(start_loc) : int(start_loc) + token_num] + self._pending_prefill_swa[int(req_idx)] = self._reserve_prefill_swa_slots(int(req_idx), pos, slots) + return + + def _get_pending_prefill_swa_slots( + self, + req_idx: int, + positions: torch.Tensor, + full_slots: torch.Tensor, + ) -> Optional[torch.Tensor]: + pending = self._pending_prefill_swa.get(int(req_idx)) + if pending is None: + return None + if pending["positions"].numel() != positions.numel(): + return None + if not torch.equal(pending["positions"].to(positions.device), positions.long().reshape(-1)): + return None + if not torch.equal(pending["full_slots"].to(full_slots.device), full_slots.long().reshape(-1)): + return None + return pending["swa_slots"].to(full_slots.device) + + def commit_prefill_swa_slots(self) -> None: + if not self._pending_prefill_swa: + return + for req_idx, pending in self._pending_prefill_swa.items(): + rings = pending["commit_rings"].to(self.req_to_swa_indexs.device) + if rings.numel() == 0: + continue + old_full = pending["commit_old_full_slots"].to(self.full_to_swa_indexs.device) + valid_old = old_full >= 0 + if valid_old.any(): + self.full_to_swa_indexs[old_full[valid_old].long()] = -1 + + full_slots = pending["commit_full_slots"].to(self.full_to_swa_indexs.device) + swa_slots = pending["commit_swa_slots"].to(self.full_to_swa_indexs.device) + self.req_to_swa_indexs[int(req_idx), rings] = swa_slots.to(torch.int32) + self.req_to_swa_full_indexs[int(req_idx), rings] = full_slots.to(torch.int32) + self.full_to_swa_indexs[full_slots.long()] = swa_slots.to(torch.int32) + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + self._pending_prefill_swa = {} + return + def _swa_slots_from_full(self, full_slots: torch.Tensor) -> torch.Tensor: full_slots = full_slots.long().reshape(-1) if full_slots.numel() == 0: @@ -597,6 +719,79 @@ def free_swa_for_req(self, req_idx: int) -> None: self.req_to_swa_full_indexs[req_idx].fill_(-1) self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + def snapshot_swa_for_prompt_cache(self, req_idx: int, cache_len: int, full_slots: torch.Tensor): + if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None or cache_len <= 0: + return None + tail_start = max(0, int(cache_len) - int(self.sliding_window)) + full_slots = full_slots[tail_start:cache_len].long().to(self.kv_buffer.device) + if full_slots.numel() == 0: + return None + swa_slots = self.full_to_swa_indexs[full_slots].long() + if (swa_slots < 0).any(): + bad = int(full_slots[swa_slots < 0][0].item()) + raise RuntimeError(f"DeepSeek-V4 prompt cache cannot snapshot evicted SWA full slot {bad}") + return { + "positions": torch.arange(tail_start, cache_len, dtype=torch.int64, device="cpu"), + "full_slots": full_slots.detach().cpu(), + "swa_slots": swa_slots.detach().cpu(), + } + + def clone_swa_for_prompt_cache(self, req_idx: int, cache_len: int, full_slots: torch.Tensor): + payload = self.snapshot_swa_for_prompt_cache(req_idx, cache_len, full_slots) + if payload is None: + return None + + src_slots = payload["swa_slots"].long().to(self.kv_buffer.device) + dst_slots = self.swa_allocator.alloc(src_slots.numel()).long().to(self.kv_buffer.device) + for layer_idx in range(self.layer_num): + self.swa_pool.write(layer_idx, dst_slots, self.swa_pool.read(layer_idx, src_slots)) + payload["swa_slots"] = dst_slots.detach().cpu() + return payload + + def detach_swa_for_prompt_cache(self, req_idx: int, swa_payload) -> None: + if ( + swa_payload is None + or self.req_to_swa_indexs is None + or self.req_to_swa_full_indexs is None + or len(swa_payload["positions"]) == 0 + ): + return + req_idx = int(req_idx) + positions = swa_payload["positions"].tolist() + full_slots = swa_payload["full_slots"].tolist() + swa_slots = swa_payload["swa_slots"].tolist() + for pos, full, swa in zip(positions, full_slots, swa_slots): + ring_pos = int(pos) % int(self.sliding_window) + if int(self.req_to_swa_indexs[req_idx, ring_pos].item()) == int(swa) and int( + self.req_to_swa_full_indexs[req_idx, ring_pos].item() + ) == int(full): + self.req_to_swa_indexs[req_idx, ring_pos] = self.swa_pool.HOLD_TOKEN_MEMINDEX + self.req_to_swa_full_indexs[req_idx, ring_pos] = -1 + return + + def restore_swa_from_prompt_cache(self, swa_payload) -> None: + if swa_payload is None or len(swa_payload["full_slots"]) == 0: + return + full_slots = swa_payload["full_slots"].long().to(self.kv_buffer.device) + swa_slots = swa_payload["swa_slots"].long().to(self.kv_buffer.device) + self.full_to_swa_indexs[full_slots] = swa_slots.to(torch.int32) + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + return + + def free_swa_prompt_cache(self, swa_payload) -> None: + if swa_payload is None or len(swa_payload["swa_slots"]) == 0: + return + swa_slots = torch.unique(swa_payload["swa_slots"].long()).detach().cpu() + self.swa_allocator.free(swa_slots) + full_slots = swa_payload["full_slots"].long().to(self.kv_buffer.device) + mapped = self.full_to_swa_indexs[full_slots].long() + expected = swa_payload["swa_slots"].long().to(self.kv_buffer.device) + same = mapped == expected + if same.any(): + self.full_to_swa_indexs[full_slots[same]] = -1 + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + return + def _keep_last_swa_writes(self, swa_slots: torch.Tensor, packed: torch.Tensor): """Drop duplicate SWA writes generated by long prefill ring reuse.""" if swa_slots.numel() <= 1: @@ -634,7 +829,11 @@ def pack_mla_kv_to_cache( if req_idx is None or positions is None: swa_slots = self._identity_swa_slots(mem_index).to(kv.device) else: - swa_slots = self.ensure_swa_slots(req_idx, positions, mem_index).to(kv.device) + pending_slots = self._get_pending_prefill_swa_slots(req_idx, positions, mem_index) + if pending_slots is None: + swa_slots = self.ensure_swa_slots(req_idx, positions, mem_index).to(kv.device) + else: + swa_slots = pending_slots.to(kv.device) swa_slots, packed = self._keep_last_swa_writes(swa_slots, packed) if swa_slots.numel() == 0: return @@ -712,6 +911,7 @@ def free_all(self): if getattr(self, "req_to_swa_indexs", None) is not None: self.req_to_swa_indexs.fill_(self.swa_pool.HOLD_TOKEN_MEMINDEX) self.req_to_swa_full_indexs.fill_(-1) + self._pending_prefill_swa = {} if self.c4_pool is not None: self.c4_pool.free_all() if self.c128_pool is not None: diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 1cdea03381..606469d48e 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,5 +1,6 @@ import torch import collections +from dataclasses import dataclass from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig from lightllm.utils.log_utils import init_logger @@ -23,6 +24,33 @@ logger = init_logger(__name__) +@dataclass +class DeepseekV4PromptCachePayload: + cache_len: int + c4_slots: Optional[torch.Tensor] = None + c128_slots: Optional[torch.Tensor] = None + c4_state: Optional[torch.Tensor] = None + c4_state_pool: Optional[torch.Tensor] = None + c4_indexer_state: Optional[torch.Tensor] = None + c4_indexer_state_pool: Optional[torch.Tensor] = None + swa: Optional[dict] = None + + +class DeepseekV4PromptCacheValueOps: + def __init__(self, req_manager: "DeepseekV4ReqManager"): + self.req_manager = req_manager + + def slice(self, payload: DeepseekV4PromptCachePayload, start: int, end: int): + return self.req_manager.slice_prompt_cache_payload(payload, start, end) + + def concat(self, payloads: List[DeepseekV4PromptCachePayload]): + return self.req_manager.concat_prompt_cache_payloads(payloads) + + def free(self, payload: DeepseekV4PromptCachePayload): + self.req_manager.free_prompt_cache_payload(payload) + return + + class _ReqNode: def __init__(self, index): self.index = index @@ -594,6 +622,233 @@ def get_c4_indexer_state_pool_for_req(self, layer_index: int, req_idx: int) -> t local = self.layer_to_c4_idx[layer_index] return self.req_to_c4_indexer_state_pool.buffer[local, req_idx] + def get_prompt_cache_value_ops(self): + return DeepseekV4PromptCacheValueOps(self) + + def get_prompt_cache_page_size(self): + return 128 + + def _slice_cpu_slots(self, slots: Optional[torch.Tensor], start: int, end: int, ratio: int): + if slots is None: + return None + return slots[start // ratio : end // ratio].clone() + + def _slice_swa_payload(self, swa_payload, start: int, end: int): + if swa_payload is None: + return None + positions = swa_payload["positions"] + mask = (positions >= start) & (positions < end) + if not bool(mask.any()): + return None + return { + "positions": positions[mask].clone(), + "full_slots": swa_payload["full_slots"][mask].clone(), + "swa_slots": swa_payload["swa_slots"][mask].clone(), + } + + def slice_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload, start: int, end: int): + start = int(start) + end = int(end) + # c4/c128/indexer-K slots are true historical KV and can be sliced by ratio. + # compressor running state only describes the payload end boundary; it is valid + # for a slice only when that slice keeps the original end boundary. + keep_end_state = end == payload.cache_len + return DeepseekV4PromptCachePayload( + cache_len=end - start, + c4_slots=self._slice_cpu_slots(payload.c4_slots, start, end, 4), + c128_slots=self._slice_cpu_slots(payload.c128_slots, start, end, 128), + c4_state=payload.c4_state.clone() if keep_end_state and payload.c4_state is not None else None, + c4_state_pool=payload.c4_state_pool.clone() + if keep_end_state and payload.c4_state_pool is not None + else None, + c4_indexer_state=payload.c4_indexer_state.clone() + if keep_end_state and payload.c4_indexer_state is not None + else None, + c4_indexer_state_pool=payload.c4_indexer_state_pool.clone() + if keep_end_state and payload.c4_indexer_state_pool is not None + else None, + swa=self._slice_swa_payload(payload.swa, start, end), + ) + + def concat_prompt_cache_payloads(self, payloads: List[DeepseekV4PromptCachePayload]): + if len(payloads) == 0: + return None + c4_slots = [p.c4_slots for p in payloads if p.c4_slots is not None and len(p.c4_slots) > 0] + c128_slots = [p.c128_slots for p in payloads if p.c128_slots is not None and len(p.c128_slots) > 0] + last = payloads[-1] + return DeepseekV4PromptCachePayload( + cache_len=sum(p.cache_len for p in payloads), + c4_slots=torch.cat(c4_slots, dim=0) if c4_slots else None, + c128_slots=torch.cat(c128_slots, dim=0) if c128_slots else None, + c4_state=last.c4_state, + c4_state_pool=last.c4_state_pool, + c4_indexer_state=last.c4_indexer_state, + c4_indexer_state_pool=last.c4_indexer_state_pool, + swa=last.swa, + ) + + def build_prompt_cache_payload( + self, + req_idx: int, + cache_len: int, + clone_swa: bool = False, + ) -> DeepseekV4PromptCachePayload: + assert self.mem_manager is not None + cache_len = int(cache_len) + full_slots = self.req_to_token_indexs[req_idx, :cache_len].detach().cpu() + c4_count = cache_len // 4 + c128_count = cache_len // 128 + c4_slots = self.req_to_c4_indexs[req_idx, :c4_count].detach().cpu().clone() if c4_count > 0 else None + c128_slots = self.req_to_c128_indexs[req_idx, :c128_count].detach().cpu().clone() if c128_count > 0 else None + if clone_swa: + swa_payload = self.mem_manager.clone_swa_for_prompt_cache(req_idx, cache_len, full_slots) + else: + swa_payload = self.mem_manager.snapshot_swa_for_prompt_cache(req_idx, cache_len, full_slots) + return DeepseekV4PromptCachePayload( + cache_len=cache_len, + c4_slots=c4_slots, + c128_slots=c128_slots, + c4_state=self.req_to_c4_state.buffer[:, req_idx].detach().clone() if self.n_c4 > 0 else None, + c4_state_pool=self.req_to_c4_state_pool.buffer[:, req_idx].detach().clone() if self.n_c4 > 0 else None, + c4_indexer_state=self.req_to_c4_indexer_state.buffer[:, req_idx].detach().clone() + if self.n_c4 > 0 + else None, + c4_indexer_state_pool=self.req_to_c4_indexer_state_pool.buffer[:, req_idx].detach().clone() + if self.n_c4 > 0 + else None, + swa=swa_payload, + ) + + def detach_prompt_cache_payload_from_req(self, req_idx: int, payload: DeepseekV4PromptCachePayload): + if payload is not None and self.mem_manager is not None: + self.mem_manager.detach_swa_for_prompt_cache(req_idx, payload.swa) + return + + def free_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload): + if payload is None or self.mem_manager is None: + return + if payload.c4_slots is not None and len(payload.c4_slots) > 0: + self.mem_manager.free_c4(payload.c4_slots) + if payload.c128_slots is not None and len(payload.c128_slots) > 0: + self.mem_manager.free_c128(payload.c128_slots) + self.mem_manager.free_swa_prompt_cache(payload.swa) + return + + def release_prompt_cache_detached_swa( + self, + payload: DeepseekV4PromptCachePayload, + keep_payload: Optional[DeepseekV4PromptCachePayload] = None, + ): + if payload is None or payload.swa is None or self.mem_manager is None: + return + old_swa = payload.swa + if keep_payload is None or keep_payload.swa is None: + self.mem_manager.free_swa_prompt_cache(old_swa) + return + + old_slots = old_swa["swa_slots"].long() + keep_slots = keep_payload.swa["swa_slots"].long() + if old_slots.numel() == 0: + return + if keep_slots.numel() == 0: + self.mem_manager.free_swa_prompt_cache(old_swa) + return + + release_mask = ~torch.isin(old_slots, keep_slots) + if not release_mask.any(): + return + release_payload = { + "full_slots": old_swa["full_slots"][release_mask].clone(), + "swa_slots": old_swa["swa_slots"][release_mask].clone(), + } + self.mem_manager.free_swa_prompt_cache(release_payload) + return + + def _reset_c128_for_prompt_cache(self, req_idx: int): + if self.n_c128 > 0: + self._reset_compress_cache_req(self.req_to_c128_state, req_idx) + self._reset_state_pool_req(self.req_to_c128_state_pool, req_idx) + return + + def rebuild_runtime_state_for_req(self, req_idx: int): + state_map = self._runtime_states[req_idx] + state_map.clear() + for layer_index, ratio in enumerate(self.compress_rates): + if ratio == 4: + cstate_kv, cstate_score = self.get_compress_state_for_req(layer_index, req_idx) + idx_state = self.get_c4_indexer_compress_state(layer_index) + state_map[layer_index] = { + "cstate_kv": cstate_kv, + "cstate_score": cstate_score, + "idx_cstate_kv": idx_state[req_idx, 0], + "idx_cstate_score": idx_state[req_idx, 1], + } + elif ratio == 128: + cstate_kv, cstate_score = self.get_compress_state_for_req(layer_index, req_idx) + state_map[layer_index] = { + "cstate_kv": cstate_kv, + "cstate_score": cstate_score, + } + return + + def restore_prompt_cache_payload(self, req_idx: int, payload: DeepseekV4PromptCachePayload): + assert self.mem_manager is not None + cache_len = int(payload.cache_len) + c4_count = cache_len // 4 + c128_count = cache_len // 128 + if c4_count > 0: + assert payload.c4_slots is not None and len(payload.c4_slots) == c4_count + self.req_to_c4_indexs[req_idx, :c4_count] = payload.c4_slots.cuda(non_blocking=True) + if c128_count > 0: + assert payload.c128_slots is not None and len(payload.c128_slots) == c128_count + self.req_to_c128_indexs[req_idx, :c128_count] = payload.c128_slots.cuda(non_blocking=True) + self._c4_entry_counts[req_idx] = c4_count + self._c128_entry_counts[req_idx] = c128_count + + if self.n_c4 > 0: + if payload.c4_state is None or payload.c4_indexer_state is None: + raise RuntimeError("DeepSeek-V4 prompt cache hit is missing c4 running state") + self.req_to_c4_state.buffer[:, req_idx].copy_(payload.c4_state) + self.req_to_c4_indexer_state.buffer[:, req_idx].copy_(payload.c4_indexer_state) + if payload.c4_state_pool is not None: + self.req_to_c4_state_pool.buffer[:, req_idx].copy_(payload.c4_state_pool) + if payload.c4_indexer_state_pool is not None: + self.req_to_c4_indexer_state_pool.buffer[:, req_idx].copy_(payload.c4_indexer_state_pool) + self._reset_c128_for_prompt_cache(req_idx) + self.mem_manager.restore_swa_from_prompt_cache(payload.swa) + self.rebuild_runtime_state_for_req(req_idx) + return + + def pop_prompt_cache_free_compress_indices( + self, + req_idx: int, + keep_len: int, + duplicate_start_len: Optional[int] = None, + duplicate_end_len: Optional[int] = None, + ): + def collect(table, cur_count, ratio): + ranges = [] + if duplicate_start_len is not None and duplicate_end_len is not None: + dup_start = duplicate_start_len // ratio + dup_end = duplicate_end_len // ratio + if dup_end > dup_start: + ranges.append((dup_start, dup_end)) + keep_count = keep_len // ratio + if cur_count > keep_count: + ranges.append((keep_count, cur_count)) + parts = [table[req_idx, s:e].clone() for s, e in ranges if e > s] + return torch.cat(parts, dim=0) if parts else None + + c4 = collect(self.req_to_c4_indexs, self._c4_entry_counts[req_idx], 4) + c128 = collect(self.req_to_c128_indexs, self._c128_entry_counts[req_idx], 128) + if self._c4_entry_counts[req_idx] > 0: + self.req_to_c4_indexs[req_idx, : self._c4_entry_counts[req_idx]].fill_(0) + if self._c128_entry_counts[req_idx] > 0: + self.req_to_c128_indexs[req_idx, : self._c128_entry_counts[req_idx]].fill_(0) + self._c4_entry_counts[req_idx] = 0 + self._c128_entry_counts[req_idx] = 0 + return c4, c128 + def free( self, free_req_indexes, diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 98dee7fd8a..81b45299f9 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -593,19 +593,25 @@ def _indexer_topk(self, idx_q, idx_comp, idx_weight, positions_1based, offset, i if k == 0: return torch.empty((idx_q.shape[0], 0), device=idx_q.device, dtype=torch.long) - scores = torch.einsum("thd,nd->thn", idx_q.float(), idx_comp.float()) - scores = F.relu(scores) * self.indexer_score_scale - index_scores = (scores * idx_weight.unsqueeze(-1)).sum(dim=1) - if self.tp_world_size_ > 1: - all_reduce( - index_scores, - op=dist.ReduceOp.SUM, - group=infer_state.dist_group, - async_op=False, - ) - - causal_threshold = positions_1based // 4 - top = self._indexer_topk_kernel(index_scores, causal_threshold, k) + top_chunks = [] + heads = max(1, idx_q.shape[1]) + max_score_elems = 16 * 1024 * 1024 + chunk_size = max(1, min(idx_q.shape[0], max_score_elems // max(1, heads * ncomp))) + for start in range(0, idx_q.shape[0], chunk_size): + end = min(idx_q.shape[0], start + chunk_size) + scores = torch.einsum("thd,nd->thn", idx_q[start:end].float(), idx_comp.float()) + scores = F.relu(scores) * self.indexer_score_scale + index_scores = (scores * idx_weight[start:end].unsqueeze(-1)).sum(dim=1) + if self.tp_world_size_ > 1: + all_reduce( + index_scores, + op=dist.ReduceOp.SUM, + group=infer_state.dist_group, + async_op=False, + ) + causal_threshold = positions_1based[start:end] // 4 + top_chunks.append(self._indexer_topk_kernel(index_scores, causal_threshold, k)) + top = torch.cat(top_chunks, dim=0) valid = top >= 0 return torch.where(valid, top + offset, torch.full_like(top, -1)) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 5a92a339cb..4c23001bce 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -133,10 +133,6 @@ def normal_or_p_d_start(args): raise NotImplementedError("DeepSeek-V4 EP MoE is not supported yet; use TP for now.") if "prompt_cache_kv_buffer" in get_config_json(args.model_dir): raise NotImplementedError("DeepSeek-V4 prompt_cache_kv_buffer is not supported yet.") - if not args.disable_dynamic_prompt_cache: - logger.info("DeepSeek-V4 runtime state does not support radix prompt cache yet; disabling it.") - args.disable_dynamic_prompt_cache = True - args.use_dynamic_prompt_cache = False if args.enable_cpu_cache: # 生成一个用于创建cpu kv cache的共享内存id。 diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 21e26c5854..8be5198eb3 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -2,7 +2,7 @@ import torch import numpy as np import collections -from typing import Tuple, Dict, Set, List, Optional, Union +from typing import Any, Tuple, Dict, Set, List, Optional, Union from sortedcontainers import SortedSet from .shared_arr import SharedArray @@ -25,6 +25,7 @@ def __init__(self): self.parent: TreeNode = None self.token_id_key: torch.Tensor = None self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.token_extra_value: Any = None self.ref_counter = 0 self.time_id = time_gen.generate_time_id() # 用于标识时间周期 @@ -34,14 +35,17 @@ def __init__(self): def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) - def split_node(self, prefix_len): + def split_node(self, prefix_len, child_key_fn=None, extra_value_ops=None): split_parent_node = TreeNode() split_parent_node.parent = self.parent - split_parent_node.parent.children[self.token_id_key[0].item()] = split_parent_node + split_parent_node.parent.children[child_key_fn(self.token_id_key)] = split_parent_node split_parent_node.token_id_key = self.token_id_key[0:prefix_len] split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + if self.token_extra_value is not None and extra_value_ops is not None: + split_parent_node.token_extra_value = extra_value_ops.slice(self.token_extra_value, 0, prefix_len) + self.token_extra_value = extra_value_ops.slice(self.token_extra_value, prefix_len, len(self.token_id_key)) split_parent_node.children = {} - split_parent_node.children[self.token_id_key[prefix_len].item()] = self + split_parent_node.children[child_key_fn(self.token_id_key[prefix_len:])] = self split_parent_node.ref_counter = self.ref_counter new_len = len(split_parent_node.token_mem_index_value) @@ -56,11 +60,12 @@ def split_node(self, prefix_len): self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len return split_parent_node - def add_and_return_new_child(self, token_id_key, token_mem_index_value): + def add_and_return_new_child(self, token_id_key, token_mem_index_value, token_extra_value=None, child_key=None): child = TreeNode() child.token_id_key = token_id_key child.token_mem_index_value = token_mem_index_value - first_token_key = child.token_id_key[0].item() + child.token_extra_value = token_extra_value + first_token_key = child.token_id_key[0].item() if child_key is None else child_key assert first_token_key not in self.children.keys() self.children[first_token_key] = child child.parent = self @@ -71,9 +76,17 @@ def add_and_return_new_child(self, token_id_key, token_mem_index_value): return child def remove_child(self, child_node: "TreeNode"): - del self.children[child_node.token_id_key[0].item()] - child_node.parent = None - return + child_key = child_node.token_id_key[0].item() + if child_key in self.children: + del self.children[child_key] + child_node.parent = None + return + for key, value in list(self.children.items()): + if value is child_node: + del self.children[key] + child_node.parent = None + return + raise KeyError("child node not found") def update_time(self): self.time_id = time_gen.generate_time_id() @@ -103,12 +116,22 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + def __init__( + self, + unique_name, + total_token_num, + rank_in_node, + mem_manager=None, + page_size: int = 1, + extra_value_ops=None, + ): from lightllm.common.kv_cache_mem_manager import MemoryManager self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 + self.page_size = max(1, int(page_size)) + self.extra_value_ops = extra_value_ops self.root_node = TreeNode() self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) @@ -125,30 +148,66 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 - def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + def _align_len(self, length: int) -> int: + if self.page_size <= 1: + return int(length) + return int(length) // self.page_size * self.page_size + + def align_len(self, length: int) -> int: + return self._align_len(length) + + def _child_key(self, key: torch.Tensor): + if self.page_size <= 1: + return key[0].item() + return tuple(key[: self.page_size].tolist()) + + def _match_len(self, key: torch.Tensor, node_key: torch.Tensor) -> int: + prefix_len = match(key, node_key) + return self._align_len(prefix_len) + + def _slice_extra(self, extra_value, start: int, end: int): + if extra_value is None: + return None + assert self.extra_value_ops is not None + return self.extra_value_ops.slice(extra_value, start, end) + + def _concat_extra(self, values: list): + values = [v for v in values if v is not None] + if len(values) == 0: + return None + assert self.extra_value_ops is not None + return self.extra_value_ops.concat(values) + + def insert(self, key, value=None, extra_value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key + align_len = self._align_len(len(key)) + key = key[:align_len] + value = value[:align_len] + if extra_value is not None: + extra_value = self._slice_extra(extra_value, 0, align_len) + assert len(key) == len(value) # and len(key) >= 1 if len(key) == 0: return 0, None - return self._insert_helper(self.root_node, key, value) + return self._insert_helper(self.root_node, key, value, extra_value) - def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + def _insert_helper(self, node: TreeNode, key, value, extra_value) -> Tuple[int, Optional[TreeNode]]: handle_stack = collections.deque() update_list = collections.deque() - handle_stack.append((node, key, value)) + handle_stack.append((node, key, value, extra_value)) ans_prefix_len = 0 ans_node = None while len(handle_stack) != 0: - node, key, value = handle_stack.popleft() - ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) - if len(ans_tuple) == 4: - (_prefix_len, new_node, new_key, new_value) = ans_tuple + node, key, value, extra_value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value, extra_value=extra_value) + if len(ans_tuple) == 5: + (_prefix_len, new_node, new_key, new_value, new_extra_value) = ans_tuple ans_prefix_len += _prefix_len - handle_stack.append((new_node, new_key, new_value)) + handle_stack.append((new_node, new_key, new_value, new_extra_value)) else: _prefix_len, ans_node = ans_tuple ans_prefix_len += _prefix_len @@ -166,15 +225,15 @@ def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[Tree return ans_prefix_len, ans_node def _insert_helper_no_recursion( - self, node: TreeNode, key: torch.Tensor, value: torch.Tensor - ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor, extra_value=None + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor, Any]]: if node.is_leaf(): self.evict_tree_set.discard(node) - first_key_id = key[0].item() + first_key_id = self._child_key(key) if first_key_id in node.children.keys(): child: TreeNode = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) + prefix_len = self._match_len(key, child.token_id_key) if prefix_len == len(key): if prefix_len == len(child.token_id_key): if child.is_leaf(): @@ -184,10 +243,14 @@ def _insert_helper_no_recursion( self.evict_tree_set.add(child) return prefix_len, child elif prefix_len < len(child.token_id_key): + if prefix_len == 0: + return 0, node if child.is_leaf(): self.evict_tree_set.discard(child) - split_parent_node = child.split_node(prefix_len) + split_parent_node = child.split_node( + prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops + ) if split_parent_node.is_leaf(): self.evict_tree_set.add(split_parent_node) @@ -199,13 +262,23 @@ def _insert_helper_no_recursion( assert False, "can not run to here" elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if prefix_len == 0: + return 0, node if child.is_leaf(): self.evict_tree_set.discard(child) + new_extra_value = self._slice_extra(extra_value, prefix_len, len(key)) key = key[prefix_len:] value = value[prefix_len:] - split_parent_node = child.split_node(prefix_len) - new_node = split_parent_node.add_and_return_new_child(key, value) + split_parent_node = child.split_node( + prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops + ) + new_node = split_parent_node.add_and_return_new_child( + key, + value, + token_extra_value=new_extra_value, + child_key=self._child_key(key), + ) # update total token num self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) @@ -218,12 +291,23 @@ def _insert_helper_no_recursion( self.evict_tree_set.add(child) return prefix_len, new_node elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + return ( + prefix_len, + child, + key[prefix_len:], + value[prefix_len:], + self._slice_extra(extra_value, prefix_len, len(key)), + ) else: assert False, "can not run to here" else: - new_node = node.add_and_return_new_child(key, value) + new_node = node.add_and_return_new_child( + key, + value, + token_extra_value=extra_value, + child_key=first_key_id, + ) # update total token num self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) if new_node.is_leaf(): @@ -231,7 +315,9 @@ def _insert_helper_no_recursion( return 0, new_node def match_prefix(self, key, update_refs=False): - assert len(key) != 0 + key = key[: self._align_len(len(key))] + if len(key) == 0: + return None, 0, None ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if tree_node != self.root_node: @@ -290,20 +376,24 @@ def _match_prefix_helper_no_recursion( if len(key) == 0: return node - first_key_id = key[0].item() + first_key_id = self._child_key(key) if first_key_id not in node.children.keys(): return node else: child = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) + prefix_len = self._match_len(key, child.token_id_key) if prefix_len == len(child.token_id_key): ans_value_list.append(child.token_mem_index_value) return (child, key[prefix_len:]) elif prefix_len < len(child.token_id_key): + if prefix_len == 0: + return node if child.is_leaf(): self.evict_tree_set.discard(child) - split_parent_node = child.split_node(prefix_len) + split_parent_node = child.split_node( + prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops + ) ans_value_list.append(split_parent_node.token_mem_index_value) if update_refs: @@ -334,6 +424,8 @@ def evict(self, need_remove_tokens, evict_callback): ), "error evict tree node state" num_evicted += len(node.token_mem_index_value) evict_callback(node.token_mem_index_value) + if self.extra_value_ops is not None and node.token_extra_value is not None: + self.extra_value_ops.free(node.token_extra_value) # update total token num self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) parent_node: TreeNode = node.parent @@ -369,11 +461,12 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: child_node.token_mem_index_value = torch.cat( [parent_node.token_mem_index_value, child_node.token_mem_index_value] ) + child_node.token_extra_value = self._concat_extra([parent_node.token_extra_value, child_node.token_extra_value]) child_node.node_value_len = len(child_node.token_mem_index_value) child_node.time_id = max(parent_node.time_id, child_node.time_id) grandparent_node = parent_node.parent - key_in_grandparent = parent_node.token_id_key[0].item() + key_in_grandparent = self._child_key(parent_node.token_id_key) grandparent_node.children[key_in_grandparent] = child_node child_node.parent = grandparent_node @@ -469,6 +562,19 @@ def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: ans_list.reverse() return torch.concat(ans_list, dim=0) + def get_extra_value_by_node(self, node: TreeNode): + if node is None or self.extra_value_ops is None: + return None + + ans_list = [] + while node is not None: + if node.token_extra_value is not None: + ans_list.append(node.token_extra_value) + node = node.parent + + ans_list.reverse() + return self._concat_extra(ans_list) + def get_refed_tokens_num(self): return self.refed_tokens_num.arr[0] diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index abeb8d61e9..3fd6e0463a 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -131,7 +131,8 @@ def free_a_req_mem( free_c4_index: Optional[List] = None, free_c128_index: Optional[List] = None, ): - if hasattr(self.req_manager, "pop_compress_indices_for_req"): + is_dsv4_req_manager = hasattr(self.req_manager, "build_prompt_cache_payload") + if hasattr(self.req_manager, "pop_compress_indices_for_req") and not is_dsv4_req_manager: c4, c128 = self.req_manager.pop_compress_indices_for_req(req.req_idx) if c4 is not None and free_c4_index is not None: free_c4_index.append(c4) @@ -141,9 +142,24 @@ def free_a_req_mem( if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) + if is_dsv4_req_manager: + c4, c128 = self.req_manager.pop_compress_indices_for_req(req.req_idx) + if c4 is not None and free_c4_index is not None: + free_c4_index.append(c4) + if c128 is not None and free_c128_index is not None: + free_c128_index.append(c128) + self.req_manager.clear_runtime_state(req.req_idx) else: if not self.is_linear_att_mixed_model: - self._full_att_free_req(free_token_index=free_token_index, req=req) + if is_dsv4_req_manager: + self._dsv4_full_att_free_req( + free_token_index=free_token_index, + req=req, + free_c4_index=free_c4_index, + free_c128_index=free_c128_index, + ) + else: + self._full_att_free_req(free_token_index=free_token_index, req=req) else: self._linear_att_free_req(free_token_index=free_token_index, req=req) assert len(req.linear_att_len_to_big_page_id) == 0 @@ -151,6 +167,11 @@ def free_a_req_mem( req.shm_req.shm_cur_kv_len = req.cur_kv_len return + def _append_free_token_index(self, free_token_index: List, tensor: torch.Tensor): + if tensor.numel() > 0: + free_token_index.append(tensor) + return + def _full_att_free_req(self, free_token_index: List, req: "InferReq"): input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") @@ -166,6 +187,86 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"): req.shared_kv_node = None return + def _dsv4_full_att_free_req( + self, + free_token_index: List, + req: "InferReq", + free_c4_index: Optional[List] = None, + free_c128_index: Optional[List] = None, + ): + if req.cur_kv_len == 0: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0:0]) + return + + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + cache_len = self.radix_cache.align_len(req.cur_kv_len) + inserted_len = old_prefix_len + duplicate_prefix_len = old_prefix_len + inserted_payload = None + pending_payload = getattr(req, "prompt_cache_snapshot_payload", None) + pending_cache_len = getattr(req, "prompt_cache_snapshot_len", 0) + + # The current V4 runtime state is only guaranteed to describe the current + # sequence end. Cache aligned current ends; leave unaligned tails uncached. + if pending_payload is not None and pending_cache_len > old_prefix_len: + cache_len = pending_cache_len + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() + duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value, extra_value=pending_payload) + inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len + if inserted_len == cache_len: + inserted_payload = pending_payload + else: + self.req_manager.release_prompt_cache_detached_swa(pending_payload) + pending_payload = None + inserted_len = old_prefix_len + duplicate_prefix_len = old_prefix_len + elif cache_len == req.cur_kv_len and cache_len > old_prefix_len: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() + payload = self.req_manager.build_prompt_cache_payload(req.req_idx, cache_len) + duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value, extra_value=payload) + inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len + if inserted_len == cache_len: + inserted_payload = payload + self.req_manager.detach_prompt_cache_payload_from_req(req.req_idx, inserted_payload) + else: + inserted_len = old_prefix_len + duplicate_prefix_len = old_prefix_len + + if ( + pending_payload is not None + and inserted_payload is not pending_payload + and pending_cache_len <= old_prefix_len + ): + self.req_manager.release_prompt_cache_detached_swa(pending_payload) + req.prompt_cache_snapshot_payload = None + req.prompt_cache_snapshot_len = 0 + dense_row = self.req_manager.req_to_token_indexs[req.req_idx] + self._append_free_token_index(free_token_index, dense_row[old_prefix_len:duplicate_prefix_len]) + self._append_free_token_index(free_token_index, dense_row[inserted_len : req.cur_kv_len]) + if len(free_token_index) == 0: + free_token_index.append(dense_row[0:0]) + + c4, c128 = self.req_manager.pop_prompt_cache_free_compress_indices( + req.req_idx, + keep_len=inserted_len, + duplicate_start_len=old_prefix_len, + duplicate_end_len=duplicate_prefix_len, + ) + if c4 is not None and free_c4_index is not None: + free_c4_index.append(c4) + if c128 is not None and free_c128_index is not None: + free_c128_index.append(c128) + + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= max(inserted_len, old_prefix_len) + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + return + def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert g_infer_context.is_linear_att_mixed_model is True args = get_env_start_args() @@ -637,6 +738,8 @@ def _init_all_state(self): g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self) if hasattr(g_infer_context.req_manager, "init_compress_state"): g_infer_context.req_manager.init_compress_state(req_idx=self.req_idx) + self.prompt_cache_snapshot_len = 0 + self.prompt_cache_snapshot_payload = None self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list() # token healing mode 才被使用的管理对象 @@ -672,6 +775,11 @@ def _match_radix_cache(self): ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + if hasattr(g_infer_context.req_manager, "restore_prompt_cache_payload"): + payload = g_infer_context.radix_cache.get_extra_value_by_node(share_node) + if payload is None: + raise RuntimeError("DeepSeek-V4 radix cache hit is missing prompt-cache payload") + g_infer_context.req_manager.restore_prompt_cache_payload(self.req_idx, payload) self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 @@ -843,6 +951,7 @@ def get_input_token_ids(self): def get_chuncked_input_token_ids(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size) + chunked_end = self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end) return self.shm_req.shm_prompt_ids.arr[0:chunked_end] def get_chuncked_input_token_ids_for_linear_att(self): @@ -863,6 +972,17 @@ def get_chuncked_input_token_ids_for_linear_att(self): def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size) + return self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end) + + def _align_chuncked_end_for_prompt_cache(self, chunked_start: int, chunked_end: int): + radix_cache = g_infer_context.radix_cache + page_size = getattr(radix_cache, "page_size", 1) if radix_cache is not None else 1 + if page_size <= 1 or self.sampling_param.disable_prompt_cache: + return chunked_end + prompt_end = self.shm_req.input_len + next_page_end = ((int(chunked_start) // page_size) + 1) * page_size + if int(chunked_start) < next_page_end < int(chunked_end) and next_page_end <= prompt_end: + return next_page_end return chunked_end def get_chuncked_input_token_len_for_linear_att(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0220dc87fb..74fdb1e87b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -171,6 +171,8 @@ def init_model(self, kvargs): self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.is_linear_att_mixed_model = isinstance(self.model.req_manager, ReqManagerForMamba) + if hasattr(self.model.req_manager, "build_prompt_cache_payload"): + self.support_overlap = False if self.is_linear_att_mixed_model: self.linear_att_cache_manager = LinearAttCacheManager( @@ -182,6 +184,7 @@ def init_model(self, kvargs): if not self.use_dynamic_prompt_cache: self.radix_cache = None + setattr(self.args, "dynamic_prompt_cache_page_size", 1) else: if self.is_linear_att_mixed_model: self.radix_cache = LinearAttPagedRadixCache( @@ -193,12 +196,21 @@ def init_model(self, kvargs): kv_cache_mem_manager=self.model.mem_manager, linear_att_small_page_buffers=self.linear_att_cache_manager, ) + setattr(self.args, "dynamic_prompt_cache_page_size", 1) else: + radix_page_size = 1 + radix_extra_value_ops = None + if hasattr(self.model.req_manager, "get_prompt_cache_value_ops"): + radix_page_size = self.model.req_manager.get_prompt_cache_page_size() + radix_extra_value_ops = self.model.req_manager.get_prompt_cache_value_ops() + setattr(self.args, "dynamic_prompt_cache_page_size", radix_page_size) self.radix_cache = RadixCache( unique_name=get_unique_server_name(), total_token_num=self.model.mem_manager.size, rank_in_node=self.rank_in_node, mem_manager=self.model.mem_manager, + page_size=radix_page_size, + extra_value_ops=radix_extra_value_ops, ) if "prompt_cache_kv_buffer" in model_cfg: @@ -701,6 +713,31 @@ def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): """ pass + def _maybe_capture_prompt_cache_payload(self, req_obj: InferReq): + if self.radix_cache is None: + return + req_manager = g_infer_context.req_manager + if not hasattr(req_manager, "build_prompt_cache_payload"): + return + if req_obj.sampling_param.disable_prompt_cache: + return + page_size = getattr(self.args, "dynamic_prompt_cache_page_size", 1) + cache_len = int(req_obj.cur_kv_len) + if page_size <= 1 or cache_len <= 0 or cache_len % page_size != 0: + return + if cache_len > req_obj.shm_req.input_len: + return + if getattr(req_obj, "prompt_cache_snapshot_len", 0) >= cache_len: + return + + payload = req_manager.build_prompt_cache_payload(req_obj.req_idx, cache_len, clone_swa=True) + old_payload = getattr(req_obj, "prompt_cache_snapshot_payload", None) + if old_payload is not None: + req_manager.release_prompt_cache_detached_swa(old_payload, keep_payload=payload) + req_obj.prompt_cache_snapshot_len = cache_len + req_obj.prompt_cache_snapshot_payload = payload + return + # 一些可以复用的通用功能函数 def _pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: bool) -> List[InferReqUpdatePack]: update_func_objs: List[InferReqUpdatePack] = [] @@ -748,6 +785,7 @@ def _post_handle( ): req_obj: InferReq = req_obj pack: InferReqUpdatePack = pack + self._maybe_capture_prompt_cache_payload(req_obj) pack.handle( next_token_id=next_token_id, next_token_logprob=next_token_logprob, From 61eed870b00927005103a99102ab5dc7ead157df Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 5 Jun 2026 09:08:32 +0000 Subject: [PATCH 04/30] support cudagraph --- lightllm/common/basemodel/basemodel.py | 44 +- .../deepseek4_mem_manager.py | 66 +++ lightllm/common/req_manager.py | 20 + .../deepseek_v4/layer_infer/attention.py | 47 +- .../deepseek_v4/layer_infer/compressor.py | 64 +++ .../layer_infer/transformer_layer_infer.py | 473 +++++++++++++----- lightllm/models/deepseek_v4/model.py | 30 +- 7 files changed, 615 insertions(+), 129 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index d785991808..8e352519c0 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -577,6 +577,12 @@ def _decode( model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) + if hasattr(self.mem_manager, "prepare_decode_swa_slots"): + self.mem_manager.prepare_decode_swa_slots( + model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes + ) + if hasattr(self.req_manager, "prepare_decode_compress_slots"): + self.req_manager.prepare_decode_compress_slots(model_input.b_req_idx, model_input.b_seq_len) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -598,6 +604,12 @@ def _decode( model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) + if hasattr(self.mem_manager, "prepare_decode_swa_slots"): + self.mem_manager.prepare_decode_swa_slots( + model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes + ) + if hasattr(self.req_manager, "prepare_decode_compress_slots"): + self.req_manager.prepare_decode_compress_slots(model_input.b_req_idx, model_input.b_seq_len) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -633,7 +645,13 @@ def prefill_func(input_tensors, infer_state): handle_token_num = infer_state.input_ids.shape[0] - if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num): + can_run_prefill_graph = self.prefill_graph is not None and self.prefill_graph.can_run( + handle_token_num=handle_token_num + ) + if can_run_prefill_graph and hasattr(self, "_can_run_prefill_cudagraph"): + can_run_prefill_graph = self._can_run_prefill_cudagraph(infer_state, handle_token_num) + + if can_run_prefill_graph: finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num( handle_token_num=handle_token_num ) @@ -846,6 +864,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) + if hasattr(self.mem_manager, "prepare_decode_swa_slots"): + self.mem_manager.prepare_decode_swa_slots( + padded_model_input0.b_req_idx, padded_model_input0.b_seq_len, padded_model_input0.mem_indexes + ) + self.mem_manager.prepare_decode_swa_slots( + padded_model_input1.b_req_idx, padded_model_input1.b_seq_len, padded_model_input1.mem_indexes + ) + if hasattr(self.req_manager, "prepare_decode_compress_slots"): + self.req_manager.prepare_decode_compress_slots( + padded_model_input0.b_req_idx, padded_model_input0.b_seq_len + ) + self.req_manager.prepare_decode_compress_slots( + padded_model_input1.b_req_idx, padded_model_input1.b_seq_len + ) infer_state0 = self._create_inferstate(padded_model_input0, 0) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -887,6 +919,16 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode else: model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) + if hasattr(self.mem_manager, "prepare_decode_swa_slots"): + self.mem_manager.prepare_decode_swa_slots( + model_input0.b_req_idx, model_input0.b_seq_len, model_input0.mem_indexes + ) + self.mem_manager.prepare_decode_swa_slots( + model_input1.b_req_idx, model_input1.b_seq_len, model_input1.mem_indexes + ) + if hasattr(self.req_manager, "prepare_decode_compress_slots"): + self.req_manager.prepare_decode_compress_slots(model_input0.b_req_idx, model_input0.b_seq_len) + self.req_manager.prepare_decode_compress_slots(model_input1.b_req_idx, model_input1.b_seq_len) infer_state0 = self._create_inferstate(model_input0, 0) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 6735b2deed..dc708e0790 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -561,6 +561,37 @@ def ensure_swa_slots(self, req_idx: int, positions: torch.Tensor, full_slots: to out[i] = swa return out + def prepare_decode_swa_slots( + self, + b_req_idx: torch.Tensor, + b_seq_len: torch.Tensor, + mem_index: torch.Tensor, + ) -> None: + if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: + return + + reqs = b_req_idx.detach().cpu().tolist() + seqs = b_seq_len.detach().cpu().tolist() + fulls = mem_index.detach().cpu().tolist() + hold = self.swa_pool.HOLD_TOKEN_MEMINDEX + for req_idx, seq_len, full in zip(reqs, seqs, fulls): + req_idx = int(req_idx) + full = int(full) + if req_idx == self.max_request_num or full == self.HOLD_TOKEN_MEMINDEX: + continue + ring_pos = (int(seq_len) - 1) % int(self.sliding_window) + old_swa = int(self.req_to_swa_indexs[req_idx, ring_pos].item()) + old_full = int(self.req_to_swa_full_indexs[req_idx, ring_pos].item()) + if old_swa == hold: + old_swa = int(self.swa_allocator.alloc(1)[0].item()) + if old_full >= 0 and old_full != full: + self.full_to_swa_indexs[old_full] = -1 + self.req_to_swa_indexs[req_idx, ring_pos] = old_swa + self.req_to_swa_full_indexs[req_idx, ring_pos] = full + self.full_to_swa_indexs[full] = old_swa + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = hold + return + def _reserve_prefill_swa_slots( self, req_idx: int, @@ -839,6 +870,41 @@ def pack_mla_kv_to_cache( return self.swa_pool.write(layer_index, swa_slots, packed) + def pack_decode_mla_kv_to_cache( + self, + layer_index: int, + b_req_idx: torch.Tensor, + b_seq_len: torch.Tensor, + mem_index: torch.Tensor, + kv: torch.Tensor, + ): + if kv.shape[0] == 0: + return + packed = self._pack_mla_kv(kv) + if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: + swa_slots = self._identity_swa_slots(mem_index).to(kv.device) + else: + req = b_req_idx.long() + ring = ((b_seq_len.long() - 1) % int(self.sliding_window)).long() + swa_slots = self.req_to_swa_indexs[req, ring].long() + + old_full = self.req_to_swa_full_indexs[req, ring].long() + full_slots = mem_index.long() + old_full = torch.where(old_full >= 0, old_full, full_slots) + self.full_to_swa_indexs[old_full] = torch.full( + old_full.shape, + -1, + dtype=self.full_to_swa_indexs.dtype, + device=old_full.device, + ) + + self.req_to_swa_full_indexs[req, ring] = full_slots.to(torch.int32) + self.full_to_swa_indexs[full_slots] = swa_slots.to(torch.int32) + self.swa_pool.write(layer_index, swa_slots.to(kv.device), packed) + + def gather_mla_kv_from_swa_slots(self, layer_index: int, swa_slots: torch.Tensor) -> torch.Tensor: + return self._unpack_mla_kv(self.swa_pool.read(layer_index, swa_slots.to(self.kv_buffer.device))) + def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor): if comp.shape[0] == 0: return diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 606469d48e..ca027a63c8 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -542,6 +542,26 @@ def ensure_compress_slots(self, layer_index: int, req_idx: int, entry_start: int return self.ensure_c128_slots(req_idx, entry_start, entry_count) raise AssertionError(f"layer {layer_index} is not a compressed attention layer") + def prepare_decode_compress_slots(self, b_req_idx: torch.Tensor, b_seq_len: torch.Tensor) -> None: + req_list = b_req_idx.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + for req_idx, seq_len in zip(req_list, seq_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + seq_len = int(seq_len) + if self.n_c4 > 0: + required_c4 = seq_len // 4 + old_c4 = self._c4_entry_counts[req_idx] + if required_c4 > old_c4: + self.ensure_c4_slots(req_idx, old_c4, required_c4 - old_c4) + if self.n_c128 > 0: + required_c128 = seq_len // 128 + old_c128 = self._c128_entry_counts[req_idx] + if required_c128 > old_c128: + self.ensure_c128_slots(req_idx, old_c128, required_c128 - old_c128) + return + def pop_compress_indices_for_req(self, req_idx: int): c4_count = self._c4_entry_counts[req_idx] if c4_count > 0: diff --git a/lightllm/models/deepseek_v4/layer_infer/attention.py b/lightllm/models/deepseek_v4/layer_infer/attention.py index a24949696f..8a7428f0dd 100644 --- a/lightllm/models/deepseek_v4/layer_infer/attention.py +++ b/lightllm/models/deepseek_v4/layer_infer/attention.py @@ -46,9 +46,13 @@ def _pad_heads_for_flashmla(q, attn_sink): def _torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale): - q0 = q[0].float() - kv0 = kv[0].float() - indices = topk_idxs[0].long() + return _torch_sparse_attn_flat(q[0], kv[0], attn_sink, topk_idxs[0], scale).unsqueeze(0) + + +def _torch_sparse_attn_flat(q, kv, attn_sink, topk_idxs, scale): + q0 = q.float() + kv0 = kv.float() + indices = topk_idxs.long() valid = (indices >= 0) & (indices < kv0.shape[0]) safe_indices = torch.where(valid, indices, torch.zeros_like(indices)) kv_sel = kv0[safe_indices] @@ -60,7 +64,7 @@ def _torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale): exp_sink = torch.exp(sink - max_scores) denom = exp_scores.sum(dim=-1) + exp_sink out = torch.einsum("mhk,mkd->mhd", exp_scores / denom.unsqueeze(-1), kv_sel) - return out.unsqueeze(0).to(q.dtype) + return out.to(q.dtype) def vllm_sparse_attn(q, kv, attn_sink, topk_idxs, scale): @@ -77,15 +81,40 @@ def vllm_sparse_attn(q, kv, attn_sink, topk_idxs, scale): if q.dtype != torch.bfloat16 or kv.dtype != torch.bfloat16: raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires bf16 q/kv, got {q.dtype}/{kv.dtype}") + return vllm_sparse_attn_flat(q[0], kv[0], attn_sink, topk_idxs[0], scale).unsqueeze(0) + + +def vllm_sparse_attn_flat(q, kv, attn_sink, topk_idxs, scale, already_compact=False): + """FlashMLA sparse attention over a flat KV arena. + + q:[m,h,d], kv:[n,d], topk_idxs:[m,K] int. Indices are global offsets into + the flat kv tensor, so callers can concatenate per-request KV candidates and + run one FlashMLA call for the whole batch. When already_compact=True, each + row must place all valid indices before invalid (-1) entries. + """ + m, h, d = q.shape + if d != 512: + raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires head_dim=512, got {d}") + if q.dtype != torch.bfloat16 or kv.dtype != torch.bfloat16: + raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires bf16 q/kv, got {q.dtype}/{kv.dtype}") + if q.shape[0] == 0: + return q.new_empty((0, h, d)) + if DSV4_DEBUG_TORCH_SPARSE_ATTN: - return _torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale) + return _torch_sparse_attn_flat(q, kv, attn_sink, topk_idxs, scale) from vllm.third_party.flashmla.flash_mla_interface import flash_mla_sparse_fwd - q_pad, sink_pad, real_heads = _pad_heads_for_flashmla(q[0], attn_sink) - indices, topk_lens = _compact_topk_indices(topk_idxs[0].to(torch.int32), kv.shape[1]) + q_pad, sink_pad, real_heads = _pad_heads_for_flashmla(q, attn_sink) + topk_idxs = topk_idxs.to(torch.int32) + if already_compact: + valid = (topk_idxs >= 0) & (topk_idxs < kv.shape[0]) + indices = topk_idxs.contiguous() + topk_lens = valid.sum(dim=-1).to(torch.int32).contiguous() + else: + indices, topk_lens = _compact_topk_indices(topk_idxs, kv.shape[0]) indices = _pad_topk_for_flashmla(indices).unsqueeze(1) - kv_flat = kv[0].unsqueeze(1).contiguous() + kv_flat = kv.unsqueeze(1).contiguous() out, _, _ = flash_mla_sparse_fwd( q=q_pad, kv=kv_flat, @@ -95,7 +124,7 @@ def vllm_sparse_attn(q, kv, attn_sink, topk_idxs, scale): topk_length=topk_lens, out=None, ) - return out[:, :real_heads].unsqueeze(0).to(q.dtype) + return out[:, :real_heads].to(q.dtype) def build_prefill_topk_idxs(seqlen, window, ratio, n_window, device): diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py index c91799f9ee..f51f73829c 100644 --- a/lightllm/models/deepseek_v4/layer_infer/compressor.py +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -459,3 +459,67 @@ def compressor_decode_step( eps, dtype, ) + + +def compressor_decode_step_batch( + x_new, + wkv_w, + wgate_w, + norm_w, + ape, + ratio, + head_dim, + rope_dim, + cos_table, + sin_table, + eps, + state_all, + b_req_idx, + start_pos, +): + """Graph-safe batch decode compressor step. + + Mutates ``state_all`` for the selected request rows and returns one candidate + entry per batch row plus a boolean mask telling which rows closed a + compression window. + """ + overlap = ratio == 4 + d = head_dim + dtype = x_new.dtype + req = b_req_idx.long() + pos = start_pos.long() + pos_mod = pos % ratio + + xf = x_new.float() + kv = F.linear(xf, wkv_w.float()) + score = F.linear(xf, wgate_w.float()) + ape.float().index_select(0, pos_mod) + + kv_state = state_all[req, 0].clone() + score_state = state_all[req, 1].clone() + row = pos_mod + (ratio if overlap else 0) + batch_ids = torch.arange(x_new.shape[0], device=x_new.device) + kv_state[batch_ids, row] = kv + score_state[batch_ids, row] = score + + should_compress = ((pos + 1) % ratio) == 0 + if overlap: + kv_cat = torch.cat([kv_state[:, :ratio, :d], kv_state[:, ratio:, d:]], dim=1) + score_cat = torch.cat([score_state[:, :ratio, :d], score_state[:, ratio:, d:]], dim=1) + entry = (kv_cat * torch.softmax(score_cat, dim=1)).sum(dim=1) + shifted_kv_state = kv_state.clone() + shifted_score_state = score_state.clone() + shifted_kv_state[:, :ratio] = kv_state[:, ratio:] + shifted_score_state[:, :ratio] = score_state[:, ratio:] + kv_state = torch.where(should_compress.view(-1, 1, 1), shifted_kv_state, kv_state) + score_state = torch.where(should_compress.view(-1, 1, 1), shifted_score_state, score_state) + else: + entry = (kv_state * torch.softmax(score_state, dim=1)).sum(dim=1) + + state_all[req, 0] = kv_state + state_all[req, 1] = score_state + + entry = _rmsnorm(entry.to(dtype), norm_w, eps) + comp_pos = torch.clamp(pos + 1 - ratio, min=0) + entry_rope = apply_rotary_emb(entry[:, -rope_dim:], cos_table[comp_pos], sin_table[comp_pos]) + entry = torch.cat([entry[:, :-rope_dim], entry_rope], dim=1) + return entry, should_compress diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 81b45299f9..864104fd32 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -1,19 +1,14 @@ -import os - import torch import torch.nn.functional as F import torch.distributed as dist from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.distributed.communication_op import all_reduce from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from .hyper_connection import hc_pre, hc_post from ..triton_kernel.rotary_emb import apply_rotary_emb -from .compressor import compressor_prefill_state, compressor_decode_step -from .attention import vllm_sparse_attn - - -DSV4_DEBUG_DIRECT_PREFILL_COMP = os.getenv("DSV4_DEBUG_DIRECT_PREFILL_COMP", "0") == "1" -DSV4_DEBUG_DISABLE_COMP_ATTN = os.getenv("DSV4_DEBUG_DISABLE_COMP_ATTN", "0") == "1" +from .compressor import compressor_prefill_state, compressor_decode_step, compressor_decode_step_batch +from .attention import vllm_sparse_attn_flat class DeepseekV4TransformerLayerInfer(TransformerLayerInferTpl): @@ -54,13 +49,18 @@ def __init__(self, layer_num, network_config): self.tp_q_heads = self.n_heads // self.tp_world_size_ self.tp_index_heads = self.index_n_heads // self.tp_world_size_ self.tp_groups = self.o_groups // self.tp_world_size_ + self.tp_q_head_num_ = self.tp_q_heads + self.tp_k_head_num_ = 1 + self.tp_v_head_num_ = 1 + self.tp_o_head_num_ = self.tp_q_heads + self.head_dim_ = self.head_dim self.embed_dim_ = self.hc_mult * self.hidden self.enable_ep_moe = get_env_start_args().enable_ep_moe self.indexer_score_scale = self.index_head_dim ** -0.5 self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 # ------------------------------------------------------------------ forward (HC-wrapped) - def _hc_block(self, streams, infer_state, lw, attn_fn): + def _hc_forward(self, streams, infer_state, lw, attn_forward): residual = streams collapsed, post, comb = hc_pre( streams, @@ -72,7 +72,7 @@ def _hc_block(self, streams, infer_state, lw, attn_fn): self.hc_eps, self.sinkhorn_iters, ) - o = attn_fn(lw.attn_norm_(collapsed, eps=self.eps_), infer_state, lw) + o = attn_forward(self._att_norm(collapsed, infer_state, lw), infer_state, lw) streams = hc_post(o, residual, post, comb, self.hc_mult, self.hidden) residual = streams @@ -86,17 +86,29 @@ def _hc_block(self, streams, infer_state, lw, attn_fn): self.hc_eps, self.sinkhorn_iters, ) - f = self._moe_ffn(lw.ffn_norm_(collapsed, eps=self.eps_), infer_state, lw) + f = self._ffn(self._ffn_norm(collapsed, infer_state, lw), infer_state, lw) return hc_post(f, residual, post, comb, self.hc_mult, self.hidden) def context_forward(self, streams, infer_state, lw): - return self._hc_block(streams, infer_state, lw, self._attention_prefill) + return self._hc_forward(streams, infer_state, lw, self.context_attention_forward) def token_forward(self, streams, infer_state, lw): - return self._hc_block(streams, infer_state, lw, self._attention_decode) + return self._hc_forward(streams, infer_state, lw, self.token_attention_forward) + + def _att_norm(self, x, infer_state, lw): + return lw.attn_norm_(x, eps=self.eps_) + + def _ffn_norm(self, x, infer_state, lw): + return lw.ffn_norm_(x, eps=self.eps_) - # ------------------------------------------------------------------ shared projections - def _qkv(self, x, cos_tok, sin_tok, lw): + # ------------------------------------------------------------------ shared projections / cache + def _select_rope(self, infer_state): + if self.compress_ratio: + return infer_state.position_cos_compress, infer_state.position_sin_compress + return infer_state.position_cos_sliding, infer_state.position_sin_sliding + + def _get_qkv(self, x, infer_state, lw): + cos_tok, sin_tok = self._select_rope(infer_state) T = x.shape[0] qa = lw.q_norm_(lw.wq_a_.mm(x), eps=self.eps_) q = lw.wq_b_.mm(qa).view(T, self.tp_q_heads, self.head_dim).float() @@ -116,10 +128,10 @@ def _qkv(self, x, cos_tok, sin_tok, lw): ], dim=1, ) - return q, kv, qa + return q, kv, qa, cos_tok, sin_tok - def _out_proj(self, o, infer_state, lw): - # o: [T, tp_q_heads, head_dim] -> inverse rope -> grouped low-rank O -> [T, hidden] + def _get_o(self, o, infer_state, lw): + # o: [T, tp_q_heads, head_dim] after inverse rope -> grouped low-rank O -> [T, hidden] T = o.shape[0] o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] o = lw.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] @@ -142,22 +154,36 @@ def _inv_rope(self, o, cos_tok, sin_tok): dim=-1, ) - def _post_dense_kv(self, infer_state, req, start_pos, mem_index, kv): + def _post_cache_kv(self, cache_kv, infer_state, lw, req_idx=None, start_pos=None, mem_index=None): + if req_idx is None or start_pos is None or mem_index is None: + raise RuntimeError("DeepSeek-V4 cache write requires req_idx, start_pos, and mem_index") positions = torch.arange( start_pos, - start_pos + kv.shape[0], + start_pos + cache_kv.shape[0], device=mem_index.device, dtype=torch.long, ) infer_state.mem_manager.pack_mla_kv_to_cache( layer_index=self.layer_num_, mem_index=mem_index, - kv=kv.reshape(kv.shape[0], 1, kv.shape[-1]), - req_idx=req, + kv=cache_kv.reshape(cache_kv.shape[0], 1, cache_kv.shape[-1]), + req_idx=req_idx, positions=positions, ) return + def _get_compressor_state(self, infer_state, req): + cstate_kv, cstate_score = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) + state = { + "cstate_kv": cstate_kv, + "cstate_score": cstate_score, + } + if self.compress_ratio == 4: + idx_state = infer_state.req_manager.get_c4_indexer_compress_state(self.layer_num_) + state["idx_cstate_kv"] = idx_state[req, 0] + state["idx_cstate_score"] = idx_state[req, 1] + return state + def _write_compressed_kv(self, infer_state, req, entry_start, comp): slots = infer_state.req_manager.ensure_compress_slots(self.layer_num_, req, entry_start, comp.shape[0]) if comp.shape[0] == 0: @@ -192,20 +218,61 @@ def _c4_indexer_k_from_cache(self, infer_state, req, ncomp): slots = infer_state.req_manager.req_to_c4_indexs[req, :ncomp].long() return infer_state.mem_manager.gather_c4_indexer_k(self.layer_num_, slots) + def _run_sparse_attention_batch(self, q_chunks, kv_chunks, index_chunks, sink): + q_flat = torch.cat(q_chunks, dim=0) + kv_flat = torch.cat(kv_chunks, dim=0) + max_topk = max(t.shape[-1] for t in index_chunks) + topk = torch.full( + (q_flat.shape[0], max_topk), + -1, + dtype=torch.int32, + device=q_flat.device, + ) + offset = 0 + for idx in index_chunks: + rows = idx.shape[0] + topk[offset : offset + rows, : idx.shape[1]] = idx.to(torch.int32) + offset += rows + return vllm_sparse_attn_flat(q_flat, kv_flat, sink, topk, self.softmax_scale) + # ------------------------------------------------------------------ attention (prefill) - def _attention_prefill(self, x, infer_state, lw): + def context_attention_forward(self, x, infer_state, lw): + q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, lw) + o = self._context_attention_wrapper_run(q, cache_kv, q_lora, x, infer_state, lw) + return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + + def _context_attention_wrapper_run(self, q, cache_kv, q_lora, x, infer_state, lw): + if torch.cuda.is_current_stream_capturing(): + q = q.contiguous() + cache_kv = cache_kv.contiguous() + q_lora = q_lora.contiguous() + x = x.contiguous() + _q = tensor_to_no_ref_tensor(q) + _cache_kv = tensor_to_no_ref_tensor(cache_kv) + _q_lora = tensor_to_no_ref_tensor(q_lora) + _x = tensor_to_no_ref_tensor(x) + + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + o = torch.empty((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) + _o = tensor_to_no_ref_tensor(o) + + def att_func(new_infer_state): + tmp_o = self._context_attention_kernel(_q, _cache_kv, _q_lora, _x, new_infer_state, lw) + assert tmp_o.shape == _o.shape + _o.copy_(tmp_o) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph) + return o + + return self._context_attention_kernel(q, cache_kv, q_lora, x, infer_state, lw) + + def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state, lw): T = x.shape[0] - if self.compress_ratio: - cos_tok, sin_tok = ( - infer_state.position_cos_compress, - infer_state.position_sin_compress, - ) - else: - cos_tok, sin_tok = ( - infer_state.position_cos_sliding, - infer_state.position_sin_sliding, - ) - q, kv, qa = self._qkv(x, cos_tok, sin_tok, lw) sink = lw.attn_sink_.weight o = x.new_empty(T, self.tp_q_heads, self.head_dim) b_req = infer_state.b_req_idx.tolist() @@ -214,17 +281,28 @@ def _attention_prefill(self, x, infer_state, lw): ready_lens = infer_state.b_ready_cache_len.tolist() idx_q, idx_weight = self._indexer_q_weight( x, - qa, + q_lora, infer_state.position_cos_compress, infer_state.position_sin_compress, lw, ) + q_chunks = [] + kv_chunks = [] + index_chunks = [] + out_ranges = [] + kv_offset = 0 + hold_req = infer_state.req_manager.HOLD_REQUEST_ID for req, st, ln, ready_len in zip(b_req, starts, lens, ready_lens): - q_r, kv_r, x_r = q[st : st + ln], kv[st : st + ln], x[st : st + ln] + if req == hold_req: + o[st : st + ln].zero_() + continue + q_r = q[st : st + ln] + cache_kv_r = cache_kv[st : st + ln] + x_r = x[st : st + ln] idx_q_r = None if idx_q is None else idx_q[st : st + ln] idx_weight_r = None if idx_weight is None else idx_weight[st : st + ln] kv_all, dense_base, n_window, ncomp, idx_comp = self._gather_prefill( - x_r, kv_r, req, ready_len, lw, infer_state + x_r, cache_kv_r, req, ready_len, lw, infer_state ) ti = self._topk_idxs_prefill( ln, @@ -237,16 +315,28 @@ def _attention_prefill(self, x, infer_state, lw): idx_comp, idx_weight_r, infer_state, - ) - o[st : st + ln] = vllm_sparse_attn(q_r.unsqueeze(0), kv_all.unsqueeze(0), sink, ti, self.softmax_scale)[0] - self._post_dense_kv( + )[0] + ti = torch.where(ti >= 0, ti + kv_offset, ti).to(torch.int32) + q_chunks.append(q_r) + kv_chunks.append(kv_all) + index_chunks.append(ti) + out_ranges.append((st, ln)) + kv_offset += kv_all.shape[0] + self._post_cache_kv( + cache_kv_r, infer_state, - req, - ready_len, - infer_state.mem_index[st : st + ln], - kv_r, + lw, + req_idx=req, + start_pos=ready_len, + mem_index=infer_state.mem_index[st : st + ln], ) - return self._out_proj(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + if q_chunks: + attn_out = self._run_sparse_attention_batch(q_chunks, kv_chunks, index_chunks, sink) + out_offset = 0 + for st, ln in out_ranges: + o[st : st + ln] = attn_out[out_offset : out_offset + ln] + out_offset += ln + return o def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state): ln = kv_r.shape[0] @@ -271,16 +361,9 @@ def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state): state_pool=cstate_pool, ) comp_slots = self._write_compressed_kv(infer_state, req, 0, comp) - ( - cstate_kv, - cstate_score, - ) = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) + cstate_kv, cstate_score = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) cstate_kv.copy_(ks) cstate_score.copy_(ss) - state = { - "cstate_kv": cstate_kv, - "cstate_score": cstate_score, - } if self.compress_ratio == 4: idx_cstate_pool = infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) idx_comp, idx_ks, idx_ss, idx_cstate_pool = compressor_prefill_state( @@ -304,35 +387,15 @@ def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state): idx_cstate_score = idx_state[req, 1] idx_cstate_kv.copy_(idx_ks) idx_cstate_score.copy_(idx_ss) - state.update( - { - "idx_cstate_kv": idx_cstate_kv, - "idx_cstate_score": idx_cstate_score, - } - ) - infer_state.req_manager.set_runtime_state( - req, - self.layer_num_, - state, - ) ncomp = comp.shape[0] - if DSV4_DEBUG_DISABLE_COMP_ATTN: - return kv_r, 0, ln, 0, None - if not DSV4_DEBUG_DIRECT_PREFILL_COMP: - comp = self._compressed_kv_from_cache(infer_state, req, ncomp) - idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) + comp = self._compressed_kv_from_cache(infer_state, req, ncomp) + idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) return torch.cat([kv_r, comp], dim=0), 0, ln, ncomp, idx_comp return kv_r, 0, ln, 0, None def _gather_prefill_extend(self, x_r, kv_r, req, ready_len, lw, infer_state): if self.compress_ratio: - try: - state = infer_state.req_manager.get_runtime_state(req, self.layer_num_) - except KeyError as exc: - raise RuntimeError( - "DeepSeek-V4 prefill chunk is missing runtime state; radix prompt cache " - "must stay disabled until V4 managed token cache is implemented." - ) from exc + state = self._get_compressor_state(infer_state, req) cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) idx_cstate_pool = ( infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) @@ -392,8 +455,6 @@ def _gather_prefill_extend(self, x_r, kv_r, req, ready_len, lw, infer_state): dense = torch.cat([cached_dense, kv_r], dim=0) comp = self._compressed_kv_from_cache(infer_state, req, ncomp) idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) - if DSV4_DEBUG_DISABLE_COMP_ATTN: - return dense, dense_base, dense.shape[0], 0, None return ( torch.cat([dense, comp], dim=0), dense_base, @@ -444,23 +505,201 @@ def _topk_idxs_prefill( return torch.cat([win, comp], dim=1).int().unsqueeze(0) return win.int().unsqueeze(0) + def _decode_dense_kv_graph(self, infer_state): + req = infer_state.b_req_idx.long() + seq = infer_state.b_seq_len.long() + B = req.shape[0] + device = infer_state.b_seq_len.device + offsets = torch.arange(self.window, device=device, dtype=torch.long) + win_len = torch.minimum(seq, torch.full_like(seq, self.window)) + start = seq - win_len + pos = start.unsqueeze(1) + offsets.unsqueeze(0) + valid = offsets.unsqueeze(0) < win_len.unsqueeze(1) + hold = infer_state.mem_manager.swa_pool.HOLD_TOKEN_MEMINDEX + safe_pos = torch.where(valid, pos, torch.zeros_like(pos)).long() + full_slots = infer_state.req_manager.req_to_token_indexs[req.unsqueeze(1), safe_pos].long() + swa_slots = infer_state.mem_manager.full_to_swa_indexs[full_slots].long() + slot_valid = valid & (swa_slots >= 0) + swa_slots = torch.where(slot_valid, swa_slots, torch.full_like(swa_slots, hold)) + kv = infer_state.mem_manager.gather_mla_kv_from_swa_slots(self.layer_num_, swa_slots.reshape(-1)) + return kv.view(B, self.window, self.head_dim), valid + + def _decode_all_compressed_kv_graph(self, infer_state, ratio): + req = infer_state.b_req_idx.long() + seq = infer_state.b_seq_len.long() + B = req.shape[0] + device = infer_state.b_seq_len.device + max_comp = max(1, infer_state.max_kv_seq_len // ratio) + offsets = torch.arange(max_comp, device=device, dtype=torch.long) + ncomp = torch.div(seq, ratio, rounding_mode="floor") + valid = offsets.unsqueeze(0) < ncomp.unsqueeze(1) + safe_offsets = torch.where(valid, offsets.unsqueeze(0), torch.zeros_like(offsets).unsqueeze(0)) + if ratio == 4: + table = infer_state.req_manager.req_to_c4_indexs + hold = infer_state.mem_manager.c4_pool.HOLD_TOKEN_MEMINDEX + else: + table = infer_state.req_manager.req_to_c128_indexs + hold = infer_state.mem_manager.c128_pool.HOLD_TOKEN_MEMINDEX + slots = table[req.unsqueeze(1), safe_offsets].long() + slots = torch.where(valid, slots, torch.full_like(slots, hold)) + kv = infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots.reshape(-1)) + kv = kv.view(B, max_comp, self.head_dim) + if ratio != 4: + return kv, None, valid, ncomp + idx_k = infer_state.mem_manager.gather_c4_indexer_k(self.layer_num_, slots.reshape(-1)) + idx_k = idx_k.view(B, max_comp, self.index_head_dim) + return kv, idx_k, valid, ncomp + + def _decode_c4_topk_graph(self, idx_q, idx_weight, idx_comp, valid_comp, ncomp, infer_state): + scores = torch.einsum("bhd,bnd->bhn", idx_q.float(), idx_comp.float()) + scores = F.relu(scores) * self.indexer_score_scale + index_scores = (scores * idx_weight.unsqueeze(-1)).sum(dim=1) + if self.tp_world_size_ > 1: + all_reduce(index_scores, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + index_scores = index_scores.masked_fill(~valid_comp, float("-inf")) + top = index_scores.topk(self.index_topk, dim=-1).indices + valid = top < ncomp.unsqueeze(1) + return torch.where(valid, top, torch.zeros_like(top)), valid + + def _decode_compressed_candidates_graph(self, idx_q, idx_weight, infer_state): + if self.compress_ratio == 4: + _, idx_comp, valid_all, ncomp = self._decode_all_compressed_kv_graph(infer_state, 4) + top, valid = self._decode_c4_topk_graph(idx_q, idx_weight, idx_comp, valid_all, ncomp, infer_state) + req = infer_state.b_req_idx.long() + slots = infer_state.req_manager.req_to_c4_indexs[req.unsqueeze(1), top].long() + hold = infer_state.mem_manager.c4_pool.HOLD_TOKEN_MEMINDEX + slots = torch.where(valid, slots, torch.full_like(slots, hold)) + comp = infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots.reshape(-1)) + return comp.view(req.shape[0], self.index_topk, self.head_dim), valid + comp, _, valid, _ = self._decode_all_compressed_kv_graph(infer_state, 128) + return comp, valid + + def _write_decode_compressed_entry_graph(self, x, infer_state, lw, ratio): + req = infer_state.b_req_idx + start_pos = infer_state.b_seq_len.long() - 1 + if ratio == 4: + state_all = infer_state.req_manager.get_c4_compress_state(self.layer_num_) + table = infer_state.req_manager.req_to_c4_indexs + hold = infer_state.mem_manager.c4_pool.HOLD_TOKEN_MEMINDEX + else: + state_all = infer_state.req_manager.get_c128_compress_state(self.layer_num_) + table = infer_state.req_manager.req_to_c128_indexs + hold = infer_state.mem_manager.c128_pool.HOLD_TOKEN_MEMINDEX + + entry, should = compressor_decode_step_batch( + x, + lw.compressor_wkv_.mm_param.weight, + lw.compressor_wgate_.mm_param.weight, + lw.compressor_norm_.weight, + lw.compressor_ape_.weight, + ratio, + self.head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + state_all, + req, + start_pos, + ) + entry_idx = torch.clamp(torch.div(infer_state.b_seq_len.long(), ratio, rounding_mode="floor") - 1, min=0) + slots = table[req.long(), entry_idx].long() + slots = torch.where(should, slots, torch.full_like(slots, hold)) + infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, entry) + + if ratio == 4: + idx_state_all = infer_state.req_manager.get_c4_indexer_compress_state(self.layer_num_) + idx_entry, idx_should = compressor_decode_step_batch( + x, + lw.idx_cmp_wkv_.mm_param.weight, + lw.idx_cmp_wgate_.mm_param.weight, + lw.idx_cmp_norm_.weight, + lw.idx_cmp_ape_.weight, + 4, + self.index_head_dim, + self.rope_dim, + infer_state.cos_compress_table, + infer_state.sin_compress_table, + self.eps_, + idx_state_all, + req, + start_pos, + ) + idx_slots = torch.where(idx_should, slots, torch.full_like(slots, hold)) + infer_state.mem_manager.pack_c4_indexer_k_to_cache(self.layer_num_, idx_slots, idx_entry) + return + # ------------------------------------------------------------------ attention (decode) - def _attention_decode(self, x, infer_state, lw): - B = x.shape[0] # one new token per request + def token_attention_forward(self, x, infer_state, lw): + q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, lw) + if infer_state.is_cuda_graph: + o = self._token_attention_kernel_cuda_graph(q, cache_kv, q_lora, x, infer_state, lw) + else: + o = self._token_attention_kernel(q, cache_kv, q_lora, x, infer_state, lw) + return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + + def _token_attention_kernel_cuda_graph(self, q, cache_kv, q_lora, x, infer_state, lw): + sink = lw.attn_sink_.weight + infer_state.mem_manager.pack_decode_mla_kv_to_cache( + self.layer_num_, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, + cache_kv.reshape(cache_kv.shape[0], 1, cache_kv.shape[-1]), + ) + idx_q, idx_weight = self._indexer_q_weight( + x, + q_lora, + infer_state.position_cos_compress, + infer_state.position_sin_compress, + lw, + ) if self.compress_ratio: - cos_tok, sin_tok = ( - infer_state.position_cos_compress, - infer_state.position_sin_compress, - ) + self._write_decode_compressed_entry_graph(x, infer_state, lw, self.compress_ratio) + + dense_kv, dense_valid = self._decode_dense_kv_graph(infer_state) + B = q.shape[0] + device = q.device + if self.compress_ratio: + comp_kv, comp_valid = self._decode_compressed_candidates_graph(idx_q, idx_weight, infer_state) + kv_all = torch.cat([dense_kv, comp_kv], dim=1) + comp_offsets = torch.arange(comp_kv.shape[1], device=device, dtype=torch.int32) else: - cos_tok, sin_tok = ( - infer_state.position_cos_sliding, - infer_state.position_sin_sliding, + kv_all = dense_kv + comp_valid = None + comp_offsets = None + + total_k = kv_all.shape[1] + base = torch.arange(B, device=device, dtype=torch.int32).unsqueeze(1) * total_k + dense_offsets = torch.arange(self.window, device=device, dtype=torch.int32) + dense_topk = torch.where( + dense_valid, + base + dense_offsets.unsqueeze(0), + torch.full((B, self.window), -1, device=device, dtype=torch.int32), + ) + if self.compress_ratio: + comp_topk = torch.where( + comp_valid, + base + self.window + comp_offsets.unsqueeze(0), + torch.full((B, comp_kv.shape[1]), -1, device=device, dtype=torch.int32), ) - q, kv, qa = self._qkv(x, cos_tok, sin_tok, lw) # [B, heads, hd], [B, hd] + topk = torch.cat([dense_topk, comp_topk], dim=1) + else: + topk = dense_topk + return vllm_sparse_attn_flat( + q, + kv_all.reshape(-1, self.head_dim), + sink, + topk, + self.softmax_scale, + already_compact=True, + ) + + def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state, lw): + B = x.shape[0] # one new token per request idx_q, idx_weight = self._indexer_q_weight( x, - qa, + q_lora, infer_state.position_cos_compress, infer_state.position_sin_compress, lw, @@ -469,23 +708,27 @@ def _attention_decode(self, x, infer_state, lw): b_req = infer_state.b_req_idx.tolist() seqlens = infer_state.b_seq_len.tolist() o = x.new_empty(B, self.tp_q_heads, self.head_dim) + hold_req = infer_state.req_manager.HOLD_REQUEST_ID + q_chunks = [] + kv_chunks = [] + index_chunks = [] + out_rows = [] + kv_offset = 0 for i, (req, seq) in enumerate(zip(b_req, seqlens)): + if req == hold_req: + o[i].zero_() + continue start_pos = seq - 1 - self._post_dense_kv( + self._post_cache_kv( + cache_kv[i : i + 1], infer_state, - req, - start_pos, - infer_state.mem_index[i : i + 1], - kv[i : i + 1], + lw, + req_idx=req, + start_pos=start_pos, + mem_index=infer_state.mem_index[i : i + 1], ) if self.compress_ratio: - try: - stt = infer_state.req_manager.get_runtime_state(req, self.layer_num_) - except KeyError as exc: - raise RuntimeError( - "DeepSeek-V4 decode is missing runtime state; radix prompt cache " - "must stay disabled until V4 managed token cache is implemented." - ) from exc + stt = self._get_compressor_state(infer_state, req) cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) e = compressor_decode_step( x[i], @@ -538,12 +781,7 @@ def _attention_decode(self, x, infer_state, lw): win_kv = self._dense_kv_from_cache(infer_state, req, win_start, seq) comp_kv = self._compressed_kv_from_cache(infer_state, req, seq // self.compress_ratio) idx_comp = self._c4_indexer_k_from_cache(infer_state, req, comp_kv.shape[0]) - if DSV4_DEBUG_DISABLE_COMP_ATTN: - comp_kv = None - idx_comp = None - kv_all = win_kv - else: - kv_all = torch.cat([win_kv, comp_kv], dim=0) + kv_all = torch.cat([win_kv, comp_kv], dim=0) else: win_start = max(0, seq - self.window) win_kv = self._dense_kv_from_cache(infer_state, req, win_start, seq) @@ -559,15 +797,18 @@ def _attention_decode(self, x, infer_state, lw): seq, x.device, infer_state, - ) - o[i] = vllm_sparse_attn( - q[i].view(1, 1, self.tp_q_heads, self.head_dim), - kv_all.unsqueeze(0), - sink, - ti, - self.softmax_scale, )[0, 0] - return self._out_proj(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) + ti = torch.where(ti >= 0, ti + kv_offset, ti).view(1, -1).to(torch.int32) + q_chunks.append(q[i : i + 1]) + kv_chunks.append(kv_all) + index_chunks.append(ti) + out_rows.append(i) + kv_offset += kv_all.shape[0] + if q_chunks: + attn_out = self._run_sparse_attention_batch(q_chunks, kv_chunks, index_chunks, sink) + for row, row_out in zip(out_rows, attn_out): + o[row] = row_out + return o def _indexer_q_weight(self, x, qa, cos_tok, sin_tok, lw): if self.compress_ratio != 4: @@ -705,7 +946,7 @@ def _fp4_experts_marlin(self, x, weights, indices, experts): clamp_limit=float(self.swiglu_limit), ) - def _moe_ffn(self, x, infer_state, lw): + def _ffn(self, x, infer_state, lw): gw = lw.gate_weight_.mm_param.weight logits = F.linear(x.float(), gw.float()).contiguous() weights, indices = self._select_experts(logits, infer_state, lw) diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index c87f2fdebd..915e45b9c9 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -37,12 +37,13 @@ from lightllm.distributed.communication_op import dist_group_manager logger = init_logger(__name__) +DSV4_DECODE_CUDAGRAPH_MAX_LEN = 8192 class DeepseekV4DirectSparseAttBackend(BaseAttBackend): """Lifecycle placeholder for V4 direct attention. - V4 attention is currently driven inside the layer by `vllm_sparse_attn()`, not by the generic + V4 attention is currently driven inside the layer, not by the generic `infer_state.prefill_att_state.prefill_att()` / `decode_att()` backend selector. """ @@ -58,7 +59,7 @@ def init_state(self): return def prefill_att(self, *args, **kwargs): - raise RuntimeError("DeepSeek-V4 attention is executed directly by vllm_sparse_attn() in layer_infer.") + raise RuntimeError("DeepSeek-V4 attention is executed directly in layer_infer.") class DeepseekV4DirectSparseDecodeAttState(BaseDecodeAttState): @@ -66,7 +67,7 @@ def init_state(self): return def decode_att(self, *args, **kwargs): - raise RuntimeError("DeepSeek-V4 attention is executed directly by vllm_sparse_attn() in layer_infer.") + raise RuntimeError("DeepSeek-V4 attention is executed directly in layer_infer.") @ModelRegistry("deepseek_v4") @@ -79,6 +80,7 @@ class DeepseekV4TpPartModel(LlamaTpPartModel): transformer_layer_infer_class = DeepseekV4TransformerLayerInfer infer_state_class = DeepseekV4InferStateInfo + _logged_prefill_graph_prefix_skip = False def _verify_params(self): assert self.load_way == "HF", "only support HF format weights" @@ -137,6 +139,28 @@ def _init_mem_manager(self): self.req_manager.bind_mem_manager(self.mem_manager) return + def _init_cudagraph(self): + if not self.disable_cudagraph and self.graph_max_len_in_batch > DSV4_DECODE_CUDAGRAPH_MAX_LEN: + logger.info( + "DeepSeek-V4 caps decode cudagraph max_len_in_batch from %s to %s for the current " + "graph-safe sparse-attention fallback; longer decode batches run eager.", + self.graph_max_len_in_batch, + DSV4_DECODE_CUDAGRAPH_MAX_LEN, + ) + self.graph_max_len_in_batch = DSV4_DECODE_CUDAGRAPH_MAX_LEN + return super()._init_cudagraph() + + def _can_run_prefill_cudagraph(self, infer_state, handle_token_num): + if infer_state.prefix_total_token_num == 0: + return True + if not self._logged_prefill_graph_prefix_skip: + logger.info( + "DeepSeek-V4 skips prefill cudagraph for prompt-cache extension batches; " + "no-prefix prefill batches still use prefill cudagraph." + ) + self._logged_prefill_graph_prefix_skip = True + return False + def _init_att_backend(self): self.prefill_att_backend = DeepseekV4DirectSparseAttBackend(model=self) self.decode_att_backend = DeepseekV4DirectSparseAttBackend(model=self) From 19866d02ab1a59afc6a65deccd86d05eb9d93459 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 8 Jun 2026 01:57:00 +0000 Subject: [PATCH 05/30] refact tokenizer --- lightllm/models/deepseek3_2/model.py | 71 +++++++++++++++------------- lightllm/models/deepseek_v4/model.py | 61 ++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 37 deletions(-) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index cd33386666..5831044311 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,20 +1,14 @@ import copy from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel -from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import ( - Deepseek3_2TransformerLayerWeight, -) -from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import ( - Deepseek3_2TransformerLayerInfer, -) -from lightllm.common.basemodel.attention import ( - get_nsa_prefill_att_backend_class, - get_nsa_decode_att_backend_class, -) +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight @@ -27,11 +21,24 @@ def _init_att_backend(self): return -class DeepSeekChatTokenizerBase: +class DeepSeekV32Tokenizer: + """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based + encoding_dsv32 module instead of Jinja chat templates. + + DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat + template, so ``apply_chat_template`` would fail without either a manually + supplied ``--chat_template`` file or this wrapper. + """ + def __init__(self, tokenizer): self.tokenizer = tokenizer + # Cache added vocabulary for performance (HuggingFace can be slow). self._added_vocab = None + # ------------------------------------------------------------------ + # Attribute delegation – everything not overridden goes to the inner + # tokenizer so that encode/decode/vocab_size/eos_token_id/… all work. + # ------------------------------------------------------------------ def __getattr__(self, name): return getattr(self.tokenizer, name) @@ -40,9 +47,9 @@ def get_added_vocab(self): self._added_vocab = self.tokenizer.get_added_vocab() return self._added_vocab - def _encode_messages(self, msgs, thinking_mode, kwargs): - raise NotImplementedError("subclass must provide DeepSeek encode_messages") - + # ------------------------------------------------------------------ + # Core override: route apply_chat_template through encode_messages. + # ------------------------------------------------------------------ def apply_chat_template( self, conversation=None, @@ -51,16 +58,27 @@ def apply_chat_template( tokenize=False, add_generation_prompt=True, thinking=None, - enable_thinking=None, **kwargs, ): + from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages, render_tools + msgs = conversation if conversation is not None else messages if msgs is None: raise ValueError("Either 'conversation' or 'messages' must be provided") + # Deep copy to avoid mutating the caller's messages. msgs = copy.deepcopy(msgs) + # Determine thinking mode. + thinking_mode = "thinking" if thinking else "chat" + + # Inject tools into the first system message (or create one) so that + # encode_messages / render_message picks them up. if tools: + # build_prompt passes tools as bare function dicts: + # [{"name": "f", "description": "...", "parameters": {...}}] + # encoding_dsv32's render_message expects OpenAI wrapper format: + # [{"type": "function", "function": {...}}] wrapped_tools = [] for t in tools: if "function" in t: @@ -77,27 +95,16 @@ def apply_chat_template( break if not injected: + # Prepend a system message that carries the tools. msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) - if thinking is None: - thinking = bool(enable_thinking) if enable_thinking is not None else False - thinking_mode = "thinking" if thinking else "chat" - prompt = self._encode_messages(msgs, thinking_mode, kwargs) - - if tokenize: - return self.tokenizer.encode(prompt, add_special_tokens=False) - return prompt - - -class DeepSeekV32Tokenizer(DeepSeekChatTokenizerBase): - """Tokenizer wrapper for DeepSeek-V3.2's Python-based encoding_dsv32 module.""" - - def _encode_messages(self, msgs, thinking_mode, kwargs): - from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages - - return encode_messages( + prompt = encode_messages( msgs, thinking_mode=thinking_mode, drop_thinking=kwargs.get("drop_thinking", True), add_default_bos_token=kwargs.get("add_default_bos_token", True), ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 915e45b9c9..914804fe86 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -1,3 +1,4 @@ +import copy import importlib.util import os @@ -27,7 +28,6 @@ DeepseekV4TransformerLayerInfer, ) from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo -from lightllm.models.deepseek3_2.model import DeepSeekChatTokenizerBase from lightllm.models.llama.yarn_rotary_utils import ( find_correction_range, linear_ramp_mask, @@ -212,13 +212,22 @@ def build(base, factor, orig_max): return -class DeepSeekV4Tokenizer(DeepSeekChatTokenizerBase): +class DeepSeekV4Tokenizer: """Tokenizer wrapper for DeepSeek-V4's Python prompt encoding.""" def __init__(self, tokenizer, model_dir): - super().__init__(tokenizer) + self.tokenizer = tokenizer self.model_dir = model_dir self._encoding_module = None + self._added_vocab = None + + def __getattr__(self, name): + return getattr(self.tokenizer, name) + + def get_added_vocab(self): + if self._added_vocab is None: + self._added_vocab = self.tokenizer.get_added_vocab() + return self._added_vocab def _get_encoding_module(self): if self._encoding_module is not None: @@ -236,12 +245,54 @@ def _get_encoding_module(self): self._encoding_module = module return module - def _encode_messages(self, msgs, thinking_mode, kwargs): + def apply_chat_template( + self, + conversation=None, + messages=None, + tools=None, + tokenize=False, + add_generation_prompt=True, + thinking=None, + enable_thinking=None, + **kwargs, + ): + msgs = conversation if conversation is not None else messages + if msgs is None: + raise ValueError("Either 'conversation' or 'messages' must be provided") + + msgs = copy.deepcopy(msgs) + + if tools: + wrapped_tools = [] + for tool in tools: + if "function" in tool: + wrapped_tools.append(tool) + else: + wrapped_tools.append({"type": "function", "function": tool}) + + injected = False + for msg in msgs: + if msg.get("role") == "system": + existing = msg.get("tools") or [] + msg["tools"] = existing + wrapped_tools + injected = True + break + + if not injected: + msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) + + if thinking is None: + thinking = bool(enable_thinking) if enable_thinking is not None else False + thinking_mode = "thinking" if thinking else "chat" encoding = self._get_encoding_module() - return encoding.encode_messages( + prompt = encoding.encode_messages( msgs, thinking_mode=thinking_mode, drop_thinking=kwargs.get("drop_thinking", True), add_default_bos_token=kwargs.get("add_default_bos_token", True), reasoning_effort=kwargs.get("reasoning_effort"), ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt From 29c6082a897485f0d4d4cbcddd1ec6b1f2f01d3b Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 8 Jun 2026 04:58:53 +0000 Subject: [PATCH 06/30] add statement --- lightllm/models/deepseek_v4/infer_struct.py | 5 ++ .../layer_infer/post_layer_infer.py | 3 +- .../layer_infer/pre_layer_infer.py | 7 +- .../layer_infer/transformer_layer_infer.py | 77 ++++++++++--------- lightllm/models/deepseek_v4/model.py | 24 ++---- 5 files changed, 59 insertions(+), 57 deletions(-) diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py index 6bc402cd28..d0c2745161 100644 --- a/lightllm/models/deepseek_v4/infer_struct.py +++ b/lightllm/models/deepseek_v4/infer_struct.py @@ -1,8 +1,13 @@ import torch from lightllm.common.basemodel import InferStateInfo +from lightllm.common.req_manager import DeepseekV4ReqManager +from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager class DeepseekV4InferStateInfo(InferStateInfo): + req_manager: DeepseekV4ReqManager + mem_manager: DeepseekV4MemoryManager + """Per-token interleaved-rope cos/sin for the two rope variants (sliding / compressed), following the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). Also exposes the full compressed cos/sin tables, which the KV compressor indexes at window positions (not per-token).""" diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py index 87951e7360..c23d03afb7 100644 --- a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py @@ -1,11 +1,12 @@ from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from .hyper_connection import hc_head +from ..infer_struct import DeepseekV4InferStateInfo class DeepseekV4PostLayerInfer(LlamaPostLayerInfer): """Collapse the hc_mult residual streams (hc_head) to [T, hidden], then final norm + lm_head.""" - def token_forward(self, input_embdings, infer_state, layer_weight): + def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): cfg = layer_weight.network_config_ collapsed = hc_head( input_embdings, diff --git a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py index 0be99ecbab..d83e3082b8 100644 --- a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py @@ -2,12 +2,13 @@ import torch.distributed as dist from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.distributed.communication_op import all_reduce +from ..infer_struct import DeepseekV4InferStateInfo class DeepseekV4PreLayerInfer(LlamaPreLayerInfer): """Token embedding, then expand to the hc_mult parallel residual streams [T, hc_mult*hidden].""" - def _embed_and_expand(self, input_ids, infer_state, layer_weight): + def _embed_and_expand(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): emb = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) # [T, hidden] if self.tp_world_size_ > 1: all_reduce(emb, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) @@ -15,8 +16,8 @@ def _embed_and_expand(self, input_ids, infer_state, layer_weight): t, hidden = emb.shape return emb.unsqueeze(1).expand(t, hc_mult, hidden).reshape(t, hc_mult * hidden).contiguous() - def context_forward(self, input_ids, infer_state, layer_weight): + def context_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): return self._embed_and_expand(input_ids, infer_state, layer_weight) - def token_forward(self, input_ids, infer_state, layer_weight): + def token_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): return self._embed_and_expand(input_ids, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 864104fd32..11209d39fd 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -7,6 +7,7 @@ from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from .hyper_connection import hc_pre, hc_post from ..triton_kernel.rotary_emb import apply_rotary_emb +from ..infer_struct import DeepseekV4InferStateInfo from .compressor import compressor_prefill_state, compressor_decode_step, compressor_decode_step_batch from .attention import vllm_sparse_attn_flat @@ -60,7 +61,7 @@ def __init__(self, layer_num, network_config): self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 # ------------------------------------------------------------------ forward (HC-wrapped) - def _hc_forward(self, streams, infer_state, lw, attn_forward): + def _hc_forward(self, streams, infer_state: DeepseekV4InferStateInfo, lw, attn_forward): residual = streams collapsed, post, comb = hc_pre( streams, @@ -89,25 +90,25 @@ def _hc_forward(self, streams, infer_state, lw, attn_forward): f = self._ffn(self._ffn_norm(collapsed, infer_state, lw), infer_state, lw) return hc_post(f, residual, post, comb, self.hc_mult, self.hidden) - def context_forward(self, streams, infer_state, lw): + def context_forward(self, streams, infer_state: DeepseekV4InferStateInfo, lw): return self._hc_forward(streams, infer_state, lw, self.context_attention_forward) - def token_forward(self, streams, infer_state, lw): + def token_forward(self, streams, infer_state: DeepseekV4InferStateInfo, lw): return self._hc_forward(streams, infer_state, lw, self.token_attention_forward) - def _att_norm(self, x, infer_state, lw): + def _att_norm(self, x, infer_state: DeepseekV4InferStateInfo, lw): return lw.attn_norm_(x, eps=self.eps_) - def _ffn_norm(self, x, infer_state, lw): + def _ffn_norm(self, x, infer_state: DeepseekV4InferStateInfo, lw): return lw.ffn_norm_(x, eps=self.eps_) # ------------------------------------------------------------------ shared projections / cache - def _select_rope(self, infer_state): + def _select_rope(self, infer_state: DeepseekV4InferStateInfo): if self.compress_ratio: return infer_state.position_cos_compress, infer_state.position_sin_compress return infer_state.position_cos_sliding, infer_state.position_sin_sliding - def _get_qkv(self, x, infer_state, lw): + def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, lw): cos_tok, sin_tok = self._select_rope(infer_state) T = x.shape[0] qa = lw.q_norm_(lw.wq_a_.mm(x), eps=self.eps_) @@ -130,7 +131,7 @@ def _get_qkv(self, x, infer_state, lw): ) return q, kv, qa, cos_tok, sin_tok - def _get_o(self, o, infer_state, lw): + def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, lw): # o: [T, tp_q_heads, head_dim] after inverse rope -> grouped low-rank O -> [T, hidden] T = o.shape[0] o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] @@ -154,7 +155,9 @@ def _inv_rope(self, o, cos_tok, sin_tok): dim=-1, ) - def _post_cache_kv(self, cache_kv, infer_state, lw, req_idx=None, start_pos=None, mem_index=None): + def _post_cache_kv( + self, cache_kv, infer_state: DeepseekV4InferStateInfo, lw, req_idx=None, start_pos=None, mem_index=None + ): if req_idx is None or start_pos is None or mem_index is None: raise RuntimeError("DeepSeek-V4 cache write requires req_idx, start_pos, and mem_index") positions = torch.arange( @@ -172,7 +175,7 @@ def _post_cache_kv(self, cache_kv, infer_state, lw, req_idx=None, start_pos=None ) return - def _get_compressor_state(self, infer_state, req): + def _get_compressor_state(self, infer_state: DeepseekV4InferStateInfo, req): cstate_kv, cstate_score = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) state = { "cstate_kv": cstate_kv, @@ -184,26 +187,26 @@ def _get_compressor_state(self, infer_state, req): state["idx_cstate_score"] = idx_state[req, 1] return state - def _write_compressed_kv(self, infer_state, req, entry_start, comp): + def _write_compressed_kv(self, infer_state: DeepseekV4InferStateInfo, req, entry_start, comp): slots = infer_state.req_manager.ensure_compress_slots(self.layer_num_, req, entry_start, comp.shape[0]) if comp.shape[0] == 0: return slots infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, comp) return slots - def _write_c4_indexer_k(self, infer_state, slots, idx_comp): + def _write_c4_indexer_k(self, infer_state: DeepseekV4InferStateInfo, slots, idx_comp): if idx_comp is None or idx_comp.shape[0] == 0: return infer_state.mem_manager.pack_c4_indexer_k_to_cache(self.layer_num_, slots, idx_comp) return - def _dense_kv_from_cache(self, infer_state, req, start_pos, end_pos): + def _dense_kv_from_cache(self, infer_state: DeepseekV4InferStateInfo, req, start_pos, end_pos): if end_pos <= start_pos: return torch.empty((0, self.head_dim), dtype=infer_state.mem_manager.dtype, device="cuda") slots = infer_state.req_manager.req_to_token_indexs[req, start_pos:end_pos].long() return infer_state.mem_manager.gather_mla_kv(self.layer_num_, slots) - def _compressed_kv_from_cache(self, infer_state, req, ncomp): + def _compressed_kv_from_cache(self, infer_state: DeepseekV4InferStateInfo, req, ncomp): if ncomp == 0: return torch.empty((0, self.head_dim), dtype=infer_state.mem_manager.dtype, device="cuda") if self.compress_ratio == 4: @@ -212,7 +215,7 @@ def _compressed_kv_from_cache(self, infer_state, req, ncomp): slots = infer_state.req_manager.req_to_c128_indexs[req, :ncomp].long() return infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots) - def _c4_indexer_k_from_cache(self, infer_state, req, ncomp): + def _c4_indexer_k_from_cache(self, infer_state: DeepseekV4InferStateInfo, req, ncomp): if self.compress_ratio != 4 or ncomp == 0: return None slots = infer_state.req_manager.req_to_c4_indexs[req, :ncomp].long() @@ -236,12 +239,12 @@ def _run_sparse_attention_batch(self, q_chunks, kv_chunks, index_chunks, sink): return vllm_sparse_attn_flat(q_flat, kv_flat, sink, topk, self.softmax_scale) # ------------------------------------------------------------------ attention (prefill) - def context_attention_forward(self, x, infer_state, lw): + def context_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, lw): q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, lw) o = self._context_attention_wrapper_run(q, cache_kv, q_lora, x, infer_state, lw) return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) - def _context_attention_wrapper_run(self, q, cache_kv, q_lora, x, infer_state, lw): + def _context_attention_wrapper_run(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): if torch.cuda.is_current_stream_capturing(): q = q.contiguous() cache_kv = cache_kv.contiguous() @@ -260,7 +263,7 @@ def _context_attention_wrapper_run(self, q, cache_kv, q_lora, x, infer_state, lw o = torch.empty((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) _o = tensor_to_no_ref_tensor(o) - def att_func(new_infer_state): + def att_func(new_infer_state: DeepseekV4InferStateInfo): tmp_o = self._context_attention_kernel(_q, _cache_kv, _q_lora, _x, new_infer_state, lw) assert tmp_o.shape == _o.shape _o.copy_(tmp_o) @@ -271,7 +274,7 @@ def att_func(new_infer_state): return self._context_attention_kernel(q, cache_kv, q_lora, x, infer_state, lw) - def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state, lw): + def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): T = x.shape[0] sink = lw.attn_sink_.weight o = x.new_empty(T, self.tp_q_heads, self.head_dim) @@ -338,7 +341,7 @@ def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state, lw): out_offset += ln return o - def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state): + def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state: DeepseekV4InferStateInfo): ln = kv_r.shape[0] idx_comp = None if ready_len > 0: @@ -393,7 +396,7 @@ def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state): return torch.cat([kv_r, comp], dim=0), 0, ln, ncomp, idx_comp return kv_r, 0, ln, 0, None - def _gather_prefill_extend(self, x_r, kv_r, req, ready_len, lw, infer_state): + def _gather_prefill_extend(self, x_r, kv_r, req, ready_len, lw, infer_state: DeepseekV4InferStateInfo): if self.compress_ratio: state = self._get_compressor_state(infer_state, req) cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) @@ -484,7 +487,7 @@ def _topk_idxs_prefill( idx_q, idx_comp, idx_weight, - infer_state, + infer_state: DeepseekV4InferStateInfo, ): t = torch.arange(seqlen, device=device) abs_pos = t + base_pos @@ -505,7 +508,7 @@ def _topk_idxs_prefill( return torch.cat([win, comp], dim=1).int().unsqueeze(0) return win.int().unsqueeze(0) - def _decode_dense_kv_graph(self, infer_state): + def _decode_dense_kv_graph(self, infer_state: DeepseekV4InferStateInfo): req = infer_state.b_req_idx.long() seq = infer_state.b_seq_len.long() B = req.shape[0] @@ -524,7 +527,7 @@ def _decode_dense_kv_graph(self, infer_state): kv = infer_state.mem_manager.gather_mla_kv_from_swa_slots(self.layer_num_, swa_slots.reshape(-1)) return kv.view(B, self.window, self.head_dim), valid - def _decode_all_compressed_kv_graph(self, infer_state, ratio): + def _decode_all_compressed_kv_graph(self, infer_state: DeepseekV4InferStateInfo, ratio): req = infer_state.b_req_idx.long() seq = infer_state.b_seq_len.long() B = req.shape[0] @@ -550,7 +553,9 @@ def _decode_all_compressed_kv_graph(self, infer_state, ratio): idx_k = idx_k.view(B, max_comp, self.index_head_dim) return kv, idx_k, valid, ncomp - def _decode_c4_topk_graph(self, idx_q, idx_weight, idx_comp, valid_comp, ncomp, infer_state): + def _decode_c4_topk_graph( + self, idx_q, idx_weight, idx_comp, valid_comp, ncomp, infer_state: DeepseekV4InferStateInfo + ): scores = torch.einsum("bhd,bnd->bhn", idx_q.float(), idx_comp.float()) scores = F.relu(scores) * self.indexer_score_scale index_scores = (scores * idx_weight.unsqueeze(-1)).sum(dim=1) @@ -561,7 +566,7 @@ def _decode_c4_topk_graph(self, idx_q, idx_weight, idx_comp, valid_comp, ncomp, valid = top < ncomp.unsqueeze(1) return torch.where(valid, top, torch.zeros_like(top)), valid - def _decode_compressed_candidates_graph(self, idx_q, idx_weight, infer_state): + def _decode_compressed_candidates_graph(self, idx_q, idx_weight, infer_state: DeepseekV4InferStateInfo): if self.compress_ratio == 4: _, idx_comp, valid_all, ncomp = self._decode_all_compressed_kv_graph(infer_state, 4) top, valid = self._decode_c4_topk_graph(idx_q, idx_weight, idx_comp, valid_all, ncomp, infer_state) @@ -574,7 +579,7 @@ def _decode_compressed_candidates_graph(self, idx_q, idx_weight, infer_state): comp, _, valid, _ = self._decode_all_compressed_kv_graph(infer_state, 128) return comp, valid - def _write_decode_compressed_entry_graph(self, x, infer_state, lw, ratio): + def _write_decode_compressed_entry_graph(self, x, infer_state: DeepseekV4InferStateInfo, lw, ratio): req = infer_state.b_req_idx start_pos = infer_state.b_seq_len.long() - 1 if ratio == 4: @@ -630,7 +635,7 @@ def _write_decode_compressed_entry_graph(self, x, infer_state, lw, ratio): return # ------------------------------------------------------------------ attention (decode) - def token_attention_forward(self, x, infer_state, lw): + def token_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, lw): q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, lw) if infer_state.is_cuda_graph: o = self._token_attention_kernel_cuda_graph(q, cache_kv, q_lora, x, infer_state, lw) @@ -638,7 +643,7 @@ def token_attention_forward(self, x, infer_state, lw): o = self._token_attention_kernel(q, cache_kv, q_lora, x, infer_state, lw) return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) - def _token_attention_kernel_cuda_graph(self, q, cache_kv, q_lora, x, infer_state, lw): + def _token_attention_kernel_cuda_graph(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): sink = lw.attn_sink_.weight infer_state.mem_manager.pack_decode_mla_kv_to_cache( self.layer_num_, @@ -695,7 +700,7 @@ def _token_attention_kernel_cuda_graph(self, q, cache_kv, q_lora, x, infer_state already_compact=True, ) - def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state, lw): + def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): B = x.shape[0] # one new token per request idx_q, idx_weight = self._indexer_q_weight( x, @@ -828,7 +833,9 @@ def _indexer_q_weight(self, x, qa, cos_tok, sin_tok, lw): idx_weight = lw.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale return idx_q, idx_weight - def _indexer_topk(self, idx_q, idx_comp, idx_weight, positions_1based, offset, infer_state): + def _indexer_topk( + self, idx_q, idx_comp, idx_weight, positions_1based, offset, infer_state: DeepseekV4InferStateInfo + ): ncomp = idx_comp.shape[0] k = min(self.index_topk, ncomp) if k == 0: @@ -896,7 +903,7 @@ def _topk_idxs_decode( idx_weight, seq_len, device, - infer_state, + infer_state: DeepseekV4InferStateInfo, ): win = torch.arange(win_len, device=device, dtype=torch.long) if comp_kv is None or comp_kv.shape[0] == 0: @@ -946,7 +953,7 @@ def _fp4_experts_marlin(self, x, weights, indices, experts): clamp_limit=float(self.swiglu_limit), ) - def _ffn(self, x, infer_state, lw): + def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, lw): gw = lw.gate_weight_.mm_param.weight logits = F.linear(x.float(), gw.float()).contiguous() weights, indices = self._select_experts(logits, infer_state, lw) @@ -968,10 +975,10 @@ def _ffn(self, x, infer_state, lw): all_reduce(out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return out - def _select_experts(self, logits, infer_state, lw): + def _select_experts(self, logits, infer_state: DeepseekV4InferStateInfo, lw): return self._select_experts_vllm(logits, infer_state, lw) - def _select_experts_vllm(self, logits, infer_state, lw): + def _select_experts_vllm(self, logits, infer_state: DeepseekV4InferStateInfo, lw): from vllm import _custom_ops as ops M = logits.shape[0] diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 914804fe86..687d5f46f0 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -47,10 +47,10 @@ class DeepseekV4DirectSparseAttBackend(BaseAttBackend): `infer_state.prefill_att_state.prefill_att()` / `decode_att()` backend selector. """ - def create_att_prefill_state(self, infer_state): + def create_att_prefill_state(self, infer_state: DeepseekV4InferStateInfo): return DeepseekV4DirectSparsePrefillAttState(backend=self, infer_state=infer_state) - def create_att_decode_state(self, infer_state): + def create_att_decode_state(self, infer_state: DeepseekV4InferStateInfo): return DeepseekV4DirectSparseDecodeAttState(backend=self, infer_state=infer_state) @@ -72,6 +72,9 @@ def decode_att(self, *args, **kwargs): @ModelRegistry("deepseek_v4") class DeepseekV4TpPartModel(LlamaTpPartModel): + req_manager: DeepseekV4ReqManager + mem_manager: DeepseekV4MemoryManager + pre_and_post_weight_class = DeepseekV4PreAndPostLayerWeight transformer_weight_class = DeepseekV4TransformerLayerWeight @@ -80,7 +83,6 @@ class DeepseekV4TpPartModel(LlamaTpPartModel): transformer_layer_infer_class = DeepseekV4TransformerLayerInfer infer_state_class = DeepseekV4InferStateInfo - _logged_prefill_graph_prefix_skip = False def _verify_params(self): assert self.load_way == "HF", "only support HF format weights" @@ -89,11 +91,6 @@ def _verify_params(self): assert self.config["index_n_heads"] % self.tp_world_size_ == 0 return - def _init_some_value(self): - super()._init_some_value() - self.head_dim_ = self.config["head_dim"] - return - def _init_req_manager(self): create_max_seq_len = 0 if self.batch_max_tokens is not None: @@ -115,9 +112,6 @@ def _init_req_manager(self): def _get_compress_rates(self, layer_num): rates = list(self.config["compress_ratios"]) - assert ( - len(rates) >= layer_num - ), f"DeepSeek-V4 compress_ratios length {len(rates)} is shorter than layer_num {layer_num}" return rates[:layer_num] def _init_mem_manager(self): @@ -150,15 +144,9 @@ def _init_cudagraph(self): self.graph_max_len_in_batch = DSV4_DECODE_CUDAGRAPH_MAX_LEN return super()._init_cudagraph() - def _can_run_prefill_cudagraph(self, infer_state, handle_token_num): + def _can_run_prefill_cudagraph(self, infer_state: DeepseekV4InferStateInfo, handle_token_num): if infer_state.prefix_total_token_num == 0: return True - if not self._logged_prefill_graph_prefix_skip: - logger.info( - "DeepSeek-V4 skips prefill cudagraph for prompt-cache extension batches; " - "no-prefix prefill batches still use prefill cudagraph." - ) - self._logged_prefill_graph_prefix_skip = True return False def _init_att_backend(self): From ffafdbf1ed8eec2530f8e541d5978fbc0161d952 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 8 Jun 2026 05:05:10 +0000 Subject: [PATCH 07/30] format --- lightllm/common/req_manager.py | 37 ++++++++-------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index ca027a63c8..7b56129c3f 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -163,13 +163,11 @@ def __init__(self, max_request_num): ) elif self.penalty_counter_mode == "pin_mem_counter": self.req_to_out_token_id_counter = torch.zeros( - (max_request_num + 1, self.vocab_size), - dtype=torch.int32, - device="cpu", - pin_memory=True, + (max_request_num + 1, self.vocab_size), dtype=torch.int32, device="cpu", pin_memory=True ) def init_req_sampling_params(self, req: "InferReq"): + shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty) @@ -199,18 +197,14 @@ def init_req_sampling_params(self, req: "InferReq"): dtype=torch.int32, ).cuda(non_blocking=True) token_id_counter( - prompt_ids=prompt_ids, - out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx], + prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] ) torch.cuda.current_stream().synchronize() return def update_reqs_out_token_counter_gpu( - self, - b_req_idx: torch.Tensor, - next_token_ids: torch.Tensor, - mask: torch.Tensor = None, + self, b_req_idx: torch.Tensor, next_token_ids: torch.Tensor, mask: torch.Tensor = None ): if self.penalty_counter_mode not in ["gpu_counter", "pin_mem_counter"]: return @@ -226,10 +220,7 @@ def update_reqs_out_token_counter_gpu( return def update_reqs_token_counter( - self, - req_objs: List["InferReq"], - next_token_ids: List[int], - accept_mark: Optional[List[List[bool]]] = None, + self, req_objs: List["InferReq"], next_token_ids: List[int], accept_mark: Optional[List[List[bool]]] = None ): if self.penalty_counter_mode != "cpu_counter": return @@ -271,13 +262,7 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List["InferReq"]): class ReqManagerForMamba(ReqManager): - def __init__( - self, - max_request_num, - max_sequence_length, - mem_manager, - linear_config: LinearAttCacheConfig, - ): + def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_config: LinearAttCacheConfig): super().__init__(max_request_num, max_sequence_length, mem_manager) self.mtp_step = get_env_start_args().mtp_step self.big_page_token_num = ( @@ -322,6 +307,7 @@ def get_mamba_cache(self, layer_idx_in_all: int): return conv_states, ssm_states def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req: "InferReq"): + from .linear_att_cache_manager import LinearAttCacheManager big_page_buffers: LinearAttCacheManager = self.mem_manager.linear_att_big_page_buffers @@ -375,15 +361,8 @@ def __init__( indexer_head_dim: Optional[int] = None, ): super().__init__(max_request_num, max_sequence_length, mem_manager) - if mem_manager is not None: - assert isinstance(mem_manager, DeepseekV4MemoryManager) - compress_rates = mem_manager.compress_rates - head_dim = mem_manager.head_dim - indexer_head_dim = mem_manager.indexer_head_dim - assert compress_rates is not None, "DeepSeek-V4 req manager requires compress_rates" - assert head_dim is not None, "DeepSeek-V4 req manager requires head_dim" - assert indexer_head_dim is not None, "DeepSeek-V4 req manager requires indexer_head_dim" + self.mem_manager = mem_manager self.compress_rates = list(compress_rates) self.n_c4 = sum(1 for r in self.compress_rates if r == 4) self.n_c128 = sum(1 for r in self.compress_rates if r == 128) From e8009cb3e053ffe7dbe465c027e4fe6a676181c8 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 11 Jun 2026 02:15:12 +0000 Subject: [PATCH 08/30] pass gsm8k but need review --- .../attention/nsa/fp8_flashmla_sparse.py | 401 +++++- lightllm/common/basemodel/basemodel.py | 80 +- .../fused_moe/fused_moe_weight.py | 36 +- .../meta_weights/fused_moe/impl/__init__.py | 6 + .../meta_weights/fused_moe/impl/mxfp4_impl.py | 44 + .../fused_moe/impl/triton_impl.py | 19 + .../fused_moe/grouped_fused_moe_ep.py | 2 +- .../deepseek4_mem_manager.py | 1218 +++++++--------- lightllm/common/quantization/__init__.py | 9 +- lightllm/common/quantization/deepgemm.py | 78 + lightllm/common/req_manager.py | 693 ++++----- lightllm/models/deepseek_v4/infer_struct.py | 8 +- .../deepseek_v4/layer_infer/attention.py | 149 -- .../deepseek_v4/layer_infer/compressor.py | 645 ++++----- .../layer_infer/hyper_connection.py | 75 +- .../layer_infer/post_layer_infer.py | 8 +- .../layer_infer/pre_layer_infer.py | 19 +- .../layer_infer/transformer_layer_infer.py | 1251 ++++++----------- .../layer_weights/transformer_layer_weight.py | 331 +---- lightllm/models/deepseek_v4/mem_manager.py | 12 - lightllm/models/deepseek_v4/model.py | 86 +- .../destindex_copy_indexer_k_dsv4.py | 92 ++ .../destindex_copy_kv_flashmla_dsv4.py | 121 ++ .../triton_kernel/quant_convert.py | 77 - lightllm/server/api_cli.py | 7 +- .../router/dynamic_prompt/radix_cache.py | 57 + .../server/router/model_infer/infer_batch.py | 167 +-- .../model_infer/mode_backend/base_backend.py | 29 +- 28 files changed, 2591 insertions(+), 3129 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py delete mode 100644 lightllm/models/deepseek_v4/layer_infer/attention.py delete mode 100644 lightllm/models/deepseek_v4/mem_manager.py create mode 100644 lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py create mode 100644 lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 539ade769e..0570adea83 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -1,4 +1,5 @@ import dataclasses +import inspect import torch from typing import TYPE_CHECKING, Tuple @@ -9,6 +10,245 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo +FLASHMLA_INDEX_ALIGN = 64 +# this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up +# to the nearest supported count (zero heads are discarded from the output slice). +FLASHMLA_SUPPORTED_HEADS = (64, 128) + + +def _pad_q_heads(q_4d: torch.Tensor, attn_sink: torch.Tensor): + h_q = q_4d.shape[2] + if h_q in FLASHMLA_SUPPORTED_HEADS: + return q_4d, attn_sink, h_q + target = next((h for h in FLASHMLA_SUPPORTED_HEADS if h >= h_q), None) + assert target is not None, f"num q heads {h_q} exceeds flash_mla support {FLASHMLA_SUPPORTED_HEADS}" + q_pad = torch.nn.functional.pad(q_4d, (0, 0, 0, target - h_q)) + sink_pad = torch.nn.functional.pad(attn_sink, (0, target - h_q)) + return q_pad, sink_pad, h_q + + +class DeepseekV4MissingOperatorError(RuntimeError): + pass + + +def _missing_attention_op(feature: str) -> None: + raise DeepseekV4MissingOperatorError( + f"DeepSeek-V4 {feature} has no production batch operator. The flashmla_kvcache path " + f"(packed swa/c4/c128 pools + paged compressor + indexer top-k) is the supported route; " + f"this legacy/non-flashmla entry point was never wired and is fenced on purpose." + ) + + +def _pad_last_dim(x: torch.Tensor, multiple: int = FLASHMLA_INDEX_ALIGN, value: int = -1) -> torch.Tensor: + pad = (-x.shape[-1]) % multiple + if pad == 0: + return x.contiguous() + out = torch.full((*x.shape[:-1], x.shape[-1] + pad), value, dtype=x.dtype, device=x.device) + out[..., : x.shape[-1]] = x + return out.contiguous() + + +def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor: + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN + + usable = page_size * DSV4_MLA_BYTES_PER_TOKEN + return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN) + + +def _load_flash_mla_with_extra(): + try: + import flash_mla + except Exception as exc: + raise DeepseekV4MissingOperatorError( + "DeepSeek-V4 packed FlashMLA requires the flash_mla package with compiled CUDA extension. " + f"Import failed with: {type(exc).__name__}: {exc}" + ) from exc + + fn = getattr(flash_mla, "flash_mla_with_kvcache", None) + get_mla_metadata = getattr(flash_mla, "get_mla_metadata", None) + missing_symbols = [] + if fn is None: + missing_symbols.append("flash_mla_with_kvcache") + if get_mla_metadata is None: + missing_symbols.append("get_mla_metadata") + if missing_symbols: + raise DeepseekV4MissingOperatorError( + "DeepSeek-V4 requires flash_mla.flash_mla_with_kvcache extra-cache wrapper. " + f"Current module={getattr(flash_mla, '__file__', '')} " + f"is missing symbols {missing_symbols}." + ) + + sig = inspect.signature(fn) + required = { + "attn_sink", + "extra_k_cache", + "extra_indices_in_kvcache", + "topk_length", + "extra_topk_length", + } + missing = sorted(required.difference(sig.parameters)) + if missing: + raise DeepseekV4MissingOperatorError( + "DeepSeek-V4 requires flash_mla.flash_mla_with_kvcache with extra-cache arguments. " + f"Current module={getattr(flash_mla, '__file__', '')} is missing {missing}." + ) + return flash_mla + + +def _build_dsv4_repeated_prefill_reqs(infer_state) -> torch.Tensor: + return torch.repeat_interleave(infer_state.b_req_idx, infer_state.b_q_seq_len.long()) + + +def _build_dsv4_prefill_positions(infer_state) -> torch.Tensor: + total = infer_state.total_token_num - infer_state.prefix_total_token_num + token_offsets = torch.arange(total, dtype=torch.int32, device=infer_state.b_q_seq_len.device) + req_ids = torch.repeat_interleave( + torch.arange(infer_state.batch_size, dtype=torch.long, device=infer_state.b_q_seq_len.device), + infer_state.b_q_seq_len.long(), + ) + local_offsets = token_offsets - infer_state.b_q_start_loc[req_ids] + return infer_state.b_ready_cache_len[req_ids] + local_offsets + + +def _build_dsv4_swa_indices( + req_manager, + mem_manager, + req_idx: torch.Tensor, + positions: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + window = int(mem_manager.sliding_window) + offsets = positions[:, None] - torch.arange(window, dtype=positions.dtype, device=positions.device)[None, :] + valid_pos = offsets >= 0 + safe_offsets = offsets.clamp_min(0).long() + full_slots = req_manager.req_to_token_indexs[req_idx.long()[:, None], safe_offsets] + swa_slots = mem_manager.full_to_swa_indexs[full_slots.long()].to(torch.int32) + indices = torch.where(valid_pos, swa_slots, torch.full_like(swa_slots, -1)) + lengths = torch.clamp(positions + 1, min=1, max=window).to(torch.int32) + return _pad_last_dim(indices.to(torch.int32)).unsqueeze(1), lengths.contiguous() + + +def _gather_dsv4_compress_slots( + infer_state, + mapping: torch.Tensor, + req_idx: torch.Tensor, + valid: torch.Tensor, + offsets: torch.Tensor, + ratio: int, +) -> torch.Tensor: + """条目 g 的压缩槽 = full_to_c*[req_to_token[req, (g+1)*ratio-1]](组末 token 的 full 槽位)。 + 无效条目(超出因果长度/HOLD 行)用位置 0 安全 gather 后由调用方按 valid 掩掉。""" + end_pos = offsets[None, :] * ratio + (ratio - 1) + safe_pos = torch.where(valid, end_pos, torch.zeros_like(end_pos)) + full_slots = infer_state.req_manager.req_to_token_indexs[req_idx.long()[:, None], safe_pos] + return mapping[full_slots.long()].to(torch.int32) + + +def _build_dsv4_c128_indices( + infer_state, + req_idx: torch.Tensor, + positions: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + raw_lengths = (positions + 1) // 128 + lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) + max_len = max(1, int(infer_state.max_kv_seq_len) // 128) + offsets = torch.arange(max_len, dtype=torch.long, device=positions.device) + valid = offsets[None, :] < raw_lengths[:, None] + slots = _gather_dsv4_compress_slots( + infer_state, infer_state.mem_manager.full_to_c128_indexs, req_idx, valid, offsets, 128 + ) + indices = torch.where(valid, slots, torch.full_like(slots, -1)) + return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() + + +def _build_dsv4_c4_indices( + infer_state, + layer_index: int, + req_idx: torch.Tensor, + positions: torch.Tensor, + nsa_dict: dict, +) -> Tuple[torch.Tensor, torch.Tensor]: + """c4(CSA) extra indices: causal all-entries when the entry space fits index_topk, + otherwise Lightning-Indexer scored top-k. Pure tensor ops (decode runs inside cuda graphs).""" + import torch.distributed as dist + import torch.nn.functional as F + from lightllm.distributed.communication_op import all_reduce + + mem_manager = infer_state.mem_manager + raw_lengths = (positions + 1) // 4 + max_entries = max(1, int(infer_state.max_kv_seq_len) // 4) + index_topk = int(nsa_dict["index_topk"]) + offsets = torch.arange(max_entries, dtype=torch.long, device=positions.device) + valid = offsets[None, :] < raw_lengths[:, None] + slots = _gather_dsv4_compress_slots(infer_state, mem_manager.full_to_c4_indexs, req_idx, valid, offsets, 4) + + if max_entries <= index_topk: + lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) + indices = torch.where(valid, slots, torch.full_like(slots, -1)) + return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() + + idx_q = nsa_dict["idx_q"] # [T, H, index_head_dim], rope applied + idx_weight = nsa_dict["idx_weight"] # [T, H] fp32, weight scale applied + score_scale = float(nsa_dict["indexer_score_scale"]) + hold_slot = mem_manager.c4_indexer_pool.HOLD_TOKEN_MEMINDEX + safe_slots = torch.where(valid, slots.long(), torch.full_like(slots.long(), hold_slot)) + k = mem_manager.gather_indexer_k(layer_index, safe_slots.reshape(-1)).view(positions.shape[0], max_entries, -1) + + num_tokens, num_heads = idx_q.shape[0], idx_q.shape[1] + score_chunks = [] + chunk = max(1, min(num_tokens, (16 * 1024 * 1024) // max(1, num_heads * max_entries))) + for start in range(0, num_tokens, chunk): + end = min(num_tokens, start + chunk) + scores = torch.einsum("thd,tnd->thn", idx_q[start:end].float(), k[start:end].float()) + scores = F.relu(scores) * score_scale + score_chunks.append((scores * idx_weight[start:end].unsqueeze(-1)).sum(dim=1)) + index_scores = torch.cat(score_chunks, dim=0) + if int(nsa_dict.get("tp_world_size", 1)) > 1: + all_reduce(index_scores, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + index_scores = index_scores.masked_fill(~valid, float("-inf")) + top = index_scores.topk(index_topk, dim=-1).indices + top_valid = torch.gather(valid, 1, top) + top_slots = torch.gather(slots.long(), 1, top).to(torch.int32) + indices = torch.where(top_valid, top_slots, torch.full_like(top_slots, -1)) + lengths = torch.clamp(torch.minimum(raw_lengths, torch.full_like(raw_lengths, index_topk)), min=1) + return _pad_last_dim(indices).unsqueeze(1), lengths.to(torch.int32).contiguous() + + +def _build_dsv4_extra_metadata( + infer_state, + layer_index: int, + compress_ratio: int, + req_idx: torch.Tensor, + positions: torch.Tensor, + swa_indices: torch.Tensor, + swa_lengths: torch.Tensor, + nsa_dict: dict, +) -> "_Dsv4Metadata": + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE + + if compress_ratio == 0: + return _Dsv4Metadata(swa_indices, swa_lengths) + if compress_ratio == 4: + extra_indices, extra_lengths = _build_dsv4_c4_indices(infer_state, layer_index, req_idx, positions, nsa_dict) + extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(layer_index) + extra_cache = _view_dsv4_flashmla_cache(extra_buffer, DSV4_C4_PAGE_SIZE) + return _Dsv4Metadata(swa_indices, swa_lengths, extra_cache, extra_indices, extra_lengths) + if compress_ratio == 128: + extra_indices, extra_lengths = _build_dsv4_c128_indices(infer_state, req_idx, positions) + extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(layer_index) + extra_cache = _view_dsv4_flashmla_cache(extra_buffer, DSV4_C128_PAGE_SIZE) + return _Dsv4Metadata(swa_indices, swa_lengths, extra_cache, extra_indices, extra_lengths) + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") + + +@dataclasses.dataclass +class _Dsv4Metadata: + swa_indices: torch.Tensor + swa_lengths: torch.Tensor + extra_cache: torch.Tensor = None + extra_indices: torch.Tensor = None + extra_lengths: torch.Tensor = None + + class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) @@ -17,6 +257,7 @@ def __init__(self, model): torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) for _ in range(2) ] + self._flash_mla = None def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) @@ -24,6 +265,11 @@ def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMl def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) + def flash_mla(self): + if self._flash_mla is None: + self._flash_mla = _load_flash_mla_with_extra() + return self._flash_mla + @dataclasses.dataclass class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): @@ -62,6 +308,12 @@ def prefill_att( ) -> torch.Tensor: assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + if att_control.nsa_prefill_dict.get("flashmla_kvcache"): + return self._flashmla_kvcache_prefill_att( + q=q, + packed_kv=k, + nsa_dict=att_control.nsa_prefill_dict, + ) return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control) def _nsa_prefill_att( @@ -78,6 +330,8 @@ def _nsa_prefill_att( kv_lora_rank = nsa_dict["kv_lora_rank"] topk_mem_indices = nsa_dict["topk_mem_indices"] prefill_cache_kv = nsa_dict["prefill_cache_kv"] + attn_sink = nsa_dict.get("attn_sink") + topk_length = nsa_dict.get("topk_length") if self.infer_state.prefix_total_token_num > 0: # 当前推理生成的token kv部分从 prefill_cache_kv 中获取,历史 @@ -101,9 +355,72 @@ def _nsa_prefill_att( indices=topk_indices, sm_scale=softmax_scale, d_v=kv_lora_rank, + attn_sink=attn_sink, + topk_length=topk_length, ) return mla_out + def _build_flashmla_kvcache_prefill_metadata(self, nsa_dict: dict) -> _Dsv4Metadata: + infer_state = self.infer_state + req_idx = _build_dsv4_repeated_prefill_reqs(infer_state) + positions = _build_dsv4_prefill_positions(infer_state) + swa_indices, swa_lengths = _build_dsv4_swa_indices( + infer_state.req_manager, + infer_state.mem_manager, + req_idx, + positions, + ) + return _build_dsv4_extra_metadata( + infer_state, + nsa_dict["layer_index"], + nsa_dict["compress_ratio"], + req_idx, + positions, + swa_indices, + swa_lengths, + nsa_dict, + ) + + def _flashmla_kvcache_prefill_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor: + attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous() + metadata = self._build_flashmla_kvcache_prefill_metadata(nsa_dict) + return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict) + + def _flashmla_kvcache_att( + self, + q: torch.Tensor, + packed_kv: torch.Tensor, + metadata: _Dsv4Metadata, + attn_sink: torch.Tensor, + nsa_dict: dict, + ) -> torch.Tensor: + flash_mla = self.backend.flash_mla() + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE + + q_4d = q.unsqueeze(1).contiguous() + q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink) + k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE) + sched_meta, _ = flash_mla.get_mla_metadata() + out, _ = flash_mla.flash_mla_with_kvcache( + q=q_4d, + k_cache=k_cache, + block_table=None, + cache_seqlens=None, + head_dim_v=nsa_dict["head_dim_v"], + tile_scheduler_metadata=sched_meta, + num_splits=None, + softmax_scale=nsa_dict["softmax_scale"], + causal=False, + is_fp8_kvcache=True, + indices=metadata.swa_indices, + attn_sink=attn_sink, + topk_length=metadata.swa_lengths, + extra_k_cache=metadata.extra_cache, + extra_indices_in_kvcache=metadata.extra_indices, + extra_topk_length=metadata.extra_lengths, + ) + return out[:, 0, :num_real_heads].contiguous() + @dataclasses.dataclass class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): @@ -141,9 +458,10 @@ def init_state(self): ragged_mem_index=self.ragged_mem_index, hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) - import flash_mla - - self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + flash_mla = self.backend.flash_mla() + # one sched_meta per layer type: the lazy config locks extra-cache geometry (page size, + # presence) on first invocation, so swa-only/c4/c128 layers must not share one object. + self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} return def decode_att( @@ -156,6 +474,12 @@ def decode_att( ) -> torch.Tensor: assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + if att_control.nsa_decode_dict.get("flashmla_kvcache"): + return self._flashmla_kvcache_decode_att( + q=q, + packed_kv=k, + nsa_dict=att_control.nsa_decode_dict, + ) return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control) def _nsa_decode_att( @@ -170,6 +494,11 @@ def _nsa_decode_att( topk_mem_indices = nsa_dict["topk_mem_indices"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] + attn_sink = nsa_dict.get("attn_sink") + topk_length = nsa_dict.get("topk_length") + extra_k_cache = nsa_dict.get("extra_k_cache") + extra_indices = nsa_dict.get("extra_indices_in_kvcache") + extra_topk_length = nsa_dict.get("extra_topk_length") if topk_mem_indices.ndim == 2: topk_mem_indices = topk_mem_indices.unsqueeze(1) @@ -189,10 +518,74 @@ def _nsa_decode_att( block_table=None, cache_seqlens=None, head_dim_v=kv_lora_rank, - tile_scheduler_metadata=self.flashmla_sched_meta, + tile_scheduler_metadata=self.flashmla_sched_meta[0], softmax_scale=softmax_scale, causal=False, is_fp8_kvcache=True, indices=topk_mem_indices, + attn_sink=attn_sink, + topk_length=topk_length, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_length, ) return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] + + def _build_flashmla_kvcache_decode_metadata(self, nsa_dict: dict) -> _Dsv4Metadata: + infer_state = self.infer_state + positions = infer_state.b_seq_len.to(torch.int32) - 1 + swa_indices, swa_lengths = _build_dsv4_swa_indices( + infer_state.req_manager, + infer_state.mem_manager, + infer_state.b_req_idx, + positions, + ) + return _build_dsv4_extra_metadata( + infer_state, + nsa_dict["layer_index"], + nsa_dict["compress_ratio"], + infer_state.b_req_idx, + positions, + swa_indices, + swa_lengths, + nsa_dict, + ) + + def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor: + attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous() + metadata = self._build_flashmla_kvcache_decode_metadata(nsa_dict) + return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict) + + def _flashmla_kvcache_att( + self, + q: torch.Tensor, + packed_kv: torch.Tensor, + metadata: _Dsv4Metadata, + attn_sink: torch.Tensor, + nsa_dict: dict, + ) -> torch.Tensor: + flash_mla = self.backend.flash_mla() + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE + + q_4d = q.unsqueeze(1).contiguous() + q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink) + k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE) + out, _ = flash_mla.flash_mla_with_kvcache( + q=q_4d, + k_cache=k_cache, + block_table=None, + cache_seqlens=None, + head_dim_v=nsa_dict["head_dim_v"], + tile_scheduler_metadata=self.flashmla_sched_meta[nsa_dict["compress_ratio"]], + num_splits=None, + softmax_scale=nsa_dict["softmax_scale"], + causal=False, + is_fp8_kvcache=True, + indices=metadata.swa_indices, + attn_sink=attn_sink, + topk_length=metadata.swa_lengths, + extra_k_cache=metadata.extra_cache, + extra_indices_in_kvcache=metadata.extra_indices, + extra_topk_length=metadata.extra_lengths, + ) + return out[:, 0, :num_real_heads].contiguous() diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 8e352519c0..986802e760 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -527,13 +527,17 @@ def _prefill( alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, ) - if hasattr(self.mem_manager, "prepare_prefill_swa_slots"): - self.mem_manager.prepare_prefill_swa_slots( + if hasattr(self.req_manager, "prepare_prefill_swa"): + self.req_manager.prepare_prefill_swa( b_req_idx=infer_state.b_req_idx, + b_ready_cache_len=infer_state.b_ready_cache_len, b_seq_len=infer_state.b_seq_len, + ) + if hasattr(self.req_manager, "prepare_prefill_compress_slots"): + self.req_manager.prepare_prefill_compress_slots( + b_req_idx=infer_state.b_req_idx, b_ready_cache_len=infer_state.b_ready_cache_len, - b_start_loc=model_input.b_prefill_start_loc, - mem_index=infer_state.mem_index, + b_seq_len=infer_state.b_seq_len, ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -541,8 +545,6 @@ def _prefill( infer_state.init_some_extra_state(self) infer_state.init_att_state() model_output = self._context_forward(infer_state) - if hasattr(self.mem_manager, "commit_prefill_swa_slots"): - self.mem_manager.commit_prefill_swa_slots() model_output = self._create_unpad_prefill_model_output( padded_model_output=model_output, @@ -577,12 +579,14 @@ def _decode( model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) - if hasattr(self.mem_manager, "prepare_decode_swa_slots"): - self.mem_manager.prepare_decode_swa_slots( + if hasattr(self.req_manager, "prepare_decode_swa"): + self.req_manager.prepare_decode_swa( model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes ) if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots(model_input.b_req_idx, model_input.b_seq_len) + self.req_manager.prepare_decode_compress_slots( + model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes + ) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -604,12 +608,14 @@ def _decode( model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) - if hasattr(self.mem_manager, "prepare_decode_swa_slots"): - self.mem_manager.prepare_decode_swa_slots( + if hasattr(self.req_manager, "prepare_decode_swa"): + self.req_manager.prepare_decode_swa( model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes ) if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots(model_input.b_req_idx, model_input.b_seq_len) + self.req_manager.prepare_decode_compress_slots( + model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes + ) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -775,13 +781,17 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state0.mem_index, max_q_seq_len=infer_state0.max_q_seq_len, ) - if hasattr(self.mem_manager, "prepare_prefill_swa_slots"): - self.mem_manager.prepare_prefill_swa_slots( + if hasattr(self.req_manager, "prepare_prefill_swa"): + self.req_manager.prepare_prefill_swa( b_req_idx=infer_state0.b_req_idx, + b_ready_cache_len=infer_state0.b_ready_cache_len, b_seq_len=infer_state0.b_seq_len, + ) + if hasattr(self.req_manager, "prepare_prefill_compress_slots"): + self.req_manager.prepare_prefill_compress_slots( + b_req_idx=infer_state0.b_req_idx, b_ready_cache_len=infer_state0.b_ready_cache_len, - b_start_loc=model_input0.b_prefill_start_loc, - mem_index=infer_state0.mem_index, + b_seq_len=infer_state0.b_seq_len, ) infer_state0.init_some_extra_state(self) infer_state0.init_att_state() @@ -796,13 +806,17 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state1.mem_index, max_q_seq_len=infer_state1.max_q_seq_len, ) - if hasattr(self.mem_manager, "prepare_prefill_swa_slots"): - self.mem_manager.prepare_prefill_swa_slots( + if hasattr(self.req_manager, "prepare_prefill_swa"): + self.req_manager.prepare_prefill_swa( b_req_idx=infer_state1.b_req_idx, + b_ready_cache_len=infer_state1.b_ready_cache_len, b_seq_len=infer_state1.b_seq_len, + ) + if hasattr(self.req_manager, "prepare_prefill_compress_slots"): + self.req_manager.prepare_prefill_compress_slots( + b_req_idx=infer_state1.b_req_idx, b_ready_cache_len=infer_state1.b_ready_cache_len, - b_start_loc=model_input1.b_prefill_start_loc, - mem_index=infer_state1.mem_index, + b_seq_len=infer_state1.b_seq_len, ) infer_state1.init_some_extra_state(self) infer_state1.init_att_state() @@ -811,8 +825,6 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod prefill_mem_indexes_ready_event.record() model_output0, model_output1 = self._overlap_tpsp_context_forward(infer_state0, infer_state1=infer_state1) - if hasattr(self.mem_manager, "commit_prefill_swa_slots"): - self.mem_manager.commit_prefill_swa_slots() model_output0 = self._create_unpad_prefill_model_output( padded_model_output=model_output0, @@ -864,19 +876,19 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) - if hasattr(self.mem_manager, "prepare_decode_swa_slots"): - self.mem_manager.prepare_decode_swa_slots( + if hasattr(self.req_manager, "prepare_decode_swa"): + self.req_manager.prepare_decode_swa( padded_model_input0.b_req_idx, padded_model_input0.b_seq_len, padded_model_input0.mem_indexes ) - self.mem_manager.prepare_decode_swa_slots( + self.req_manager.prepare_decode_swa( padded_model_input1.b_req_idx, padded_model_input1.b_seq_len, padded_model_input1.mem_indexes ) if hasattr(self.req_manager, "prepare_decode_compress_slots"): self.req_manager.prepare_decode_compress_slots( - padded_model_input0.b_req_idx, padded_model_input0.b_seq_len + padded_model_input0.b_req_idx, padded_model_input0.b_seq_len, padded_model_input0.mem_indexes ) self.req_manager.prepare_decode_compress_slots( - padded_model_input1.b_req_idx, padded_model_input1.b_seq_len + padded_model_input1.b_req_idx, padded_model_input1.b_seq_len, padded_model_input1.mem_indexes ) infer_state0 = self._create_inferstate(padded_model_input0, 0) copy_kv_index_to_req( @@ -919,16 +931,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode else: model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) - if hasattr(self.mem_manager, "prepare_decode_swa_slots"): - self.mem_manager.prepare_decode_swa_slots( + if hasattr(self.req_manager, "prepare_decode_swa"): + self.req_manager.prepare_decode_swa( model_input0.b_req_idx, model_input0.b_seq_len, model_input0.mem_indexes ) - self.mem_manager.prepare_decode_swa_slots( + self.req_manager.prepare_decode_swa( model_input1.b_req_idx, model_input1.b_seq_len, model_input1.mem_indexes ) if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots(model_input0.b_req_idx, model_input0.b_seq_len) - self.req_manager.prepare_decode_compress_slots(model_input1.b_req_idx, model_input1.b_seq_len) + self.req_manager.prepare_decode_compress_slots( + model_input0.b_req_idx, model_input0.b_seq_len, model_input0.mem_indexes + ) + self.req_manager.prepare_decode_compress_slots( + model_input1.b_req_idx, model_input1.b_seq_len, model_input1.mem_indexes + ) infer_state0 = self._create_inferstate(model_input0, 0) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..24842ed383 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -68,12 +68,13 @@ def __init__( auto_update_redundancy_expert=self.auto_update_redundancy_expert, ) self.lock = threading.Lock() + self._moe_weight_finalized = False self._create_weight() def _init_config(self, network_config: Dict[str, Any]): self.n_group = network_config.get("n_group", 0) self.use_grouped_topk = self.n_group > 0 - self.norm_topk_prob = network_config["norm_topk_prob"] + self.norm_topk_prob = network_config.get("norm_topk_prob", False) self.topk_group = network_config.get("topk_group", 0) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) @@ -136,6 +137,7 @@ def experts( is_prefill: Optional[bool] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" + self._finalize_moe_weight() return self.fuse_moe_impl( input_tensor=input_tensor, router_logits=router_logits, @@ -152,6 +154,25 @@ def experts( per_expert_scale=self.per_expert_scale, ) + def experts_with_preselected( + self, + input_tensor: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, + ) -> torch.Tensor: + self._finalize_moe_weight() + return self.fuse_moe_impl.fused_experts_with_topk( + input_tensor=input_tensor, + w13=self.w13, + w2=self.w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + is_prefill=is_prefill, + clamp_limit=clamp_limit, + ) + def low_latency_dispatch( self, hidden_states: torch.Tensor, @@ -280,7 +301,18 @@ def verify_load(self): e_score_correction_bias_load_ok = ( True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False) ) - return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok + load_ok = weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok + if load_ok: + self._finalize_moe_weight() + return load_ok + + def _finalize_moe_weight(self): + if self._moe_weight_finalized: + return + finalize = getattr(self.quant_method, "finalize_moe_weight", None) + if finalize is not None: + finalize(self) + self._moe_weight_finalized = True def _create_weight(self): intermediate_size = self.split_inter_size diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py index 67bb90e4ef..282c0abdce 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py @@ -2,9 +2,15 @@ from .triton_impl import FuseMoeTriton from .marlin_impl import FuseMoeMarlin from .deepgemm_impl import FuseMoeDeepGEMM +from .mxfp4_impl import FuseMoeMXFP4 def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool): + if quant_method.method_name == "marlin-mxfp4w4a16-b32": + if enable_ep_moe: + raise RuntimeError("marlin-mxfp4w4a16-b32 does not support enable_ep_moe yet") + return FuseMoeMXFP4 + if enable_ep_moe: return FuseMoeDeepGEMM diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py new file mode 100644 index 0000000000..97cf238116 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py @@ -0,0 +1,44 @@ +import torch +from typing import Optional + +from lightllm.common.quantization.quantize_method import WeightPack +from .triton_impl import FuseMoeTriton + + +class FuseMoeMXFP4(FuseMoeTriton): + def create_workspace(self): + return None + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, + ): + try: + from vllm.model_executor.layers.fused_moe.activation import MoEActivation + from vllm.model_executor.layers.fused_moe.experts.marlin_moe import fused_marlin_moe + from vllm.scalar_type import scalar_types + except Exception as e: + raise RuntimeError(f"MXFP4 fused MoE requires vLLM fused kernels, error={repr(e)}") from e + + return fused_marlin_moe( + hidden_states=input_tensor.contiguous(), + w1=w13.weight, + w2=w2.weight, + bias1=None, + bias2=None, + w1_scale=w13.weight_scale, + w2_scale=w2.weight_scale, + topk_weights=topk_weights.to(torch.float32).contiguous(), + topk_ids=topk_ids.to(torch.long).contiguous(), + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=self.n_routed_experts, + activation=MoEActivation.SILU, + clamp_limit=clamp_limit, + ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..09ce88e3fd 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -114,6 +114,25 @@ def _fused_experts( ) return input_tensor + def fused_experts_with_topk( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, + ): + return self._fused_experts( + input_tensor=input_tensor, + w13=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + is_prefill=is_prefill, + ) + def __call__( self, input_tensor: torch.Tensor, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index cb2e370cb9..28fe6e4304 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -49,7 +49,7 @@ def check_ep_expert_dtype(quant_method: Any): "EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], " f"but the resolved fused_moe quant method is `{expert_dtype}`. " "Please start with --expert_dtype fp8 or --expert_dtype fp4. " - "Note that --expert_dtype fp4 is only supported on SM100 GPUs." + "Note that --expert_dtype fp4 with EP MoE is only supported on SM100 GPUs." ) if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu(): raise RuntimeError( diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index dc708e0790..47fdf76fd9 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist -from typing import Dict, List, Optional -from .deepseek2_mem_manager import Deepseek2MemoryManager +from typing import List, Optional, Union +from .mem_manager import MemoryManager from .operator import DeepseekV4MemOperator from .allocator import KvCacheAllocator from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node @@ -12,36 +12,46 @@ logger = init_logger(__name__) +# fp8_ds_mla packed-latent byte layout (ABI shared with the flash_mla extra-cache fork and +# sglang/vllm): 448B NoPE fp8 + 64*2B RoPE bf16 + 7B ue8m0 scale + 1B pad = 584B per token, +# stored in page slabs whose tail carries the per-token scale bytes. DSV4_MLA_NOPE_DIM = 448 DSV4_MLA_ROPE_DIM = 64 DSV4_MLA_HEAD_DIM = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM DSV4_MLA_QUANT_GROUP_SIZE = 64 DSV4_MLA_SCALE_BYTES = DSV4_MLA_NOPE_DIM // DSV4_MLA_QUANT_GROUP_SIZE + 1 DSV4_MLA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 + DSV4_MLA_SCALE_BYTES +DSV4_MLA_DATA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 +DSV4_MLA_PAGE_ALIGN_BYTES = DSV4_MLA_DATA_BYTES_PER_TOKEN DSV4_INDEXER_HEAD_DIM = 128 -DSV4_INDEXER_BYTES_PER_TOKEN = DSV4_INDEXER_HEAD_DIM + 4 +DSV4_INDEXER_SCALE_BYTES = 4 +DSV4_INDEXER_BYTES_PER_TOKEN = DSV4_INDEXER_HEAD_DIM + DSV4_INDEXER_SCALE_BYTES DSV4_FP8_E4M3_MAX = 448.0 DSV4_FP8_SCALE_MIN = 1e-4 -DSV4_MLA_DATA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 -DSV4_MLA_SCALE_TAIL_BYTES = DSV4_MLA_SCALE_BYTES -DSV4_MLA_PAGE_ALIGN_BYTES = DSV4_MLA_DATA_BYTES_PER_TOKEN DSV4_SWA_PAGE_SIZE = 128 DSV4_C4_PAGE_SIZE = 64 DSV4_C128_PAGE_SIZE = 2 +# c4 compressor state ring(overlap 对: 每页 2 个分组槽 × ratio 4 行)。c128 state 在 128 边界 +# 自然归零(在线聚合),无缓存常驻需求,保持 req 键控,不进 swa 派生池。 +DSV4_C4_STATE_RING = 8 DSV4_PROFILE_MAX_FULL_TOKENS = 1_500_000 +# swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 的对应物)。 +# lightllm 的调度准入只看 full 池,prefill 优先的波次会让"已 prefill 未 decode"的请求整段 +# prompt 占住 swa 槽(首次 decode prep 才批量出窗回收),峰值≈准入波次 prompt 总和。在 +# v5 的 swa 压力阀/准入耦合落地前,用比 sglang 更宽的 0.3 兜住该瞬时峰值。 +DSV4_SWA_FULL_TOKENS_RATIO = 0.3 def _ceil_div(a: int, b: int) -> int: return (a + b - 1) // b -class _PageSlabMlaPool: - """SGLang-compatible fp8_ds_mla page-slab storage with token-slot addressing. +class PackedPagePool: + """fp8_ds_mla 风格的 page-slab 存储: 每页前段连续放 token 的 data 字节,页尾放 per-token scale 字节。 - The public loc is still a LightLLM token slot. Internally each page stores all - 576B NoPE+RoPE payloads first and the 8B scale records at the page tail: - data_offset = page * bytes_per_page + token_in_page * 576 - scale_offset = page * bytes_per_page + page_size * 576 + token_in_page * 8 + 寻址是纯 token 槽位 (page = slot // page_size),page 只是 scale-tail/对齐的物理打包技巧, + 不存在页粒度的分配。``write``/``read`` 是 torch 参考实现(单测 oracle);生产写入走 + triton packed writer(destindex_copy_kv_flashmla_dsv4 等),kernel 直接消费 ``buffer``。 """ def __init__( @@ -49,27 +59,26 @@ def __init__( size: int, page_size: int, layer_num: int, + data_bytes: int, + scale_bytes: int, + align_bytes: int = 1, device: str = "cuda", ): self.size = size self.page_size = page_size self.layer_num = layer_num - self.dtype = torch.uint8 - self.data_bytes_per_token = DSV4_MLA_DATA_BYTES_PER_TOKEN - self.scale_bytes_per_token = DSV4_MLA_SCALE_TAIL_BYTES - self.bytes_per_token = DSV4_MLA_BYTES_PER_TOKEN + self.data_bytes_per_token = data_bytes + self.scale_bytes_per_token = scale_bytes + self.bytes_per_token = data_bytes + scale_bytes self.num_pages = _ceil_div(size + 1, page_size) - self.bytes_per_page = ( - _ceil_div(page_size * self.bytes_per_token, DSV4_MLA_PAGE_ALIGN_BYTES) * DSV4_MLA_PAGE_ALIGN_BYTES - ) - self.scale_offset_in_page = page_size * self.data_bytes_per_token - self.kv_buffer = torch.zeros( - (layer_num, self.num_pages, self.bytes_per_page), - dtype=torch.uint8, - device=device, - ) + self.bytes_per_page = _ceil_div(page_size * self.bytes_per_token, align_bytes) * align_bytes + self.scale_offset_in_page = page_size * data_bytes + self.buffer = torch.zeros((layer_num, self.num_pages, self.bytes_per_page), dtype=torch.uint8, device=device) self.HOLD_TOKEN_MEMINDEX = size + def get_layer_buffer(self, layer_index: int) -> torch.Tensor: + return self.buffer[layer_index] + def _loc_offsets(self, loc: torch.Tensor): loc = loc.long() page = torch.div(loc, self.page_size, rounding_mode="floor") @@ -82,24 +91,21 @@ def _loc_offsets(self, loc: torch.Tensor): def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None: if loc.numel() == 0: return - loc = loc.long() - packed = packed.reshape(-1, DSV4_MLA_BYTES_PER_TOKEN).contiguous() - flat = self.kv_buffer[layer_index].view(-1) + loc = loc.reshape(-1) + packed = packed.reshape(-1, self.bytes_per_token).contiguous() + flat = self.buffer[layer_index].view(-1) data_offsets, scale_offsets = self._loc_offsets(loc) - - data = packed[:, : self.data_bytes_per_token].contiguous() - scale = packed[:, self.data_bytes_per_token : self.bytes_per_token].contiguous() data_range = torch.arange(self.data_bytes_per_token, device=loc.device) scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device) - flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] = data - flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = scale + flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] = packed[:, : self.data_bytes_per_token] + flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = packed[:, self.data_bytes_per_token :] return def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor: - loc = loc.long() + loc = loc.reshape(-1) if loc.numel() == 0: - return torch.empty((0, DSV4_MLA_BYTES_PER_TOKEN), dtype=torch.uint8, device=self.kv_buffer.device) - flat = self.kv_buffer[layer_index].view(-1) + return torch.empty((0, self.bytes_per_token), dtype=torch.uint8, device=self.buffer.device) + flat = self.buffer[layer_index].view(-1) data_offsets, scale_offsets = self._loc_offsets(loc) data_range = torch.arange(self.data_bytes_per_token, device=loc.device) scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device) @@ -107,150 +113,24 @@ def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor: scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] return torch.cat([data, scale], dim=1).contiguous() - def get_layer_buffer(self, layer_index: int) -> torch.Tensor: - return self.kv_buffer[layer_index] - - -class _PageSlabIndexerPool: - """C4 indexer-K storage: page tail stores per-token fp32 scales.""" - - def __init__( - self, - size: int, - page_size: int, - layer_num: int, - device: str = "cuda", - ): - self.size = size - self.page_size = page_size - self.layer_num = layer_num - self.head_dim = DSV4_INDEXER_HEAD_DIM - self.scale_bytes = 4 - self.bytes_per_token = DSV4_INDEXER_BYTES_PER_TOKEN - self.num_pages = _ceil_div(size + 1, page_size) - self.bytes_per_page = page_size * self.bytes_per_token - self.scale_offset_in_page = page_size * self.head_dim - self.index_k_buffer = torch.zeros( - (layer_num, self.num_pages, self.bytes_per_page), - dtype=torch.uint8, - device=device, - ) - self.HOLD_TOKEN_MEMINDEX = size - - def _loc_offsets(self, loc: torch.Tensor): - loc = loc.long() - page = torch.div(loc, self.page_size, rounding_mode="floor") - token = loc % self.page_size - page_base = page * self.bytes_per_page - k_offsets = page_base + token * self.head_dim - scale_offsets = page_base + self.scale_offset_in_page + token * self.scale_bytes - return k_offsets, scale_offsets - - def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None: - if loc.numel() == 0: - return - loc = loc.long() - packed = packed.reshape(-1, self.bytes_per_token).contiguous() - flat = self.index_k_buffer[layer_index].view(-1) - k_offsets, scale_offsets = self._loc_offsets(loc) - k_range = torch.arange(self.head_dim, device=loc.device) - scale_range = torch.arange(self.scale_bytes, device=loc.device) - flat[k_offsets.unsqueeze(1) + k_range.unsqueeze(0)] = packed[:, : self.head_dim] - flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = packed[:, self.head_dim :] - return - - def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor: - loc = loc.long() - if loc.numel() == 0: - return torch.empty((0, self.bytes_per_token), dtype=torch.uint8, device=self.index_k_buffer.device) - flat = self.index_k_buffer[layer_index].view(-1) - k_offsets, scale_offsets = self._loc_offsets(loc) - k_range = torch.arange(self.head_dim, device=loc.device) - scale_range = torch.arange(self.scale_bytes, device=loc.device) - k = flat[k_offsets.unsqueeze(1) + k_range.unsqueeze(0)] - scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] - return torch.cat([k, scale], dim=1).contiguous() - - def get_layer_buffer(self, layer_index: int) -> torch.Tensor: - return self.index_k_buffer[layer_index] - - -class _SubKvPool: - """Compressed c4/c128 KV pool with token-slot allocator and page-slab backing.""" - - def __init__( - self, - size: int, - page_size: int, - layer_num: int, - with_indexer: bool = False, - shared_name: Optional[str] = None, - device: str = "cuda", - ): - self.size = size - self.dtype = torch.uint8 - self.layer_num = layer_num - self.page_size = page_size - self.mla_pool = _PageSlabMlaPool(size=size, page_size=page_size, layer_num=layer_num, device=device) - self.kv_buffer = self.mla_pool.kv_buffer - if with_indexer: - self.indexer_pool = _PageSlabIndexerPool( - size=size, - page_size=page_size, - layer_num=layer_num, - device=device, - ) - self.index_k_buffer = self.indexer_pool.index_k_buffer - else: - self.indexer_pool = None - self.index_k_buffer = None - - self.allocator = KvCacheAllocator(size, shared_name=shared_name) - self.HOLD_TOKEN_MEMINDEX = size - - def alloc(self, need_size) -> torch.Tensor: - return self.allocator.alloc(need_size) - - def free(self, free_index) -> None: - self.allocator.free(free_index) - - def free_all(self) -> None: - self.allocator.free_all() - - def get_kv_buffer(self, layer_index: int) -> torch.Tensor: - return self.mla_pool.get_layer_buffer(layer_index) - - def get_index_k_buffer(self, layer_index: int) -> torch.Tensor: - assert self.indexer_pool is not None, "this sub pool has no indexer-K buffer" - return self.indexer_pool.get_layer_buffer(layer_index) - - def write_kv(self, layer_index: int, slots: torch.Tensor, packed: torch.Tensor) -> None: - self.mla_pool.write(layer_index, slots, packed) - - def read_kv(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: - return self.mla_pool.read(layer_index, slots) - - def write_indexer_k(self, layer_index: int, slots: torch.Tensor, packed: torch.Tensor) -> None: - assert self.indexer_pool is not None - self.indexer_pool.write(layer_index, slots, packed) - - def read_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: - assert self.indexer_pool is not None - return self.indexer_pool.read(layer_index, slots) +class DeepseekV4MemoryManager(MemoryManager): + """DeepSeek-V4 KV cache: 窗口 latent(全层) + c4/c128 压缩 latent(压实层) + c4 indexer-K。 -class DeepseekV4MemoryManager(Deepseek2MemoryManager): - """DeepSeek-V4 token-slot KV 管理(584B packed cache + bf16 workspace)。 + 与兄弟 manager 一致的 token-slot 设计;req 索引的表都在 DeepseekV4ReqManager。 - - dense/SWA latent: 主 ``kv_buffer`` 仍是 LightLLM 的 token-slot cache,不分页;物理格式改为 - SGLang/vLLM 的 ``fp8_ds_mla``: 448B NoPE fp8 + 64*2B RoPE bf16 + 7B scale + 1B pad = 584B。 - - c4_pool / c128_pool: 两个独立 ``_SubKvPool``(window 粒度,1-token 分配),compressed KV 同样 - 存 584B packed。c4 池附带 132B/token 的 packed indexer-K。 - - 读取时先用 torch reference dequant/gather 回 bf16 workspace,供现有 vLLM sparse FlashMLA wrapper - 消费;下一步可把这些 pack/dequant helper 替换成 fused/triton 版本。 - - 容量: 用闭式 ``get_cell_size()``(= 每个 dense token 在所有池上的 packed 总字节)让基类 - ``profile_size`` 直接得到 full_token = dense 池大小,再按 1/4、1/128 派生压缩池大小。 - - compressor 递归状态放 DeepseekV4ReqManager。 + - ``swa_pool``: 584B packed latent,所有层。池子小于 full token 空间;prep 阶段 + ``alloc_swa_prefill/decode`` 按**页**(128 槽,位置对齐: slot(p)=page_base+p%128)分配, + 映射记录到 ``full_to_swa_indexs``(以 full token 槽位为键)。出窗槽位由 DeepseekV4ReqManager + 在 prep 阶段批量惰性回收(``evict_swa``,页存活计数减到 0 才整页归还);full 槽位释放时 + ``free`` 级联回收对应 swa 槽,所以 radix 驱逐/请求释放/暂停无需任何额外协议。 + 页 allocator 触底时先走压力阀(radix 对 ref==0 节点回收)再 assert。 + 没有 ring buffer,prefill chunk 大小不受 sliding_window 限制。 + - ``c4_pool``/``c128_pool``: 压缩 latent,按 qwen3next 的层号压实手法只为压缩层建层; + c4 另带 packed indexer-K 池。槽位映射(``full_to_c4/c128_indexs``)以组末 token 的 full + 槽位为键(prep 阶段分配/scatter),``free`` 级联回收,与 swa 完全同构。 + - 写入走标准 operator 路径(``pack_mla_kv_to_cache``),内部为 triton packed writer; + torch codecs 保留为 ABI 的可执行规格(单测 oracle)。 """ operator_class = DeepseekV4MemOperator @@ -275,6 +155,7 @@ def __init__( indexer_head_dim: int = 128, max_request_num: Optional[int] = None, sliding_window: Optional[int] = None, + swa_extra_token_num: int = 0, always_copy=False, mem_fraction=0.9, ): @@ -290,15 +171,15 @@ def __init__( self.n_c4 = sum(1 for r in self.compress_rates if r == 4) self.n_c128 = sum(1 for r in self.compress_rates if r == 128) self.indexer_head_dim = indexer_head_dim - self.prefill_dtype = dtype - self.cache_dtype = torch.uint8 self.max_request_num = max_request_num self.sliding_window = sliding_window - self._pending_prefill_swa: Dict[int, Dict[str, torch.Tensor]] = {} + # 活跃窗口(max_request_num * sliding_window)之外的余量: 在途 prefill chunk 的瞬时占用 + # (出窗槽位要到下一次 prep 才回收) + radix cache 持有的窗口尾部。 + self.swa_extra_token_num = int(swa_extra_token_num) # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法) - self.layer_to_c4_idx: Dict[int, int] = {} - self.layer_to_c128_idx: Dict[int, int] = {} + self.layer_to_c4_idx = {} + self.layer_to_c128_idx = {} c4 = c128 = 0 for lid, r in enumerate(self.compress_rates): if r == 4: @@ -310,21 +191,56 @@ def __init__( super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + # ------------------------------------------------------------------ sizing + def _swa_per_req_budget(self) -> int: + # 活跃请求保留 window + 一个 radix 页(req_manager._swa_retain_len: 让最近完成的 + # 128 边界的结尾页恒驻留,prompt cache 插入门才能放行),即 v5 §2 的「活跃窗口跨页 ≤2」。 + return int(self.sliding_window) + DSV4_SWA_PAGE_SIZE + def _planned_swa_size(self, full_size: int) -> int: + # swa 池按页分配(页 = 128 = sliding_window = radix 页),容量向上取整到整页。 if self.max_request_num is None or self.sliding_window is None: - return full_size - window_cap = max(1, int(self.max_request_num) * int(self.sliding_window)) - return max(1, min(full_size, window_cap)) + return _ceil_div(full_size, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE + cap = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num + cap = max(cap, int(full_size * DSV4_SWA_FULL_TOKENS_RATIO)) + cap = max(1, min(full_size, cap)) + return _ceil_div(cap, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE + + @staticmethod + def _slab_bytes_per_slot(page_size: int, data_bytes: int, scale_bytes: int, align_bytes: int = 1) -> float: + bytes_per_page = _ceil_div(page_size * (data_bytes + scale_bytes), align_bytes) * align_bytes + return bytes_per_page / page_size + + def _c4_state_bytes_per_swa_slot(self) -> float: + """c4 compressor state(attention + indexer,swa 页派生寻址)摊到每个 swa 槽的字节数。""" + if self.n_c4 == 0: + return 0.0 + per_page = DSV4_C4_STATE_RING * (4 * self.head_dim + 4 * self.indexer_head_dim) * 4 # fp32 + return per_page * self.n_c4 / DSV4_SWA_PAGE_SIZE + + def _swa_slot_bytes(self) -> float: + per_layer = self._slab_bytes_per_slot( + DSV4_SWA_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES + ) + return per_layer * self.layer_num + self._c4_state_bytes_per_swa_slot() - def _dense_cell_size(self): - return self.head_num * self.mla_bytes_per_token * self.layer_num + def _compressed_cell_size(self) -> float: + """每个 full token 摊到压缩池上的精确字节数(按 page-slab 对齐后)。""" + c4_latent = self._slab_bytes_per_slot( + DSV4_C4_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES + ) + c128_latent = self._slab_bytes_per_slot( + DSV4_C128_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES + ) + c4_indexer = self._slab_bytes_per_slot(DSV4_C4_PAGE_SIZE, self.indexer_head_dim, DSV4_INDEXER_SCALE_BYTES) + return (c4_latent + c4_indexer) * self.n_c4 / 4 + c128_latent * self.n_c128 / 128 - def _compressed_cell_size(self): - latent_bytes = self.head_num * self.mla_bytes_per_token - c4 = latent_bytes * self.n_c4 / 4 - c128 = latent_bytes * self.n_c128 / 128 - indexer = self.indexer_bytes_per_token * self.n_c4 / 4 - return c4 + c128 + indexer + def get_cell_size(self): + compressed = self._compressed_cell_size() + if self.size is None: + return self._swa_slot_bytes() + compressed + swa_ratio = self._planned_swa_size(self.size) / max(1, self.size) + return self._swa_slot_bytes() * swa_ratio + compressed def profile_size(self, mem_fraction): if self.size is not None: @@ -334,19 +250,26 @@ def profile_size(self, mem_fraction): world_size = dist.get_world_size() available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) available_bytes = available_memory * 1024 ** 3 - dense_cell = self._dense_cell_size() + swa_slot_bytes = self._swa_slot_bytes() compressed_cell = self._compressed_cell_size() if self.max_request_num is not None and self.sliding_window is not None and compressed_cell > 0: - swa_cap = max(1, int(self.max_request_num) * int(self.sliding_window)) - full_cell = dense_cell + compressed_cell - bytes_until_swa_cap = full_cell * swa_cap - if available_bytes <= bytes_until_swa_cap: + swa_budget = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num + full_cell = swa_slot_bytes + compressed_cell + if available_bytes <= full_cell * swa_budget: + # 小显存: full token 数还到不了 swa 预算,swa 池跟随 full token 数(每 token 一个 swa 槽)。 self.size = max(1, int(available_bytes / full_cell)) else: - self.size = max(1, int((available_bytes - dense_cell * swa_cap) / compressed_cell)) + size_budget = max(1, int((available_bytes - swa_slot_bytes * swa_budget) / compressed_cell)) + if size_budget * DSV4_SWA_FULL_TOKENS_RATIO > swa_budget: + # 比例下限生效(_planned_swa_size 会取 ratio*full),按该机制反解 full。 + self.size = max( + 1, int(available_bytes / (swa_slot_bytes * DSV4_SWA_FULL_TOKENS_RATIO + compressed_cell)) + ) + else: + self.size = size_budget else: - self.size = max(1, int(available_bytes / (dense_cell + compressed_cell))) + self.size = max(1, int(available_bytes / (swa_slot_bytes + compressed_cell))) if world_size > 1: tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") @@ -362,93 +285,350 @@ def profile_size(self, mem_fraction): logger.info( f"{str(available_memory)} GB space is available after load the model weight\n" - f"{str((dense_cell + compressed_cell) / 1024 ** 2)} MB is the conservative size of one token kv cache\n" + f"{str(self.get_cell_size() / 1024 ** 2)} MB is the conservative size of one token kv cache\n" f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" ) return - def get_cell_size(self): - dense = self._dense_cell_size() - compressed = self._compressed_cell_size() - if self.size is None: - return dense + compressed - swa_ratio = self._planned_swa_size(self.size) / max(1, self.size) - return dense * swa_ratio + compressed - + # ------------------------------------------------------------------ buffers def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + rank_in_node = get_current_rank_in_node() + server = get_unique_server_name() + self.swa_size = self._planned_swa_size(size) - self.swa_pool = _PageSlabMlaPool( + assert self.swa_size % DSV4_SWA_PAGE_SIZE == 0 + self.swa_pool = PackedPagePool( size=self.swa_size, page_size=DSV4_SWA_PAGE_SIZE, layer_num=layer_num, - device="cuda", + data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN, + scale_bytes=self.mla_scale_bytes, + align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES, ) - self.kv_buffer = self.swa_pool.kv_buffer - self._init_swa_mapping(size) - self._init_compressed_pools(size, head_num) - - def _init_swa_mapping(self, size): - rank_in_node = get_current_rank_in_node() - server = get_unique_server_name() - self.swa_allocator = KvCacheAllocator( - self.swa_size, - shared_name=f"{server}_dsv4_swa_can_use_token_num_{rank_in_node}", + # 注意: 该别名是 page 索引([layer, num_pages, bytes_per_page])而非 token 索引, + # 只允许 get_att_input_params 的消费者使用;token 索引语义的继承接口已显式 fence。 + self.kv_buffer = self.swa_pool.buffer + # 页粒度分配(页 = 128 槽,位置对齐): 槽位不变式 slot(p) = page_base + p%128。 + # swa_size 整页对齐 ⇒ HOLD 槽(swa_size)独占池子最后一个物理页,永不参与分配。 + self.swa_num_pages = self.swa_size // DSV4_SWA_PAGE_SIZE + self.swa_page_allocator = KvCacheAllocator( + self.swa_num_pages, shared_name=f"{server}_dsv4_swa_can_use_page_num_{rank_in_node}" ) + # 页存活计数 = 指向该页的有效 full_to_swa 行数;减到 0 归还 allocator(出窗逐 token + # 回收下,「部分出窗页」计数 > 0 自然受保护)。下标含 HOLD 页(只读不增减)。 + self.swa_page_live_count = torch.zeros((self.swa_pool.num_pages,), dtype=torch.int32, device="cuda") + # swa 压力阀(可选): 页 allocator 触底时回调(radix 对 ref==0 节点回收 swa 页), + # 由 backend 在 radix cache 创建后注入;assert 仍是最后防线。 + self._swa_pressure_valve = None self.full_to_swa_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") self.full_to_swa_indexs[size] = self.swa_pool.HOLD_TOKEN_MEMINDEX - if self.max_request_num is None or self.sliding_window is None: - self.req_to_swa_indexs = None - self.req_to_swa_full_indexs = None - return - - self.req_to_swa_indexs = torch.full( - (self.max_request_num + 1, self.sliding_window), - self.swa_pool.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - self.req_to_swa_full_indexs = torch.full( - (self.max_request_num + 1, self.sliding_window), - -1, - dtype=torch.int32, - device="cuda", - ) - - def _init_compressed_pools(self, size, head_num): - rank_in_node = get_current_rank_in_node() - server = get_unique_server_name() - - self.c4_size = (size + 4 - 1) // 4 - self.c128_size = (size + 128 - 1) // 128 - self.c4_pool: Optional[_SubKvPool] = None - self.c128_pool: Optional[_SubKvPool] = None + self.c4_size = _ceil_div(size, 4) + self.c128_size = _ceil_div(size, 128) + self.c4_pool: Optional[PackedPagePool] = None + self.c4_indexer_pool: Optional[PackedPagePool] = None + self.c4_allocator: Optional[KvCacheAllocator] = None + self.c128_pool: Optional[PackedPagePool] = None + self.c128_allocator: Optional[KvCacheAllocator] = None + # 压缩槽映射: 键 = 组末 token(位置 (g+1)%ratio==0)的 full 槽位,值 = 压缩池槽位。 + # 与 full_to_swa_indexs 同构: radix 持有 full 槽 => 映射行存活,free 级联回收。 + self.full_to_c4_indexs: Optional[torch.Tensor] = None + self.full_to_c128_indexs: Optional[torch.Tensor] = None if self.n_c4 > 0: - self.c4_pool = _SubKvPool( + self.c4_pool = PackedPagePool( size=self.c4_size, page_size=DSV4_C4_PAGE_SIZE, layer_num=self.n_c4, - with_indexer=True, - shared_name=f"{server}_dsv4_c4_can_use_token_num_{rank_in_node}", + data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN, + scale_bytes=self.mla_scale_bytes, + align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES, ) + self.c4_indexer_pool = PackedPagePool( + size=self.c4_size, + page_size=DSV4_C4_PAGE_SIZE, + layer_num=self.n_c4, + data_bytes=self.indexer_head_dim, + scale_bytes=DSV4_INDEXER_SCALE_BYTES, + ) + self.c4_allocator = KvCacheAllocator( + self.c4_size, shared_name=f"{server}_dsv4_c4_can_use_token_num_{rank_in_node}" + ) + self.full_to_c4_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") + self.full_to_c4_indexs[size] = self.c4_pool.HOLD_TOKEN_MEMINDEX + # c4 compressor 在途状态(attention + indexer): swa 页派生寻址(翻译③),随 swa 页 + # 生灭 -> radix 命中零拷贝续算。行数 = 页数*ring + ring(HOLD 页) + 1(哨兵), + # 取整到 ratio;末行哨兵 kv=0/score=-inf(KVAndScore.clear 语义),其余行由内核在 + # 组起点覆写,无需按页清零。last_dim = 2*coff*head_dim(overlap coff=2)。 + state_rows = self.swa_num_pages * DSV4_C4_STATE_RING + DSV4_C4_STATE_RING + 1 + state_rows = _ceil_div(state_rows, 4) * 4 + self.c4_state_buffer = torch.zeros( + (self.n_c4, state_rows, 4 * self.head_dim), dtype=torch.float32, device="cuda" + ) + self.c4_indexer_state_buffer = torch.zeros( + (self.n_c4, state_rows, 4 * self.indexer_head_dim), dtype=torch.float32, device="cuda" + ) + for buf in (self.c4_state_buffer, self.c4_indexer_state_buffer): + half = buf.shape[-1] // 2 + buf[:, -1, half:].fill_(float("-inf")) if self.n_c128 > 0: - self.c128_pool = _SubKvPool( + self.c128_pool = PackedPagePool( size=self.c128_size, page_size=DSV4_C128_PAGE_SIZE, layer_num=self.n_c128, - with_indexer=False, - shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}", + data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN, + scale_bytes=self.mla_scale_bytes, + align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES, ) + self.c128_allocator = KvCacheAllocator( + self.c128_size, shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}" + ) + self.full_to_c128_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") + self.full_to_c128_indexs[size] = self.c128_pool.HOLD_TOKEN_MEMINDEX logger.info( - f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size} " + f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size}({self.swa_num_pages}p) " f"c4={self.c4_size}(L={self.n_c4}) c128={self.c128_size}(L={self.n_c128}) " f"packed_kv_bytes={self.mla_bytes_per_token} indexer_bytes={self.indexer_bytes_per_token}" ) + # ------------------------------------------------------------------ buffer accessors def get_att_input_params(self, layer_index: int): return self.swa_pool.get_layer_buffer(layer_index) + def _pool_and_local_layer(self, layer_index: int): + r = self.compress_rates[layer_index] + if r == 4: + return self.c4_pool, self.layer_to_c4_idx[layer_index] + if r == 128: + return self.c128_pool, self.layer_to_c128_idx[layer_index] + raise AssertionError(f"layer {layer_index} (rate {r}) 不是压缩层,没有压缩池") + + def get_compressed_kv_buffer(self, layer_index: int) -> torch.Tensor: + pool, local_layer = self._pool_and_local_layer(layer_index) + return pool.get_layer_buffer(local_layer) + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" + return self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]) + + def get_c4_state_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged compressor state" + return self.c4_state_buffer[self.layer_to_c4_idx[layer_index]] + + def get_c4_indexer_state_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged indexer state" + return self.c4_indexer_state_buffer[self.layer_to_c4_idx[layer_index]] + + # ------------------------------------------------------------------ swa slot lifecycle + def set_swa_pressure_valve(self, valve) -> None: + """valve(need_pages): 在页 allocator 不足时尝试腾页(radix 对 ref==0 节点回收 swa)。""" + self._swa_pressure_valve = valve + return + + def _alloc_swa_pages(self, need_pages: int) -> torch.Tensor: + if need_pages > self.swa_page_allocator.can_use_mem_size and self._swa_pressure_valve is not None: + self._swa_pressure_valve(need_pages - self.swa_page_allocator.can_use_mem_size) + return self.swa_page_allocator.alloc(need_pages) + + def _count_swa_pages(self, swa_slots: torch.Tensor, delta: int) -> torch.Tensor: + """按 slot 所在页更新存活计数,返回触达的页(去重)。""" + pages = torch.div(swa_slots.long(), DSV4_SWA_PAGE_SIZE, rounding_mode="floor") + ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device) + self.swa_page_live_count.index_add_(0, pages, ones) + return torch.unique(pages) + + def alloc_swa_prefill( + self, + b_req_idx: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_seq_len: torch.Tensor, + req_to_token_indexs: torch.Tensor, + ) -> None: + """prefill prep: 为各请求位置 [ready, seq) 的新 token 分配位置对齐的 swa 槽。 + + 槽位不变式: slot(p) = page_base(p 所在页) + p%128,page_base % 128 == 0。 + 续页(start 非整页,只可能是首页)的 base 从上一 token 的映射派生 + (full_to_swa[req_to_token[req, start-1]],该 token 必在保留窗内);其余页全新分配。 + radix 命中(ready 必 128 对齐)的借用方从全新页开始,与节点持有页天然不相交。 + 必须在 init_req_to_token_indexes 之后调用(scatter 目标经 req_to_token 行)。 + """ + page = DSV4_SWA_PAGE_SIZE + hold_req_id = self.max_request_num # padding 行的请求 id(req_manager.HOLD_REQUEST_ID) + req_list = b_req_idx.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + + segs = [] # (req_idx, start, end, n_new_pages, has_cont_page) + total_new_pages = 0 + for req_idx, start, end in zip(req_list, ready_list, seq_list): + req_idx, start, end = int(req_idx), int(start), int(end) + if req_idx == hold_req_id or end <= start: + continue + first_new_page = _ceil_div(start, page) + n_new = max(0, (end - 1) // page - first_new_page + 1) + segs.append((req_idx, start, end, n_new, start % page != 0)) + total_new_pages += n_new + if not segs: + return + + new_pages = self._alloc_swa_pages(total_new_pages).cuda(non_blocking=True).long() if total_new_pages else None + page_cursor = 0 + for req_idx, start, end, n_new, has_cont in segs: + positions = torch.arange(start, end, dtype=torch.long, device="cuda") + page_local = torch.div(positions, page, rounding_mode="floor") - start // page + bases = torch.empty(((end - 1) // page - start // page + 1,), dtype=torch.long, device="cuda") + if has_cont: + prev_slot = int(self.full_to_swa_indexs[req_to_token_indexs[req_idx, start - 1].long()].item()) + # 续页不变式: 上一 token 必驻留(retain >= 2)且位置对齐(未来 resume/MTP 改动的哨兵)。 + assert prev_slot >= 0 and prev_slot % page == (start - 1) % page + bases[0] = prev_slot - (start - 1) % page + if n_new: + bases[1 if has_cont else 0 :] = new_pages[page_cursor : page_cursor + n_new] * page + page_cursor += n_new + slots = (bases[page_local] + positions % page).to(torch.int32) + self.full_to_swa_indexs[req_to_token_indexs[req_idx, start:end].long()] = slots + self._count_swa_pages(slots, 1) + return + + def alloc_swa_decode( + self, + b_req_idx: torch.Tensor, + b_seq_len: torch.Tensor, + mem_indexes: torch.Tensor, + req_to_token_indexs: torch.Tensor, + ) -> None: + """decode prep: 本步 token(位置 seq-1)的 swa 槽。整页起点开新页,否则上一 token 槽 +1 + (位置对齐不变式保证同页连续)。scatter 目标用 mem_indexes(此刻 req_to_token 尚未写本步)。 + + 注意: 续槽从上一位置的映射派生,故同一请求的多行(MTP 多 token/步)在同一批内不支持 + (DSV4 启动参数已拒绝 MTP;支持需按步内顺序分段派生)。""" + page = DSV4_SWA_PAGE_SIZE + hold_req_id = self.max_request_num + req_list = b_req_idx.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + cont_rows, cont_prev_pos, new_rows = [], [], [] + for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)): + req_idx, seq_len = int(req_idx), int(seq_len) + if req_idx == hold_req_id or seq_len <= 0: + continue + if (seq_len - 1) % page == 0: + new_rows.append(i) + else: + cont_rows.append(i) + cont_prev_pos.append(seq_len - 2) + mem_indexes = mem_indexes.cuda().long().reshape(-1) + if cont_rows: + req_rows = b_req_idx[cont_rows].long() + prev_full = req_to_token_indexs[req_rows, torch.tensor(cont_prev_pos, device="cuda")].long() + prev_slots = self.full_to_swa_indexs[prev_full] + # 续槽不变式哨兵: 上一位置必驻留(retain 覆盖)。prep 阶段本就有同步,代价可忽略。 + assert bool((prev_slots >= 0).all()) + slots = prev_slots + 1 + self.full_to_swa_indexs[mem_indexes[cont_rows]] = slots + self._count_swa_pages(slots, 1) + if new_rows: + pages = self._alloc_swa_pages(len(new_rows)).cuda(non_blocking=True).long() + slots = (pages * page).to(torch.int32) + self.full_to_swa_indexs[mem_indexes[new_rows]] = slots + self._count_swa_pages(slots, 1) + return + + def evict_swa(self, full_slots: torch.Tensor) -> None: + """回收 full 槽位对应的 swa 槽(出窗惰性回收 / free 级联 / 压力阀共用)。 + 未映射(-1)的槽位跳过;页计数减到 0 时整页归还 allocator。""" + if full_slots.numel() == 0: + return + full_slots = full_slots.cuda().long().reshape(-1) + full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX]) + if full_slots.numel() == 0: + return + swa_slots = self.full_to_swa_indexs[full_slots] + valid = swa_slots >= 0 + valid_slots = swa_slots[valid] + if valid_slots.numel() == 0: + return + self.full_to_swa_indexs[full_slots[valid]] = -1 + touched = self._count_swa_pages(valid_slots, -1) + empty = touched[self.swa_page_live_count[touched] == 0] + if empty.numel() > 0: + self.swa_page_allocator.free(empty.to(torch.int32)) + return + + def _evict_compress(self, full_slots: torch.Tensor, mapping: torch.Tensor, allocator: KvCacheAllocator) -> None: + full_slots = full_slots.cuda().long().reshape(-1) + # 去重: 同批重复槽会 gather 出重复的压缩槽 -> allocator 双重释放(free 已去重,直呼叫方防御)。 + full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX]) + if full_slots.numel() == 0: + return + slots = mapping[full_slots] + valid = slots >= 0 + valid_slots = slots[valid] + if valid_slots.numel() == 0: + return + allocator.free(valid_slots) + mapping[full_slots[valid]] = -1 + return + + def evict_c4(self, full_slots: torch.Tensor) -> None: + """回收 full 槽位(组末 token)映射的 c4 槽。非组末/未映射(-1)的槽位跳过。""" + if self.c4_allocator is None or full_slots.numel() == 0: + return + self._evict_compress(full_slots, self.full_to_c4_indexs, self.c4_allocator) + return + + def evict_c128(self, full_slots: torch.Tensor) -> None: + """回收 full 槽位(组末 token)映射的 c128 槽。非组末/未映射(-1)的槽位跳过。""" + if self.c128_allocator is None or full_slots.numel() == 0: + return + self._evict_compress(full_slots, self.full_to_c128_indexs, self.c128_allocator) + return + + # ------------------------------------------------------------------ alloc/free (cascade) + def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: + """释放 full token 槽位,级联回收其 swa 槽与 c4/c128 压缩槽。radix 驱逐、请求释放/暂停都走这里。 + + 先对 full 槽去重: 同批重复槽位会让映射 gather 出重复的压缩/swa 槽,导致 allocator 双重释放。""" + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int64) + if free_index.numel() > 0: + free_index = torch.unique(free_index) + self.evict_swa(free_index) + self.evict_c4(free_index) + self.evict_c128(free_index) + super().free(free_index) + return + + def free_all(self): + super().free_all() + self.swa_page_allocator.free_all() + self.swa_page_live_count.zero_() + self.full_to_swa_indexs.fill_(-1) + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + if self.c4_allocator is not None: + self.c4_allocator.free_all() + self.full_to_c4_indexs.fill_(-1) + self.full_to_c4_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c4_pool.HOLD_TOKEN_MEMINDEX + if self.c128_allocator is not None: + self.c128_allocator.free_all() + self.full_to_c128_indexs.fill_(-1) + self.full_to_c128_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c128_pool.HOLD_TOKEN_MEMINDEX + return + + def alloc_c4(self, need_size) -> torch.Tensor: + return self.c4_allocator.alloc(need_size) + + def alloc_c128(self, need_size) -> torch.Tensor: + return self.c128_allocator.alloc(need_size) + + def free_c4(self, free_index) -> None: + self.c4_allocator.free(free_index) + + def free_c128(self, free_index) -> None: + self.c128_allocator.free(free_index) + + # ------------------------------------------------------------------ packed codecs (torch reference) + # 与 sglang/vllm 的 fp8_ds_mla 字节布局逐位对齐(ue8m0 幂次 scale)。这些 torch 实现是该 ABI 的 + # 可执行规格(单测 oracle,triton writer 与其逐字节对拍),不可删除。 def _pack_mla_kv(self, kv: torch.Tensor) -> torch.Tensor: kv = kv.reshape(-1, self.mla_head_dim) out = torch.empty((kv.shape[0], self.mla_bytes_per_token), dtype=torch.uint8, device=kv.device) @@ -500,7 +680,7 @@ def _pack_indexer_k(self, indexer_k: torch.Tensor) -> torch.Tensor: ) k_fp8 = torch.clamp(k_float / scale, -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(torch.float8_e4m3fn) out[:, : self.indexer_head_dim].copy_(k_fp8.view(dtype=torch.uint8)) - out[:, self.indexer_head_dim : self.indexer_bytes_per_token].copy_(scale.view(dtype=torch.uint8).reshape(-1, 4)) + out[:, self.indexer_head_dim :].copy_(scale.view(dtype=torch.uint8).reshape(-1, DSV4_INDEXER_SCALE_BYTES)) return out def _unpack_indexer_k(self, packed: torch.Tensor) -> torch.Tensor: @@ -508,483 +688,99 @@ def _unpack_indexer_k(self, packed: torch.Tensor) -> torch.Tensor: if packed.shape[0] == 0: return torch.empty((0, self.indexer_head_dim), dtype=self.dtype, device=packed.device) k_fp8 = packed[:, : self.indexer_head_dim].view(dtype=torch.float8_e4m3fn).float() - scale = packed[:, self.indexer_head_dim : self.indexer_bytes_per_token].view(dtype=torch.float32) + scale = packed[:, self.indexer_head_dim :].view(dtype=torch.float32) return (k_fp8 * scale).to(self.dtype) - def _identity_swa_slots(self, full_slots: torch.Tensor) -> torch.Tensor: - full_slots = full_slots.long() - valid = full_slots != self.HOLD_TOKEN_MEMINDEX - if valid.any() and int(full_slots[valid].max().item()) >= self.swa_size: - raise RuntimeError( - "DeepSeek-V4 SWA cache needs req_idx/positions for full token slots outside the SWA pool" - ) - swa_slots = torch.where( - valid, - full_slots, - torch.full_like(full_slots, self.swa_pool.HOLD_TOKEN_MEMINDEX), - ) - if valid.any(): - self.full_to_swa_indexs[full_slots[valid]] = swa_slots[valid].to(torch.int32) - return swa_slots - - def ensure_swa_slots(self, req_idx: int, positions: torch.Tensor, full_slots: torch.Tensor) -> torch.Tensor: - full_slots = full_slots.long().reshape(-1) - if full_slots.numel() == 0: - return full_slots - if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: - return self._identity_swa_slots(full_slots) - - positions = positions.long().reshape(-1) - assert positions.numel() == full_slots.numel() - req_idx = int(req_idx) - out = torch.empty_like(full_slots, dtype=torch.long) - for i, (pos, full) in enumerate(zip(positions.tolist(), full_slots.tolist())): - if full == self.HOLD_TOKEN_MEMINDEX: - out[i] = self.swa_pool.HOLD_TOKEN_MEMINDEX - continue - - ring_pos = pos % self.sliding_window - old_swa = int(self.req_to_swa_indexs[req_idx, ring_pos].item()) - old_full = int(self.req_to_swa_full_indexs[req_idx, ring_pos].item()) - if old_full == full and old_swa != self.swa_pool.HOLD_TOKEN_MEMINDEX: - swa = old_swa - elif old_swa != self.swa_pool.HOLD_TOKEN_MEMINDEX: - if old_full >= 0: - self.full_to_swa_indexs[old_full] = -1 - swa = old_swa - else: - swa = int(self.swa_allocator.alloc(1)[0].item()) - - self.req_to_swa_indexs[req_idx, ring_pos] = swa - self.req_to_swa_full_indexs[req_idx, ring_pos] = full - self.full_to_swa_indexs[full] = swa - out[i] = swa - return out - - def prepare_decode_swa_slots( - self, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, - mem_index: torch.Tensor, - ) -> None: - if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: - return - - reqs = b_req_idx.detach().cpu().tolist() - seqs = b_seq_len.detach().cpu().tolist() - fulls = mem_index.detach().cpu().tolist() - hold = self.swa_pool.HOLD_TOKEN_MEMINDEX - for req_idx, seq_len, full in zip(reqs, seqs, fulls): - req_idx = int(req_idx) - full = int(full) - if req_idx == self.max_request_num or full == self.HOLD_TOKEN_MEMINDEX: - continue - ring_pos = (int(seq_len) - 1) % int(self.sliding_window) - old_swa = int(self.req_to_swa_indexs[req_idx, ring_pos].item()) - old_full = int(self.req_to_swa_full_indexs[req_idx, ring_pos].item()) - if old_swa == hold: - old_swa = int(self.swa_allocator.alloc(1)[0].item()) - if old_full >= 0 and old_full != full: - self.full_to_swa_indexs[old_full] = -1 - self.req_to_swa_indexs[req_idx, ring_pos] = old_swa - self.req_to_swa_full_indexs[req_idx, ring_pos] = full - self.full_to_swa_indexs[full] = old_swa - self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = hold - return - - def _reserve_prefill_swa_slots( - self, - req_idx: int, - positions: torch.Tensor, - full_slots: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - full_slots = full_slots.long().reshape(-1) - positions = positions.long().reshape(-1) - assert positions.numel() == full_slots.numel() - - out = torch.empty_like(full_slots, dtype=torch.long) - ring_to_swa: Dict[int, int] = {} - ring_to_old_full: Dict[int, int] = {} - ring_to_final_full: Dict[int, int] = {} - hold = self.swa_pool.HOLD_TOKEN_MEMINDEX - - for i, (pos, full) in enumerate(zip(positions.tolist(), full_slots.tolist())): - if full == self.HOLD_TOKEN_MEMINDEX: - out[i] = hold - continue - - ring_pos = int(pos) % int(self.sliding_window) - swa = ring_to_swa.get(ring_pos) - if swa is None: - old_swa = int(self.req_to_swa_indexs[req_idx, ring_pos].item()) - old_full = int(self.req_to_swa_full_indexs[req_idx, ring_pos].item()) - if old_swa == hold: - old_swa = int(self.swa_allocator.alloc(1)[0].item()) - swa = old_swa - ring_to_swa[ring_pos] = swa - ring_to_old_full[ring_pos] = old_full - - ring_to_final_full[ring_pos] = int(full) - out[i] = swa - - rings = sorted(ring_to_final_full) - return { - "positions": positions.detach().clone(), - "full_slots": full_slots.detach().clone(), - "swa_slots": out.detach().clone(), - "commit_rings": torch.tensor(rings, dtype=torch.long, device=full_slots.device), - "commit_full_slots": torch.tensor( - [ring_to_final_full[r] for r in rings], - dtype=torch.long, - device=full_slots.device, - ), - "commit_swa_slots": torch.tensor( - [ring_to_swa[r] for r in rings], - dtype=torch.long, - device=full_slots.device, - ), - "commit_old_full_slots": torch.tensor( - [ring_to_old_full[r] for r in rings], - dtype=torch.long, - device=full_slots.device, - ), - } - - def prepare_prefill_swa_slots( - self, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, - b_ready_cache_len: torch.Tensor, - b_start_loc: torch.Tensor, - mem_index: torch.Tensor, - ) -> None: - if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: - return - - self._pending_prefill_swa = {} - req_list = b_req_idx.detach().cpu().tolist() - seq_list = b_seq_len.detach().cpu().tolist() - ready_list = b_ready_cache_len.detach().cpu().tolist() - start_list = b_start_loc.detach().cpu().tolist() - for req_idx, seq_len, ready_len, start_loc in zip(req_list, seq_list, ready_list, start_list): - token_num = int(seq_len) - int(ready_len) - if token_num <= 0: - continue - pos = torch.arange(int(ready_len), int(seq_len), dtype=torch.long, device=mem_index.device) - slots = mem_index[int(start_loc) : int(start_loc) + token_num] - self._pending_prefill_swa[int(req_idx)] = self._reserve_prefill_swa_slots(int(req_idx), pos, slots) - return - - def _get_pending_prefill_swa_slots( - self, - req_idx: int, - positions: torch.Tensor, - full_slots: torch.Tensor, - ) -> Optional[torch.Tensor]: - pending = self._pending_prefill_swa.get(int(req_idx)) - if pending is None: - return None - if pending["positions"].numel() != positions.numel(): - return None - if not torch.equal(pending["positions"].to(positions.device), positions.long().reshape(-1)): - return None - if not torch.equal(pending["full_slots"].to(full_slots.device), full_slots.long().reshape(-1)): - return None - return pending["swa_slots"].to(full_slots.device) - - def commit_prefill_swa_slots(self) -> None: - if not self._pending_prefill_swa: - return - for req_idx, pending in self._pending_prefill_swa.items(): - rings = pending["commit_rings"].to(self.req_to_swa_indexs.device) - if rings.numel() == 0: - continue - old_full = pending["commit_old_full_slots"].to(self.full_to_swa_indexs.device) - valid_old = old_full >= 0 - if valid_old.any(): - self.full_to_swa_indexs[old_full[valid_old].long()] = -1 - - full_slots = pending["commit_full_slots"].to(self.full_to_swa_indexs.device) - swa_slots = pending["commit_swa_slots"].to(self.full_to_swa_indexs.device) - self.req_to_swa_indexs[int(req_idx), rings] = swa_slots.to(torch.int32) - self.req_to_swa_full_indexs[int(req_idx), rings] = full_slots.to(torch.int32) - self.full_to_swa_indexs[full_slots.long()] = swa_slots.to(torch.int32) - self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX - self._pending_prefill_swa = {} - return - - def _swa_slots_from_full(self, full_slots: torch.Tensor) -> torch.Tensor: - full_slots = full_slots.long().reshape(-1) - if full_slots.numel() == 0: - return full_slots - mapped = self.full_to_swa_indexs[full_slots].long() - missing = mapped < 0 - if missing.any(): - if self.req_to_swa_indexs is not None: - bad = int(full_slots[missing][0].item()) - raise RuntimeError(f"DeepSeek-V4 dense KV for full token slot {bad} has been evicted from SWA cache") - fallback = full_slots[missing] - fallback_valid = fallback < self.swa_size - if fallback_valid.all(): - mapped[missing] = fallback - self.full_to_swa_indexs[fallback] = fallback.to(torch.int32) - else: - bad = int(fallback[~fallback_valid][0].item()) - raise RuntimeError(f"DeepSeek-V4 dense KV for full token slot {bad} has been evicted from SWA cache") - return mapped - - def free_swa_for_req(self, req_idx: int) -> None: - if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: - return - req_idx = int(req_idx) - slots = self.req_to_swa_indexs[req_idx] - full_slots = self.req_to_swa_full_indexs[req_idx] - valid_swa = slots != self.swa_pool.HOLD_TOKEN_MEMINDEX - if valid_swa.any(): - free_slots = torch.unique(slots[valid_swa]).detach().cpu() - self.swa_allocator.free(free_slots) - valid_full = full_slots >= 0 - if valid_full.any(): - self.full_to_swa_indexs[full_slots[valid_full].long()] = -1 - self.req_to_swa_indexs[req_idx].fill_(self.swa_pool.HOLD_TOKEN_MEMINDEX) - self.req_to_swa_full_indexs[req_idx].fill_(-1) - self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX - - def snapshot_swa_for_prompt_cache(self, req_idx: int, cache_len: int, full_slots: torch.Tensor): - if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None or cache_len <= 0: - return None - tail_start = max(0, int(cache_len) - int(self.sliding_window)) - full_slots = full_slots[tail_start:cache_len].long().to(self.kv_buffer.device) - if full_slots.numel() == 0: - return None - swa_slots = self.full_to_swa_indexs[full_slots].long() - if (swa_slots < 0).any(): - bad = int(full_slots[swa_slots < 0][0].item()) - raise RuntimeError(f"DeepSeek-V4 prompt cache cannot snapshot evicted SWA full slot {bad}") - return { - "positions": torch.arange(tail_start, cache_len, dtype=torch.int64, device="cpu"), - "full_slots": full_slots.detach().cpu(), - "swa_slots": swa_slots.detach().cpu(), - } - - def clone_swa_for_prompt_cache(self, req_idx: int, cache_len: int, full_slots: torch.Tensor): - payload = self.snapshot_swa_for_prompt_cache(req_idx, cache_len, full_slots) - if payload is None: - return None - - src_slots = payload["swa_slots"].long().to(self.kv_buffer.device) - dst_slots = self.swa_allocator.alloc(src_slots.numel()).long().to(self.kv_buffer.device) - for layer_idx in range(self.layer_num): - self.swa_pool.write(layer_idx, dst_slots, self.swa_pool.read(layer_idx, src_slots)) - payload["swa_slots"] = dst_slots.detach().cpu() - return payload - - def detach_swa_for_prompt_cache(self, req_idx: int, swa_payload) -> None: - if ( - swa_payload is None - or self.req_to_swa_indexs is None - or self.req_to_swa_full_indexs is None - or len(swa_payload["positions"]) == 0 - ): - return - req_idx = int(req_idx) - positions = swa_payload["positions"].tolist() - full_slots = swa_payload["full_slots"].tolist() - swa_slots = swa_payload["swa_slots"].tolist() - for pos, full, swa in zip(positions, full_slots, swa_slots): - ring_pos = int(pos) % int(self.sliding_window) - if int(self.req_to_swa_indexs[req_idx, ring_pos].item()) == int(swa) and int( - self.req_to_swa_full_indexs[req_idx, ring_pos].item() - ) == int(full): - self.req_to_swa_indexs[req_idx, ring_pos] = self.swa_pool.HOLD_TOKEN_MEMINDEX - self.req_to_swa_full_indexs[req_idx, ring_pos] = -1 - return - - def restore_swa_from_prompt_cache(self, swa_payload) -> None: - if swa_payload is None or len(swa_payload["full_slots"]) == 0: - return - full_slots = swa_payload["full_slots"].long().to(self.kv_buffer.device) - swa_slots = swa_payload["swa_slots"].long().to(self.kv_buffer.device) - self.full_to_swa_indexs[full_slots] = swa_slots.to(torch.int32) - self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX - return - - def free_swa_prompt_cache(self, swa_payload) -> None: - if swa_payload is None or len(swa_payload["swa_slots"]) == 0: - return - swa_slots = torch.unique(swa_payload["swa_slots"].long()).detach().cpu() - self.swa_allocator.free(swa_slots) - full_slots = swa_payload["full_slots"].long().to(self.kv_buffer.device) - mapped = self.full_to_swa_indexs[full_slots].long() - expected = swa_payload["swa_slots"].long().to(self.kv_buffer.device) - same = mapped == expected - if same.any(): - self.full_to_swa_indexs[full_slots[same]] = -1 - self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX - return - - def _keep_last_swa_writes(self, swa_slots: torch.Tensor, packed: torch.Tensor): - """Drop duplicate SWA writes generated by long prefill ring reuse.""" - if swa_slots.numel() <= 1: - return swa_slots, packed - - slots_cpu = swa_slots.detach().cpu().tolist() - seen = set() - keep = [] - hold = self.swa_pool.HOLD_TOKEN_MEMINDEX - for i in range(len(slots_cpu) - 1, -1, -1): - slot = int(slots_cpu[i]) - if slot == hold or slot in seen: - continue - seen.add(slot) - keep.append(i) - keep.reverse() - if len(keep) == len(slots_cpu): - return swa_slots, packed - if not keep: - return swa_slots[:0], packed[:0] - keep_index = torch.tensor(keep, dtype=torch.long, device=swa_slots.device) - return swa_slots.index_select(0, keep_index), packed.index_select(0, keep_index) - - def pack_mla_kv_to_cache( - self, - layer_index: int, - mem_index: torch.Tensor, - kv: torch.Tensor, - req_idx: Optional[int] = None, - positions: Optional[torch.Tensor] = None, - ): + # ------------------------------------------------------------------ cache write paths + def pack_mla_kv_to_cache(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """标准 operator 写入路径。要求本步已对 mem_index 调过 ``alloc_swa``(prep 阶段); + HOLD/padding 槽位映射到 swa HOLD 槽,写入无害。""" if kv.shape[0] == 0: return - packed = self._pack_mla_kv(kv) - if req_idx is None or positions is None: - swa_slots = self._identity_swa_slots(mem_index).to(kv.device) - else: - pending_slots = self._get_pending_prefill_swa_slots(req_idx, positions, mem_index) - if pending_slots is None: - swa_slots = self.ensure_swa_slots(req_idx, positions, mem_index).to(kv.device) - else: - swa_slots = pending_slots.to(kv.device) - swa_slots, packed = self._keep_last_swa_writes(swa_slots, packed) - if swa_slots.numel() == 0: - return - self.swa_pool.write(layer_index, swa_slots, packed) - - def pack_decode_mla_kv_to_cache( - self, - layer_index: int, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, - mem_index: torch.Tensor, - kv: torch.Tensor, - ): - if kv.shape[0] == 0: - return - packed = self._pack_mla_kv(kv) - if self.req_to_swa_indexs is None or self.req_to_swa_full_indexs is None: - swa_slots = self._identity_swa_slots(mem_index).to(kv.device) - else: - req = b_req_idx.long() - ring = ((b_seq_len.long() - 1) % int(self.sliding_window)).long() - swa_slots = self.req_to_swa_indexs[req, ring].long() - - old_full = self.req_to_swa_full_indexs[req, ring].long() - full_slots = mem_index.long() - old_full = torch.where(old_full >= 0, old_full, full_slots) - self.full_to_swa_indexs[old_full] = torch.full( - old_full.shape, - -1, - dtype=self.full_to_swa_indexs.dtype, - device=old_full.device, - ) - - self.req_to_swa_full_indexs[req, ring] = full_slots.to(torch.int32) - self.full_to_swa_indexs[full_slots] = swa_slots.to(torch.int32) - self.swa_pool.write(layer_index, swa_slots.to(kv.device), packed) + from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import ( + destindex_copy_kv_flashmla_dsv4, + ) - def gather_mla_kv_from_swa_slots(self, layer_index: int, swa_slots: torch.Tensor) -> torch.Tensor: - return self._unpack_mla_kv(self.swa_pool.read(layer_index, swa_slots.to(self.kv_buffer.device))) + swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)] + destindex_copy_kv_flashmla_dsv4( + kv.reshape(-1, self.mla_head_dim), + swa_slots, + self.swa_pool.get_layer_buffer(layer_index), + self.swa_pool.page_size, + ) + return def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor): if comp.shape[0] == 0: return + from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import ( + destindex_copy_kv_flashmla_dsv4, + ) + pool, local_layer = self._pool_and_local_layer(layer_index) - pool.write_kv(local_layer, slots.to(comp.device), self._pack_mla_kv(comp)) + destindex_copy_kv_flashmla_dsv4( + comp.reshape(-1, self.mla_head_dim), + slots.to(comp.device), + pool.get_layer_buffer(local_layer), + pool.page_size, + ) - def pack_c4_indexer_k_to_cache(self, layer_index: int, slots: torch.Tensor, indexer_k: torch.Tensor): + def pack_indexer_k_to_cache(self, layer_index: int, slots: torch.Tensor, indexer_k: torch.Tensor): if indexer_k.shape[0] == 0: return - pool, local_layer = self._pool_and_local_layer(layer_index) - pool.write_indexer_k(local_layer, slots.to(indexer_k.device), self._pack_indexer_k(indexer_k)) - - def gather_mla_kv(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: - if slots.numel() == 0: - return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=self.kv_buffer.device) - swa_slots = self._swa_slots_from_full(slots).to(self.kv_buffer.device) - return self._unpack_mla_kv(self.swa_pool.read(layer_index, swa_slots)) - - def gather_compressed_kv(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: - if slots.numel() == 0: - return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=self.kv_buffer.device) - pool, local_layer = self._pool_and_local_layer(layer_index) - return self._unpack_mla_kv(pool.read_kv(local_layer, slots.to(self.kv_buffer.device))) - - def gather_c4_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: - if slots.numel() == 0: - return torch.empty( - (0, self.indexer_head_dim), - dtype=self.dtype, - device=self.kv_buffer.device, - ) - pool, local_layer = self._pool_and_local_layer(layer_index) - return self._unpack_indexer_k(pool.read_indexer_k(local_layer, slots.to(self.kv_buffer.device))) - - def _pool_and_local_layer(self, layer_index: int): - r = self.compress_rates[layer_index] - if r == 4: - return self.c4_pool, self.layer_to_c4_idx[layer_index] - if r == 128: - return self.c128_pool, self.layer_to_c128_idx[layer_index] - raise AssertionError(f"layer {layer_index} (rate {r}) 不是压缩层,没有压缩池") + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" + from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_indexer_k_dsv4 import ( + destindex_copy_indexer_k_dsv4, + ) - def get_compressed_kv_buffer(self, layer_index: int) -> torch.Tensor: - pool, local_layer = self._pool_and_local_layer(layer_index) - return pool.get_kv_buffer(local_layer) + destindex_copy_indexer_k_dsv4( + indexer_k.reshape(-1, self.indexer_head_dim), + slots.to(indexer_k.device), + self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]), + self.c4_indexer_pool.page_size, + ) - def get_compressed_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + def gather_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + """反量化 gather c4 indexer-K: slots [N](c4 槽位,HOLD 合法) -> [N, indexer_head_dim] bf16。 + indexer top-k 打分用(纯张量操作,cuda-graph 安全)。""" assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" - return self.c4_pool.get_index_k_buffer(self.layer_to_c4_idx[layer_index]) + pool = self.c4_indexer_pool + flat = pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]).view(-1) + data_offsets, scale_offsets = pool._loc_offsets(slots.reshape(-1)) + data_range = torch.arange(pool.data_bytes_per_token, device=flat.device) + scale_range = torch.arange(pool.scale_bytes_per_token, device=flat.device) + k_fp8 = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)].view(torch.float8_e4m3fn) + scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)].contiguous().view(torch.float32) + return (k_fp8.float() * scale).to(torch.bfloat16) + + # ------------------------------------------------------------------ fenced inherited APIs + # kv_buffer 是 page 索引的 uint8 slab,基类按 token 索引读写的接口会静默写坏数据,显式拦截。 + def get_index_kv_buffer(self, index): + raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io") + + def load_index_kv_buffer(self, index, load_tensor_dict): + raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io") - def alloc_c4(self, need_size) -> torch.Tensor: - return self.c4_pool.alloc(need_size) + def alloc_kv_move_buffer(self, max_req_total_len): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") - def alloc_c128(self, need_size) -> torch.Tensor: - return self.c128_pool.alloc(need_size) + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") - def free_c4(self, free_index) -> None: - self.c4_pool.free(free_index) + def write_mem_to_page_kv_move_buffer(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") - def free_c128(self, free_index) -> None: - self.c128_pool.free(free_index) + def read_page_kv_move_buffer_to_mem(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") - def free_all(self): - super().free_all() - if hasattr(self, "swa_allocator"): - self.swa_allocator.free_all() - if hasattr(self, "full_to_swa_indexs"): - self.full_to_swa_indexs.fill_(-1) - self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX - if getattr(self, "req_to_swa_indexs", None) is not None: - self.req_to_swa_indexs.fill_(self.swa_pool.HOLD_TOKEN_MEMINDEX) - self.req_to_swa_full_indexs.fill_(-1) - self._pending_prefill_swa = {} - if self.c4_pool is not None: - self.c4_pool.free_all() - if self.c128_pool is not None: - self.c128_pool.free_all() + def send_to_decode_node(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") - def alloc_kv_move_buffer(self, max_req_total_len): + def receive_from_prefill_node(self, *args, **kwargs): raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") - def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: - raise NotImplementedError("DeepSeek-V4 packed/composite paged KV transfer is not implemented") + def send_to_decode_node_p2p(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def receive_from_prefill_node_p2p(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 1c5a9c09d3..4db55e1555 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -14,6 +14,7 @@ EXPERT_DTYPE_TO_QUANT_TYPE = { "fp8": "deepgemm-fp8w8a8-b128", "fp4": "deepgemm-fp4fp8-b32", + "mxfp4": "marlin-mxfp4w4a16-b32", } SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE) @@ -64,13 +65,13 @@ def _mapping_quant_method(self): logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, - # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 + # 按其值给 fused_moe 选用对应的量化方法。 expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None) if expert_dtype is None: return - if expert_dtype == "fp4" and self.network_config_.get("model_type") == "deepseek_v4" and not is_sm100_gpu(): - logger.info("skip generic fused_moe quant mapping for DeepSeek-V4 fp4 experts on non-SM100 GPUs") - return + # DeepSeek-V4 的 fp4 发布版自带预打包 MXFP4 专家。 + if expert_dtype == "fp4" and self.network_config_.get("model_type") == "deepseek_v4": + expert_dtype = "mxfp4" target = self._get_expert_quant_type(expert_dtype) for layer_num in range(self.layer_num): if self.expert_dtype is not None: diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index ec1ee90fd4..677d3b7dd7 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -198,6 +198,84 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["marlin-mxfp4w4a16-b32"], platform="cuda") +class MXFP4MoEQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = "scale" + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "marlin-mxfp4w4a16-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + raise NotImplementedError("marlin-mxfp4w4a16-b32 only loads pre-packed MXFP4 expert weights") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("marlin-mxfp4w4a16-b32 is only implemented for fused MoE expert weights") + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % 2 == 0, "MXFP4 packed weight requires even input dimension" + assert in_dim % self.block_size == 0, "MXFP4 scale dimension must be divisible by block_size" + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8, device="cpu") + weight_scale = torch.empty( + expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float8_e8m0fnu, device="cpu" + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def finalize_moe_weight(self, moe_weight): + try: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + except Exception as e: + raise RuntimeError(f"marlin-mxfp4w4a16-b32 requires vLLM MXFP4 packing utilities, error={repr(e)}") from e + + class _MXFP4Layer: + pass + + device = torch.device("cuda", moe_weight.device_id_) + layer = _MXFP4Layer() + layer.params_dtype = moe_weight.data_type_ + w13 = moe_weight.w13.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous() + w2 = moe_weight.w2.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous() + w13_scale = moe_weight.w13.weight_scale.to(device=device, non_blocking=True).contiguous() + w2_scale = moe_weight.w2.weight_scale.to(device=device, non_blocking=True).contiguous() + ( + moe_weight.w13.weight, + moe_weight.w2.weight, + moe_weight.w13.weight_scale, + moe_weight.w2.weight_scale, + _, + _, + ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None) + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 7b56129c3f..24b39b71ce 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -26,14 +26,18 @@ @dataclass class DeepseekV4PromptCachePayload: + """prompt cache 载荷: 只剩 swa 按页有效性 bitmap。 + + 槽位与 compressor 状态都不进载荷: full_to_swa/full_to_c4/full_to_c128 以 full token 槽位 + 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4 compressor 状态以 swa 页派生 + 寻址(随 swa 页生灭,命中零拷贝续算);c128 状态在 128 边界自然归零,无需恢复。 + + * ``swa_page_valid``: cpu bool [cache_len // page],插入时按当下 full_to_swa 映射写定 + (页内 128 个映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 128 边界, + swa 压力阀回收节点页时清零。""" + cache_len: int - c4_slots: Optional[torch.Tensor] = None - c128_slots: Optional[torch.Tensor] = None - c4_state: Optional[torch.Tensor] = None - c4_state_pool: Optional[torch.Tensor] = None - c4_indexer_state: Optional[torch.Tensor] = None - c4_indexer_state_pool: Optional[torch.Tensor] = None - swa: Optional[dict] = None + swa_page_valid: Optional[torch.Tensor] = None class DeepseekV4PromptCacheValueOps: @@ -47,9 +51,29 @@ def concat(self, payloads: List[DeepseekV4PromptCachePayload]): return self.req_manager.concat_prompt_cache_payloads(payloads) def free(self, payload: DeepseekV4PromptCachePayload): - self.req_manager.free_prompt_cache_payload(payload) + # 槽位资源全部由 mem_manager.free(full_slots) 级联回收,载荷本身没有需要释放的资源。 return + def invalidate_swa_pages(self, payload: DeepseekV4PromptCachePayload) -> None: + """swa 压力阀回收了该节点的 swa 页后清 bitmap: 后续命中按缩短语义裁剪,不会复活。""" + if payload is not None and payload.swa_page_valid is not None: + payload.swa_page_valid.fill_(False) + return + + def valid_match_length(self, payload: Optional[DeepseekV4PromptCachePayload], natural_len: int) -> int: + """radix 匹配裁剪: 返回 <= natural_len 的最大 128 边界 L',使结尾页(bitmap[L'/128-1])有效。 + + 有效性可能非单调(owner 生前从左驱逐、后续阀从尾回收),按候选边界回查 bitmap; + 中段 invalid 页不挡更靠后的有效命中(注意力只回看最后一个窗口)。""" + page = self.req_manager.get_prompt_cache_page_size() + if payload is None or payload.swa_page_valid is None: + return 0 + n_pages = min(natural_len // page, int(payload.swa_page_valid.numel())) + valid_idx = torch.nonzero(payload.swa_page_valid[:n_pages]) + if valid_idx.numel() == 0: + return 0 + return (int(valid_idx[-1]) + 1) * page + class _ReqNode: def __init__(self, index): @@ -334,21 +358,23 @@ def copy_small_page_buffer_to_linear_att_state( class DeepseekV4ReqManager(ReqManager): - """DeepSeek-V4 的请求级管理(锁定决策: SWA 全历史 + 不分页)。 - - 在基类 ReqManager 之上补三类 V4 专有的 per-request 结构。该对象在 mem manager profile 前创建, - 所以初始化只依赖 config 派生出的 compress_rates/head_dim/indexer_head_dim;真实 mem_manager - 会在 `_init_mem_manager()` 后通过 `bind_mem_manager()` 接入。 - - * ``req_to_c4_indexs`` / ``req_to_c128_indexs`` —— (req, 窗口下标) -> 压缩池槽位。 - 窗口下标 = position // compress_rate;窗口关闭时由 layer-infer 写入,attention 读取前 - n_windows 列即该 req 的全部压缩条目槽。未填充列为 0(不会被读到,语义同 req_to_token_indexs)。 - * ``req_to_c4_state`` / ``req_to_c128_state`` / ``req_to_c4_indexer_state`` —— compressor 的 - “在途窗口”累加状态(per req、per 压缩层),fp32。形状为 - ``(kv_or_score, coff * ratio, coff * dim)``; c4 因 Ca/Cb overlap 取 ``coff=2``, - c128 取 ``coff=1``。score 初始化为 ``-inf``,与官方 reference compressor 的 - ``kv_state``/``score_state`` 对齐。 - * entry_count 不另存:= position // compress_rate,可由序列长度推出。 + """DeepSeek-V4 的请求级管理。 + + 在基类 ReqManager 之上补 V4 专有的 per-request 结构。该对象在 mem manager profile 前创建, + 所以初始化只依赖 config 派生出的 compress_rates/head_dim/indexer_head_dim/sliding_window; + 真实 mem_manager 会在 `_init_mem_manager()` 后通过 `bind_mem_manager()` 接入。 + + * 压缩槽位不在本类: ``full_to_c4/c128_indexs``(mem manager)以组末 token 的 full 槽位为键。 + 本类只负责 prep 阶段的分配与 scatter(``prepare_prefill_compress_slots`` / + ``prepare_decode_compress_slots``)——必须先于 attention metadata 构建/图捕获; + 条目内容由 layer-infer 的 compressor 前向写入。 + * ``req_to_c128_state_pool`` —— c128 compressor 的在途状态(per req、per c128 层)。 + c128 在线聚合在 128 边界自然归零(命中边界必 128 对齐),无缓存常驻需求,保持 req 键控。 + c4 状态(跨边界 overlap)在 mem manager 的 swa 页派生池,随页生灭,命中零拷贝续算。 + * SWA 槽位分配/出窗回收(``prepare_prefill_swa`` / ``prepare_decode_swa``): 每步 prep 阶段 + 为新 token 调 mem_manager.alloc_swa,并按 per-req 水位线(``_swa_evict_marks``)惰性回收 + 已出窗位置的 swa 槽。水位线首次置为该请求首个 chunk 的 ready_cache_len(radix 共享前缀 + 的边界),因此共享前缀的 swa 槽永远不会被本请求回收(归 radix 经 mem_manager.free 级联释放)。 """ def __init__( @@ -359,10 +385,24 @@ def __init__( compress_rates: Optional[List[int]] = None, head_dim: Optional[int] = None, indexer_head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, ): super().__init__(max_request_num, max_sequence_length, mem_manager) self.mem_manager = mem_manager + if mem_manager is not None: + if compress_rates is None: + compress_rates = mem_manager.compress_rates + if head_dim is None: + head_dim = mem_manager.head_dim + if indexer_head_dim is None: + indexer_head_dim = mem_manager.indexer_head_dim + if sliding_window is None: + sliding_window = mem_manager.sliding_window + self.sliding_window = sliding_window + # 出窗回收水位线: -1 表示该 req 尚未见过 prefill chunk(首个 chunk 的 ready_cache_len + # 即共享前缀边界,作为永不下探的回收下界)。 + self._swa_evict_marks = [-1 for _ in range(max_request_num + 1)] self.compress_rates = list(compress_rates) self.n_c4 = sum(1 for r in self.compress_rates if r == 4) self.n_c128 = sum(1 for r in self.compress_rates if r == 128) @@ -379,60 +419,16 @@ def __init__( self.layer_to_c128_idx[lid] = c128 c128 += 1 - # (req, 窗口) -> 压缩槽。列数取 ceil(max_seq / ratio) 留足余量。 - c4_windows = (max_sequence_length + 4 - 1) // 4 - c128_windows = (max_sequence_length + 128 - 1) // 128 - self.req_to_c4_indexs = torch.zeros((max_request_num + 1, c4_windows), dtype=torch.int32, device="cuda") - self.req_to_c128_indexs = torch.zeros((max_request_num + 1, c128_windows), dtype=torch.int32, device="cuda") - self._c4_entry_counts = [0 for _ in range(max_request_num + 1)] - self._c128_entry_counts = [0 for _ in range(max_request_num + 1)] - - # compressor 在途窗口累加状态(fp32): [kv_or_score, coff * ratio, coff * dim]. - state_dtype = torch.float32 - self.req_to_c4_state = LayerCache( - size=max_request_num + 1, - dtype=state_dtype, - shape=(2, 8, 2 * head_dim), - layer_num=self.n_c4, - device="cuda", - ) - self.req_to_c128_state = LayerCache( - size=max_request_num + 1, - dtype=state_dtype, - shape=(2, 128, head_dim), - layer_num=self.n_c128, - device="cuda", - ) - self.req_to_c4_indexer_state = LayerCache( - size=max_request_num + 1, - dtype=state_dtype, - shape=(2, 8, 2 * indexer_head_dim), - layer_num=self.n_c4, - device="cuda", - ) - self.req_to_c4_state_pool = LayerCache( - size=max_request_num + 1, - dtype=state_dtype, - shape=(1, 8, 4 * head_dim), - layer_num=self.n_c4, - device="cuda", - ) + # c128 compressor 在途状态(fp32): 在线聚合在 128 边界自然归零(命中边界必 128 对齐), + # 无缓存常驻需求,保持 req 键控。c4 状态(有跨边界 overlap)在 mem manager 的 + # swa 页派生池(c4_state_buffer / c4_indexer_state_buffer)。 self.req_to_c128_state_pool = LayerCache( size=max_request_num + 1, - dtype=state_dtype, + dtype=torch.float32, shape=(1, 128, 2 * head_dim), layer_num=self.n_c128, device="cuda", ) - self.req_to_c4_indexer_state_pool = LayerCache( - size=max_request_num + 1, - dtype=state_dtype, - shape=(1, 8, 4 * indexer_head_dim), - layer_num=self.n_c4, - device="cuda", - ) - self._runtime_states = [{} for _ in range(max_request_num + 1)] - self._init_all_score_state() return def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager): @@ -440,22 +436,92 @@ def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager): assert self.compress_rates == mem_manager.compress_rates assert self.head_dim == mem_manager.head_dim assert self.indexer_head_dim == mem_manager.indexer_head_dim + if self.sliding_window is None: + self.sliding_window = mem_manager.sliding_window + else: + assert mem_manager.sliding_window is None or self.sliding_window == mem_manager.sliding_window self.mem_manager = mem_manager return - def _init_all_score_state(self): - if self.n_c4 > 0: - self.req_to_c4_state.buffer[:, :, 1, ...].fill_(float("-inf")) - self.req_to_c4_indexer_state.buffer[:, :, 1, ...].fill_(float("-inf")) - if self.n_c128 > 0: - self.req_to_c128_state.buffer[:, :, 1, ...].fill_(float("-inf")) + # ------------------------------------------------------------------ swa slot prep (per step) + def _swa_retain_len(self) -> int: + """出窗回收的保留长度 = window + 一个 radix 页。 + + 多留一页使「最近一个完成的 128 边界」的结尾页恒驻留: prompt cache 只能在 floor(cur/128) + 边界入树(radix page=128),若回收只留 window,则任何非对齐时刻该边界的结尾页都已被 + 部分回收,插入门会把所有插入裁到 0(prompt cache 形同虚设)。预算即 v5 §2 的每请求 + 「活跃窗口跨页 ≤2」。驻留证明要求 window >= page-1(DSV4 实际 window == page == 128)。""" + return int(self.sliding_window) + self.get_prompt_cache_page_size() + + def prepare_prefill_swa( + self, + b_req_idx: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_seq_len: torch.Tensor, + ) -> None: + """prefill prep: 为本 chunk 全部新 token(位置 [ready, seq))分配位置对齐的 swa 槽, + 并回收已出窗位置的槽。 + + 本 chunk 起点 L = ready_cache_len,首个新 token(位置 L)的窗口是 [L-W+1, L];回收 + 边界再额外保留一个 radix 页(_swa_retain_len),即位置 < L-retain+1。先回收再分配。 + 必须在 init_req_to_token_indexes 之后调用(位置对齐分配经 req_to_token 行派生/scatter)。""" + assert self.mem_manager is not None + if self.sliding_window is not None: + retain = self._swa_retain_len() + evict_slots = [] + req_list = b_req_idx.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() + for req_idx, ready_len in zip(req_list, ready_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + ready_len = int(ready_len) + mark = self._swa_evict_marks[req_idx] + if mark < 0: + # 首个 chunk: [0, ready_len) 是 radix 共享前缀,其 swa 槽归 radix 所有,不可回收。 + self._swa_evict_marks[req_idx] = ready_len + continue + evict_end = ready_len - retain + 1 + if evict_end > mark: + evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end]) + self._swa_evict_marks[req_idx] = evict_end + if evict_slots: + self.mem_manager.evict_swa(torch.cat(evict_slots)) + self.mem_manager.alloc_swa_prefill(b_req_idx, b_ready_cache_len, b_seq_len, self.req_to_token_indexs) return - def _reset_compress_cache_req(self, cache: LayerCache, req_idx: int): - if cache.layer_num == 0: - return - cache.buffer[:, req_idx, 0, ...].fill_(0) - cache.buffer[:, req_idx, 1, ...].fill_(float("-inf")) + def prepare_decode_swa( + self, + b_req_idx: torch.Tensor, + b_seq_len: torch.Tensor, + mem_indexes: torch.Tensor, + ) -> None: + """decode prep: 回收出窗槽并为本步新 token 分配位置对齐的 swa 槽。当前 query 位置 + seq_len-1 的窗口是 [seq_len-W, seq_len-1];回收边界额外保留一个 radix 页 + (_swa_retain_len),即位置 < seq_len-retain。先回收再分配。""" + assert self.mem_manager is not None + if self.sliding_window is not None: + retain = self._swa_retain_len() + evict_slots = [] + req_list = b_req_idx.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + for req_idx, seq_len in zip(req_list, seq_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + seq_len = int(seq_len) + mark = self._swa_evict_marks[req_idx] + if mark < 0: + # 未经过 prefill prep 的保守路径: 不回收旧位置,仅推进水位线。 + self._swa_evict_marks[req_idx] = max(0, seq_len - retain) + continue + evict_end = seq_len - retain + if evict_end > mark: + evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end]) + self._swa_evict_marks[req_idx] = evict_end + if evict_slots: + self.mem_manager.evict_swa(torch.cat(evict_slots)) + self.mem_manager.alloc_swa_decode(b_req_idx, b_seq_len, mem_indexes, self.req_to_token_indexs) return def _reset_state_pool_req(self, cache: LayerCache, req_idx: int): @@ -465,105 +531,90 @@ def _reset_state_pool_req(self, cache: LayerCache, req_idx: int): return def init_compress_state(self, req_idx: int): - """新请求开始时重置其 compressor 在途状态(对应 mamba 的 init_linear_att_state)。""" + """新请求开始时重置其 compressor 在途状态(对应 mamba 的 init_linear_att_state)。 + + 只有 c128 状态是 req 键控的(c4 状态随 swa 页生灭,内核组起点覆写,无需重置; + 压缩槽位以 full 槽位为键,随请求 full 槽的释放级联回收)。""" self.clear_runtime_state(req_idx) - c4, c128 = self.pop_compress_indices_for_req(req_idx) - self.free_compress_indices(free_c4_index=c4, free_c128_index=c128) - if self.n_c4 > 0: - self._reset_compress_cache_req(self.req_to_c4_state, req_idx) - self._reset_compress_cache_req(self.req_to_c4_indexer_state, req_idx) - self._reset_state_pool_req(self.req_to_c4_state_pool, req_idx) - self._reset_state_pool_req(self.req_to_c4_indexer_state_pool, req_idx) if self.n_c128 > 0: - self._reset_compress_cache_req(self.req_to_c128_state, req_idx) self._reset_state_pool_req(self.req_to_c128_state_pool, req_idx) return - def _ensure_compress_slots(self, req_idx: int, ratio: int, entry_start: int, entry_count: int) -> torch.Tensor: - if entry_count == 0: - return torch.empty((0,), dtype=torch.int32, device="cuda") - assert entry_start >= 0 and entry_count >= 0 + # ------------------------------------------------------------------ compress slot prep (per step) + def _compress_mapping_alloc(self, ratio: int): assert self.mem_manager is not None, "DeepSeek-V4 mem manager is not bound yet" if ratio == 4: - table = self.req_to_c4_indexs - counts = self._c4_entry_counts - alloc = self.mem_manager.alloc_c4 - elif ratio == 128: - table = self.req_to_c128_indexs - counts = self._c128_entry_counts - alloc = self.mem_manager.alloc_c128 - else: - raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}") - - required_count = entry_start + entry_count - assert required_count <= table.shape[1], ( - f"DeepSeek-V4 compressed slot table overflow: req={req_idx} " - f"ratio={ratio} required={required_count} capacity={table.shape[1]}" - ) - old_count = counts[req_idx] - if required_count > old_count: - new_slots_cpu = alloc(required_count - old_count) - table[req_idx, old_count:required_count] = new_slots_cpu.cuda(non_blocking=True) - counts[req_idx] = required_count - return table[req_idx, entry_start:required_count] - - def ensure_c4_slots(self, req_idx: int, entry_start: int, entry_count: int) -> torch.Tensor: - return self._ensure_compress_slots(req_idx, 4, entry_start, entry_count) - - def ensure_c128_slots(self, req_idx: int, entry_start: int, entry_count: int) -> torch.Tensor: - return self._ensure_compress_slots(req_idx, 128, entry_start, entry_count) - - def ensure_compress_slots(self, layer_index: int, req_idx: int, entry_start: int, entry_count: int) -> torch.Tensor: - ratio = self.compress_rates[layer_index] - if ratio == 4: - return self.ensure_c4_slots(req_idx, entry_start, entry_count) + return self.mem_manager.full_to_c4_indexs, self.mem_manager.alloc_c4 if ratio == 128: - return self.ensure_c128_slots(req_idx, entry_start, entry_count) - raise AssertionError(f"layer {layer_index} is not a compressed attention layer") + return self.mem_manager.full_to_c128_indexs, self.mem_manager.alloc_c128 + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}") - def prepare_decode_compress_slots(self, b_req_idx: torch.Tensor, b_seq_len: torch.Tensor) -> None: + def _scatter_compress_slots(self, ratio: int, full_slots: torch.Tensor) -> None: + """为组末 full 槽位分配压缩槽并写入映射。已映射(>=0)的行跳过——重复 prep 幂等。""" + if full_slots.numel() == 0: + return + mapping, alloc = self._compress_mapping_alloc(ratio) + full_slots = full_slots.cuda().long().reshape(-1) + # 去重: 同批重复键会让后写覆盖先写,先分配的压缩槽成为孤儿(allocator 泄漏)。 + need = torch.unique(full_slots[mapping[full_slots] < 0]) + if need.numel() == 0: + return + new_slots = alloc(need.numel()).cuda(non_blocking=True).to(torch.int32) + mapping[need] = new_slots + return + + def prepare_prefill_compress_slots( + self, + b_req_idx: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_seq_len: torch.Tensor, + ) -> None: + """prefill prep: 为本 chunk 内的组末 token(位置 (g+1)*ratio-1 ∈ [ready, seq))分配压缩槽, + scatter 进 full_to_c4/c128_indexs。必须在 init_req_to_token_indexes 之后(组末 full 槽 + 从 req_to_token_indexs 取)、attention metadata 构建之前调用。""" + if self.n_c4 == 0 and self.n_c128 == 0: + return req_list = b_req_idx.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() seq_list = b_seq_len.detach().cpu().tolist() - for req_idx, seq_len in zip(req_list, seq_list): - req_idx = int(req_idx) - if req_idx == self.HOLD_REQUEST_ID: + for ratio, n_layers in ((4, self.n_c4), (128, self.n_c128)): + if n_layers == 0: continue - seq_len = int(seq_len) - if self.n_c4 > 0: - required_c4 = seq_len // 4 - old_c4 = self._c4_entry_counts[req_idx] - if required_c4 > old_c4: - self.ensure_c4_slots(req_idx, old_c4, required_c4 - old_c4) - if self.n_c128 > 0: - required_c128 = seq_len // 128 - old_c128 = self._c128_entry_counts[req_idx] - if required_c128 > old_c128: - self.ensure_c128_slots(req_idx, old_c128, required_c128 - old_c128) + end_slots = [] + for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + first, last = int(ready_len) // ratio, int(seq_len) // ratio + if last > first: + ends = self.req_to_token_indexs[req_idx, ratio - 1 : last * ratio : ratio] + end_slots.append(ends[first:]) + if end_slots: + self._scatter_compress_slots(ratio, torch.cat(end_slots)) return - def pop_compress_indices_for_req(self, req_idx: int): - c4_count = self._c4_entry_counts[req_idx] - if c4_count > 0: - c4 = self.req_to_c4_indexs[req_idx, :c4_count].clone() - self.req_to_c4_indexs[req_idx, :c4_count].fill_(0) - self._c4_entry_counts[req_idx] = 0 - else: - c4 = None - - c128_count = self._c128_entry_counts[req_idx] - if c128_count > 0: - c128 = self.req_to_c128_indexs[req_idx, :c128_count].clone() - self.req_to_c128_indexs[req_idx, :c128_count].fill_(0) - self._c128_entry_counts[req_idx] = 0 - else: - c128 = None - return c4, c128 - - def free_compress_indices(self, free_c4_index=None, free_c128_index=None): - if free_c4_index is not None and len(free_c4_index) > 0: - self.mem_manager.free_c4(free_c4_index) - if free_c128_index is not None and len(free_c128_index) > 0: - self.mem_manager.free_c128(free_c128_index) + def prepare_decode_compress_slots( + self, + b_req_idx: torch.Tensor, + b_seq_len: torch.Tensor, + mem_indexes: torch.Tensor, + ) -> None: + """decode prep: 本步 token 关闭一个组(seq_len % ratio == 0)时为其分配压缩槽并 scatter。 + 组末 full 槽即本步的 mem_index(此刻 req_to_token_indexs 尚未写入本步槽位)。""" + if self.n_c4 == 0 and self.n_c128 == 0: + return + req_list = b_req_idx.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + for ratio, n_layers in ((4, self.n_c4), (128, self.n_c128)): + if n_layers == 0: + continue + rows = [ + i + for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)) + if int(req_idx) != self.HOLD_REQUEST_ID and int(seq_len) > 0 and int(seq_len) % ratio == 0 + ] + if rows: + self._scatter_compress_slots(ratio, mem_indexes.reshape(-1)[rows]) return def alloc(self): @@ -573,53 +624,17 @@ def alloc(self): return req_idx def clear_runtime_state(self, req_idx: int): - self._runtime_states[req_idx].clear() - if self.mem_manager is not None and hasattr(self.mem_manager, "free_swa_for_req"): - self.mem_manager.free_swa_for_req(req_idx) - return - - def set_runtime_state(self, req_idx: int, layer_index: int, state: dict): - self._runtime_states[req_idx][layer_index] = state + # swa 槽位本身由 mem_manager.free 级联回收(随 full 槽位),这里只复位出窗水位线。 + self._swa_evict_marks[req_idx] = -1 return - def get_runtime_state(self, req_idx: int, layer_index: int): - return self._runtime_states[req_idx][layer_index] - - def get_compress_state_for_req(self, layer_index: int, req_idx: int): - if self.compress_rates[layer_index] == 4: - state = self.get_c4_compress_state(layer_index) - elif self.compress_rates[layer_index] == 128: - state = self.get_c128_compress_state(layer_index) - else: - raise AssertionError(f"layer {layer_index} is not a compressed attention layer") - return state[req_idx, 0], state[req_idx, 1] - def get_compress_state_pool_for_req(self, layer_index: int, req_idx: int): - if self.compress_rates[layer_index] == 4: - cache = self.req_to_c4_state_pool - local = self.layer_to_c4_idx[layer_index] - elif self.compress_rates[layer_index] == 128: - cache = self.req_to_c128_state_pool - local = self.layer_to_c128_idx[layer_index] - else: - raise AssertionError(f"layer {layer_index} is not a compressed attention layer") - return cache.buffer[local, req_idx] - - def get_c4_compress_state(self, layer_index: int) -> torch.Tensor: - local = self.layer_to_c4_idx[layer_index] - return self.req_to_c4_state.buffer[local] - - def get_c128_compress_state(self, layer_index: int) -> torch.Tensor: - local = self.layer_to_c128_idx[layer_index] - return self.req_to_c128_state.buffer[local] + assert self.compress_rates[layer_index] == 128, "c4 state 在 mem manager 的 swa 页派生池" + return self.req_to_c128_state_pool.buffer[self.layer_to_c128_idx[layer_index], req_idx] - def get_c4_indexer_compress_state(self, layer_index: int) -> torch.Tensor: - local = self.layer_to_c4_idx[layer_index] - return self.req_to_c4_indexer_state.buffer[local] - - def get_c4_indexer_state_pool_for_req(self, layer_index: int, req_idx: int) -> torch.Tensor: - local = self.layer_to_c4_idx[layer_index] - return self.req_to_c4_indexer_state_pool.buffer[local, req_idx] + def get_compress_state_pool(self, layer_index: int): + assert self.compress_rates[layer_index] == 128, "c4 state 在 mem manager 的 swa 页派生池" + return self.req_to_c128_state_pool.buffer[self.layer_to_c128_idx[layer_index]] def get_prompt_cache_value_ops(self): return DeepseekV4PromptCacheValueOps(self) @@ -627,262 +642,76 @@ def get_prompt_cache_value_ops(self): def get_prompt_cache_page_size(self): return 128 - def _slice_cpu_slots(self, slots: Optional[torch.Tensor], start: int, end: int, ratio: int): - if slots is None: - return None - return slots[start // ratio : end // ratio].clone() - - def _slice_swa_payload(self, swa_payload, start: int, end: int): - if swa_payload is None: - return None - positions = swa_payload["positions"] - mask = (positions >= start) & (positions < end) - if not bool(mask.any()): - return None - return { - "positions": positions[mask].clone(), - "full_slots": swa_payload["full_slots"][mask].clone(), - "swa_slots": swa_payload["swa_slots"][mask].clone(), - } + def compute_swa_page_valid(self, full_slots: torch.Tensor) -> torch.Tensor: + """按当下 full_to_swa 映射给出按页有效性: full_slots [L](L 为 page 整数倍) -> + cpu bool [L/page],页内全部映射有效才为 True。GPU gather + 同步,测试/校验用; + 插入热路径用 swa_page_valid_from_watermark(纯 CPU,免同步)。""" + page = self.get_prompt_cache_page_size() + assert full_slots.numel() % page == 0 + if full_slots.numel() == 0: + return torch.zeros((0,), dtype=torch.bool) + swa = self.mem_manager.full_to_swa_indexs[full_slots.cuda().long().reshape(-1)] + return (swa.view(-1, page) >= 0).all(dim=1).cpu() + + def swa_page_valid_from_watermark(self, req_idx: int, cache_len: int) -> torch.Tensor: + """插入时的按页有效性,纯 CPU: 请求自有 token 的 swa 映射只被出窗水位线回收 + (阀不触活跃请求,级联只在 free 时),页 p 全驻留 ⟺ 页起点 128p >= 水位线。 + + 与 compute_swa_page_valid 在插入时刻对自有 token 等价,但不做 GPU gather/同步—— + router 关键路径上每次插入省一次对全部在途 kernel 的等待。bitmap 中借入前缀 + ([0, ready) 的页)的行在 radix insert 切片时被丢弃(既有节点保留自己的 bitmap), + 其取值无影响。""" + page = self.get_prompt_cache_page_size() + mark = max(0, self._swa_evict_marks[req_idx]) + n_pages = int(cache_len) // page + return torch.arange(n_pages, dtype=torch.long) * page >= mark def slice_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload, start: int, end: int): start = int(start) end = int(end) - # c4/c128/indexer-K slots are true historical KV and can be sliced by ratio. - # compressor running state only describes the payload end boundary; it is valid - # for a slice only when that slice keeps the original end boundary. - keep_end_state = end == payload.cache_len + page = self.get_prompt_cache_page_size() + # radix page=128 保证分裂点页对齐,bitmap 可整页切分。 return DeepseekV4PromptCachePayload( cache_len=end - start, - c4_slots=self._slice_cpu_slots(payload.c4_slots, start, end, 4), - c128_slots=self._slice_cpu_slots(payload.c128_slots, start, end, 128), - c4_state=payload.c4_state.clone() if keep_end_state and payload.c4_state is not None else None, - c4_state_pool=payload.c4_state_pool.clone() - if keep_end_state and payload.c4_state_pool is not None - else None, - c4_indexer_state=payload.c4_indexer_state.clone() - if keep_end_state and payload.c4_indexer_state is not None - else None, - c4_indexer_state_pool=payload.c4_indexer_state_pool.clone() - if keep_end_state and payload.c4_indexer_state_pool is not None + swa_page_valid=payload.swa_page_valid[start // page : end // page].clone() + if payload.swa_page_valid is not None else None, - swa=self._slice_swa_payload(payload.swa, start, end), ) def concat_prompt_cache_payloads(self, payloads: List[DeepseekV4PromptCachePayload]): if len(payloads) == 0: return None - c4_slots = [p.c4_slots for p in payloads if p.c4_slots is not None and len(p.c4_slots) > 0] - c128_slots = [p.c128_slots for p in payloads if p.c128_slots is not None and len(p.c128_slots) > 0] - last = payloads[-1] + bitmaps = [p.swa_page_valid for p in payloads] return DeepseekV4PromptCachePayload( cache_len=sum(p.cache_len for p in payloads), - c4_slots=torch.cat(c4_slots, dim=0) if c4_slots else None, - c128_slots=torch.cat(c128_slots, dim=0) if c128_slots else None, - c4_state=last.c4_state, - c4_state_pool=last.c4_state_pool, - c4_indexer_state=last.c4_indexer_state, - c4_indexer_state_pool=last.c4_indexer_state_pool, - swa=last.swa, + swa_page_valid=torch.cat(bitmaps, dim=0) if all(b is not None for b in bitmaps) else None, ) def build_prompt_cache_payload( self, req_idx: int, cache_len: int, - clone_swa: bool = False, ) -> DeepseekV4PromptCachePayload: + """构造插入载荷。compressor 状态不进载荷(c4 随 swa 页生灭、c128 边界自然归零), + cache_len 不再受序列末端约束——任意 128 对齐前缀皆可插入。 + swa_page_valid 不在此填: 它必须用插入时刻的映射(infer batch 在 insert 前补)。""" assert self.mem_manager is not None - cache_len = int(cache_len) - full_slots = self.req_to_token_indexs[req_idx, :cache_len].detach().cpu() - c4_count = cache_len // 4 - c128_count = cache_len // 128 - c4_slots = self.req_to_c4_indexs[req_idx, :c4_count].detach().cpu().clone() if c4_count > 0 else None - c128_slots = self.req_to_c128_indexs[req_idx, :c128_count].detach().cpu().clone() if c128_count > 0 else None - if clone_swa: - swa_payload = self.mem_manager.clone_swa_for_prompt_cache(req_idx, cache_len, full_slots) - else: - swa_payload = self.mem_manager.snapshot_swa_for_prompt_cache(req_idx, cache_len, full_slots) - return DeepseekV4PromptCachePayload( - cache_len=cache_len, - c4_slots=c4_slots, - c128_slots=c128_slots, - c4_state=self.req_to_c4_state.buffer[:, req_idx].detach().clone() if self.n_c4 > 0 else None, - c4_state_pool=self.req_to_c4_state_pool.buffer[:, req_idx].detach().clone() if self.n_c4 > 0 else None, - c4_indexer_state=self.req_to_c4_indexer_state.buffer[:, req_idx].detach().clone() - if self.n_c4 > 0 - else None, - c4_indexer_state_pool=self.req_to_c4_indexer_state_pool.buffer[:, req_idx].detach().clone() - if self.n_c4 > 0 - else None, - swa=swa_payload, - ) - - def detach_prompt_cache_payload_from_req(self, req_idx: int, payload: DeepseekV4PromptCachePayload): - if payload is not None and self.mem_manager is not None: - self.mem_manager.detach_swa_for_prompt_cache(req_idx, payload.swa) - return - - def free_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload): - if payload is None or self.mem_manager is None: - return - if payload.c4_slots is not None and len(payload.c4_slots) > 0: - self.mem_manager.free_c4(payload.c4_slots) - if payload.c128_slots is not None and len(payload.c128_slots) > 0: - self.mem_manager.free_c128(payload.c128_slots) - self.mem_manager.free_swa_prompt_cache(payload.swa) - return - - def release_prompt_cache_detached_swa( - self, - payload: DeepseekV4PromptCachePayload, - keep_payload: Optional[DeepseekV4PromptCachePayload] = None, - ): - if payload is None or payload.swa is None or self.mem_manager is None: - return - old_swa = payload.swa - if keep_payload is None or keep_payload.swa is None: - self.mem_manager.free_swa_prompt_cache(old_swa) - return - - old_slots = old_swa["swa_slots"].long() - keep_slots = keep_payload.swa["swa_slots"].long() - if old_slots.numel() == 0: - return - if keep_slots.numel() == 0: - self.mem_manager.free_swa_prompt_cache(old_swa) - return - - release_mask = ~torch.isin(old_slots, keep_slots) - if not release_mask.any(): - return - release_payload = { - "full_slots": old_swa["full_slots"][release_mask].clone(), - "swa_slots": old_swa["swa_slots"][release_mask].clone(), - } - self.mem_manager.free_swa_prompt_cache(release_payload) - return - - def _reset_c128_for_prompt_cache(self, req_idx: int): - if self.n_c128 > 0: - self._reset_compress_cache_req(self.req_to_c128_state, req_idx) - self._reset_state_pool_req(self.req_to_c128_state_pool, req_idx) - return - - def rebuild_runtime_state_for_req(self, req_idx: int): - state_map = self._runtime_states[req_idx] - state_map.clear() - for layer_index, ratio in enumerate(self.compress_rates): - if ratio == 4: - cstate_kv, cstate_score = self.get_compress_state_for_req(layer_index, req_idx) - idx_state = self.get_c4_indexer_compress_state(layer_index) - state_map[layer_index] = { - "cstate_kv": cstate_kv, - "cstate_score": cstate_score, - "idx_cstate_kv": idx_state[req_idx, 0], - "idx_cstate_score": idx_state[req_idx, 1], - } - elif ratio == 128: - cstate_kv, cstate_score = self.get_compress_state_for_req(layer_index, req_idx) - state_map[layer_index] = { - "cstate_kv": cstate_kv, - "cstate_score": cstate_score, - } - return - - def restore_prompt_cache_payload(self, req_idx: int, payload: DeepseekV4PromptCachePayload): - assert self.mem_manager is not None - cache_len = int(payload.cache_len) - c4_count = cache_len // 4 - c128_count = cache_len // 128 - if c4_count > 0: - assert payload.c4_slots is not None and len(payload.c4_slots) == c4_count - self.req_to_c4_indexs[req_idx, :c4_count] = payload.c4_slots.cuda(non_blocking=True) - if c128_count > 0: - assert payload.c128_slots is not None and len(payload.c128_slots) == c128_count - self.req_to_c128_indexs[req_idx, :c128_count] = payload.c128_slots.cuda(non_blocking=True) - self._c4_entry_counts[req_idx] = c4_count - self._c128_entry_counts[req_idx] = c128_count - - if self.n_c4 > 0: - if payload.c4_state is None or payload.c4_indexer_state is None: - raise RuntimeError("DeepSeek-V4 prompt cache hit is missing c4 running state") - self.req_to_c4_state.buffer[:, req_idx].copy_(payload.c4_state) - self.req_to_c4_indexer_state.buffer[:, req_idx].copy_(payload.c4_indexer_state) - if payload.c4_state_pool is not None: - self.req_to_c4_state_pool.buffer[:, req_idx].copy_(payload.c4_state_pool) - if payload.c4_indexer_state_pool is not None: - self.req_to_c4_indexer_state_pool.buffer[:, req_idx].copy_(payload.c4_indexer_state_pool) - self._reset_c128_for_prompt_cache(req_idx) - self.mem_manager.restore_swa_from_prompt_cache(payload.swa) - self.rebuild_runtime_state_for_req(req_idx) - return + return DeepseekV4PromptCachePayload(cache_len=int(cache_len)) - def pop_prompt_cache_free_compress_indices( - self, - req_idx: int, - keep_len: int, - duplicate_start_len: Optional[int] = None, - duplicate_end_len: Optional[int] = None, - ): - def collect(table, cur_count, ratio): - ranges = [] - if duplicate_start_len is not None and duplicate_end_len is not None: - dup_start = duplicate_start_len // ratio - dup_end = duplicate_end_len // ratio - if dup_end > dup_start: - ranges.append((dup_start, dup_end)) - keep_count = keep_len // ratio - if cur_count > keep_count: - ranges.append((keep_count, cur_count)) - parts = [table[req_idx, s:e].clone() for s, e in ranges if e > s] - return torch.cat(parts, dim=0) if parts else None - - c4 = collect(self.req_to_c4_indexs, self._c4_entry_counts[req_idx], 4) - c128 = collect(self.req_to_c128_indexs, self._c128_entry_counts[req_idx], 128) - if self._c4_entry_counts[req_idx] > 0: - self.req_to_c4_indexs[req_idx, : self._c4_entry_counts[req_idx]].fill_(0) - if self._c128_entry_counts[req_idx] > 0: - self.req_to_c128_indexs[req_idx, : self._c128_entry_counts[req_idx]].fill_(0) - self._c4_entry_counts[req_idx] = 0 - self._c128_entry_counts[req_idx] = 0 - return c4, c128 - - def free( - self, - free_req_indexes, - free_token_index, - free_c4_index=None, - free_c128_index=None, - ): - """释放 dense 槽(基类)+ 压缩槽。压缩槽由调用方(infer batch)从 req_to_c*_indexs 收集后传入, - 与基类用 free_token_index 传 dense 槽的方式一致。""" + def free(self, free_req_indexes, free_token_index): + """dense/swa/压缩槽全部经 mem_manager.free(free_token_index) 级联回收。""" for req_index in free_req_indexes: self.clear_runtime_state(req_index) super().free(free_req_indexes, free_token_index) - self.free_compress_indices(free_c4_index=free_c4_index, free_c128_index=free_c128_index) return def free_req(self, free_req_index: int): self.clear_runtime_state(free_req_index) - c4, c128 = self.pop_compress_indices_for_req(free_req_index) - self.free_compress_indices(free_c4_index=c4, free_c128_index=c128) return super().free_req(free_req_index) def free_all(self): super().free_all() - self._runtime_states = [{} for _ in range(self.max_request_num + 1)] - self._c4_entry_counts = [0 for _ in range(self.max_request_num + 1)] - self._c128_entry_counts = [0 for _ in range(self.max_request_num + 1)] - if self.n_c4 > 0: - self.req_to_c4_indexs.fill_(0) - self.req_to_c4_state.buffer.fill_(0) - self.req_to_c4_indexer_state.buffer.fill_(0) - self.req_to_c4_state_pool.buffer.fill_(0) - self.req_to_c4_indexer_state_pool.buffer.fill_(0) + self._swa_evict_marks = [-1 for _ in range(self.max_request_num + 1)] if self.n_c128 > 0: - self.req_to_c128_indexs.fill_(0) - self.req_to_c128_state.buffer.fill_(0) self.req_to_c128_state_pool.buffer.fill_(0) - self._init_all_score_state() return diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py index d0c2745161..39a6889d72 100644 --- a/lightllm/models/deepseek_v4/infer_struct.py +++ b/lightllm/models/deepseek_v4/infer_struct.py @@ -9,8 +9,8 @@ class DeepseekV4InferStateInfo(InferStateInfo): mem_manager: DeepseekV4MemoryManager """Per-token interleaved-rope cos/sin for the two rope variants (sliding / compressed), following - the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). Also exposes the full compressed - cos/sin tables, which the KV compressor indexes at window positions (not per-token).""" + the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). The full rope tables are + model constants and live on the model / layer infers, not here.""" def __init__(self): super().__init__() @@ -18,8 +18,6 @@ def __init__(self): self.position_sin_sliding = None self.position_cos_compress = None self.position_sin_compress = None - self.cos_compress_table = None - self.sin_compress_table = None def init_some_extra_state(self, model): super().init_some_extra_state(model) # sets position_ids, b_q_seq_len, b_q_start_loc (prefill) @@ -28,5 +26,3 @@ def init_some_extra_state(self, model): self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos) self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos) self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos) - self.cos_compress_table = model._cos_cached_compress - self.sin_compress_table = model._sin_cached_compress diff --git a/lightllm/models/deepseek_v4/layer_infer/attention.py b/lightllm/models/deepseek_v4/layer_infer/attention.py deleted file mode 100644 index 8a7428f0dd..0000000000 --- a/lightllm/models/deepseek_v4/layer_infer/attention.py +++ /dev/null @@ -1,149 +0,0 @@ -import os - -import torch - - -FLASHMLA_MIN_HEADS = 64 -FLASHMLA_TOPK_MULTIPLE = 128 -DSV4_DEBUG_TORCH_SPARSE_ATTN = os.getenv("DSV4_DEBUG_TORCH_SPARSE_ATTN", "0") == "1" - - -def _pad_topk_for_flashmla(topk_idxs): - K = topk_idxs.shape[-1] - padded_K = ((K + FLASHMLA_TOPK_MULTIPLE - 1) // FLASHMLA_TOPK_MULTIPLE) * FLASHMLA_TOPK_MULTIPLE - if padded_K == K: - return topk_idxs.contiguous() - padded = torch.full((*topk_idxs.shape[:-1], padded_K), -1, device=topk_idxs.device, dtype=topk_idxs.dtype) - padded[..., :K] = topk_idxs - return padded.contiguous() - - -def _compact_topk_indices(topk_idxs, kv_len): - valid = (topk_idxs >= 0) & (topk_idxs < kv_len) - topk_lens = valid.sum(dim=-1).to(torch.int32) - if valid.all(): - return topk_idxs.contiguous(), topk_lens.contiguous() - - compact = torch.full_like(topk_idxs, -1) - ranks = valid.to(torch.int32).cumsum(dim=-1) - 1 - rows = torch.arange(topk_idxs.shape[0], device=topk_idxs.device).unsqueeze(1).expand_as(topk_idxs) - compact[rows[valid], ranks[valid].long()] = topk_idxs[valid] - return compact.contiguous(), topk_lens.contiguous() - - -def _pad_heads_for_flashmla(q, attn_sink): - h = q.shape[1] - if h == FLASHMLA_MIN_HEADS: - return q.contiguous(), attn_sink.to(torch.float32).contiguous(), h - if h > FLASHMLA_MIN_HEADS: - raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention only supports up to 64 local heads, got {h}") - - q_pad = q.new_zeros(q.shape[0], FLASHMLA_MIN_HEADS, q.shape[2]) - q_pad[:, :h] = q - sink_pad = torch.full((FLASHMLA_MIN_HEADS,), -float("inf"), device=q.device, dtype=torch.float32) - sink_pad[:h] = attn_sink.to(torch.float32) - return q_pad.contiguous(), sink_pad.contiguous(), h - - -def _torch_sparse_attn(q, kv, attn_sink, topk_idxs, scale): - return _torch_sparse_attn_flat(q[0], kv[0], attn_sink, topk_idxs[0], scale).unsqueeze(0) - - -def _torch_sparse_attn_flat(q, kv, attn_sink, topk_idxs, scale): - q0 = q.float() - kv0 = kv.float() - indices = topk_idxs.long() - valid = (indices >= 0) & (indices < kv0.shape[0]) - safe_indices = torch.where(valid, indices, torch.zeros_like(indices)) - kv_sel = kv0[safe_indices] - scores = torch.einsum("mhd,mkd->mhk", q0, kv_sel) * scale - scores = scores.masked_fill(~valid.unsqueeze(1), float("-inf")) - sink = attn_sink.float().view(1, -1) - max_scores = torch.maximum(scores.max(dim=-1).values, sink) - exp_scores = torch.exp(scores - max_scores.unsqueeze(-1)).masked_fill(~valid.unsqueeze(1), 0.0) - exp_sink = torch.exp(sink - max_scores) - denom = exp_scores.sum(dim=-1) + exp_sink - out = torch.einsum("mhk,mkd->mhd", exp_scores / denom.unsqueeze(-1), kv_sel) - return out.to(q.dtype) - - -def vllm_sparse_attn(q, kv, attn_sink, topk_idxs, scale): - """DeepSeek-V4 sparse MLA through vLLM FlashMLA. - - q:[1,m,h,d], kv:[1,n,d] (single KV head shared over h), attn_sink:[h], - topk_idxs:[1,m,K] int (-1 = invalid/skip). Returns o:[1,m,h,d]. - """ - b, m, h, d = q.shape - if b != 1 or kv.shape[0] != 1 or topk_idxs.shape[0] != 1: - raise RuntimeError("DeepSeek-V4 FlashMLA sparse attention wrapper expects one request per call") - if d != 512: - raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires head_dim=512, got {d}") - if q.dtype != torch.bfloat16 or kv.dtype != torch.bfloat16: - raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires bf16 q/kv, got {q.dtype}/{kv.dtype}") - - return vllm_sparse_attn_flat(q[0], kv[0], attn_sink, topk_idxs[0], scale).unsqueeze(0) - - -def vllm_sparse_attn_flat(q, kv, attn_sink, topk_idxs, scale, already_compact=False): - """FlashMLA sparse attention over a flat KV arena. - - q:[m,h,d], kv:[n,d], topk_idxs:[m,K] int. Indices are global offsets into - the flat kv tensor, so callers can concatenate per-request KV candidates and - run one FlashMLA call for the whole batch. When already_compact=True, each - row must place all valid indices before invalid (-1) entries. - """ - m, h, d = q.shape - if d != 512: - raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires head_dim=512, got {d}") - if q.dtype != torch.bfloat16 or kv.dtype != torch.bfloat16: - raise RuntimeError(f"DeepSeek-V4 FlashMLA sparse attention requires bf16 q/kv, got {q.dtype}/{kv.dtype}") - if q.shape[0] == 0: - return q.new_empty((0, h, d)) - - if DSV4_DEBUG_TORCH_SPARSE_ATTN: - return _torch_sparse_attn_flat(q, kv, attn_sink, topk_idxs, scale) - - from vllm.third_party.flashmla.flash_mla_interface import flash_mla_sparse_fwd - - q_pad, sink_pad, real_heads = _pad_heads_for_flashmla(q, attn_sink) - topk_idxs = topk_idxs.to(torch.int32) - if already_compact: - valid = (topk_idxs >= 0) & (topk_idxs < kv.shape[0]) - indices = topk_idxs.contiguous() - topk_lens = valid.sum(dim=-1).to(torch.int32).contiguous() - else: - indices, topk_lens = _compact_topk_indices(topk_idxs, kv.shape[0]) - indices = _pad_topk_for_flashmla(indices).unsqueeze(1) - kv_flat = kv.unsqueeze(1).contiguous() - out, _, _ = flash_mla_sparse_fwd( - q=q_pad, - kv=kv_flat, - indices=indices, - sm_scale=scale, - attn_sink=sink_pad, - topk_length=topk_lens, - out=None, - ) - return out[:, :real_heads].to(q.dtype) - - -def build_prefill_topk_idxs(seqlen, window, ratio, n_window, device): - """Per-query candidate indices into [window_kv (n_window tokens) ++ compressed_kv (ncomp entries)]. - - Returns int32 [seqlen, window + ncomp] with -1 for invalid. Window part indexes the per-token KV - (here stored as tokens 0..seqlen-1, so n_window == seqlen); compressed part is offset by n_window. - For prompts where ncomp <= index_topk the indexer is a no-op, so all causally-valid compressed - entries are attended (matches the reference for short context). - """ - t = torch.arange(seqlen, device=device) - offsets = torch.arange(window, device=device) - win = t.unsqueeze(1) - (window - 1 - offsets).unsqueeze(0) - win = torch.where(win >= 0, win, torch.full_like(win, -1)) - if ratio: - ncomp = seqlen // ratio - c = torch.arange(ncomp, device=device) - comp_valid = c.unsqueeze(0) < ((t.unsqueeze(1) + 1) // ratio) # [s, ncomp] - comp_idx = (c.unsqueeze(0) + n_window).expand(seqlen, ncomp) - comp = torch.where(comp_valid, comp_idx, torch.full((seqlen, ncomp), -1, device=device, dtype=torch.long)) - return torch.cat([win, comp], dim=1).int() - return win.int() diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py index f51f73829c..2256ecd1a9 100644 --- a/lightllm/models/deepseek_v4/layer_infer/compressor.py +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -1,121 +1,37 @@ -import importlib.util -import logging -import sys -import types -from pathlib import Path - import torch -import torch.nn.functional as F -from ..triton_kernel.rotary_emb import apply_rotary_emb -logger = logging.getLogger(__name__) -_SGLANG_COMPRESS_MOD = None _SGLANG_COMPRESS_ERR = None -_SGLANG_COMPRESS_WARNED = False +_SGLANG_COMPRESS_MOD = None +_SGLANG_LINEAR_BF16_FP32 = None _FREQ_CIS_CACHE = {} -# KV compressor: pools every `ratio` consecutive tokens into one compressed KV entry via gated -# (softmax) pooling + a learned absolute-position bias (ape), RMSNorm, and rope on the trailing -# rope_dim. ratio==4 uses overlapping windows (two-series Ca/Cb scheme). Pure-torch transcription of -# the bundled reference inference/model.py Compressor.forward for the prefill (start_pos==0) path. -# NOTE: the reference also applies an FP8/FP4 QAT activation sim to the compressed entry; omitted here -# for the correctness-first prefill path (negligible vs argmax; revisit if e2e diverges). - - -def _overlap_transform(tensor, ratio, d, value): - # tensor: [nwin, ratio, 2*d] -> [nwin, 2*ratio, d]; slots [ratio:]=Cb(current), [:ratio]=Ca(previous window) - nwin = tensor.shape[0] - out = tensor.new_full((nwin, 2 * ratio, d), value) - out[:, ratio:] = tensor[:, :, d:] - out[1:, :ratio] = tensor[:-1, :, :d] - return out - - -def _rmsnorm(x, weight, eps): - xf = x.float() - xf = xf * torch.rsqrt(xf.square().mean(-1, keepdim=True) + eps) - return (xf * weight.float()).to(x.dtype) - - -def _load_file_module(name, path): - spec = importlib.util.spec_from_file_location(name, path) - mod = importlib.util.module_from_spec(spec) - sys.modules[name] = mod - spec.loader.exec_module(mod) - return mod - def _load_sglang_compressor(): - global _SGLANG_COMPRESS_MOD, _SGLANG_COMPRESS_ERR + global _SGLANG_COMPRESS_ERR, _SGLANG_COMPRESS_MOD, _SGLANG_LINEAR_BF16_FP32 if _SGLANG_COMPRESS_MOD is not None: - return _SGLANG_COMPRESS_MOD + return _SGLANG_COMPRESS_MOD, _SGLANG_LINEAR_BF16_FP32 if _SGLANG_COMPRESS_ERR is not None: raise _SGLANG_COMPRESS_ERR try: - from sglang.jit_kernel.dsv4 import compress_old as mod - - _SGLANG_COMPRESS_MOD = mod - return mod - except Exception as first_exc: - root = Path("/data/wanzihao/sglang/python/sglang") - try: - if not root.exists(): - raise first_exc - if "sglang" not in sys.modules: - sglang_mod = types.ModuleType("sglang") - sglang_mod.__path__ = [str(root)] - sys.modules["sglang"] = sglang_mod - if "sglang.utils" not in sys.modules: - utils_mod = types.ModuleType("sglang.utils") - utils_mod.is_in_ci = lambda: False - sys.modules["sglang.utils"] = utils_mod - if "sglang.jit_kernel" not in sys.modules: - jit_mod = types.ModuleType("sglang.jit_kernel") - jit_mod.__path__ = [str(root / "jit_kernel")] - sys.modules["sglang.jit_kernel"] = jit_mod - if "sglang.jit_kernel.dsv4" not in sys.modules: - dsv4_mod = types.ModuleType("sglang.jit_kernel.dsv4") - dsv4_mod.__path__ = [str(root / "jit_kernel" / "dsv4")] - sys.modules["sglang.jit_kernel.dsv4"] = dsv4_mod - if "sglang.srt" not in sys.modules: - srt_mod = types.ModuleType("sglang.srt") - srt_mod.__path__ = [str(root / "srt")] - sys.modules["sglang.srt"] = srt_mod - if "sglang.srt.environ" not in sys.modules: - env_mod = types.ModuleType("sglang.srt.environ") - - class _FalseEnv: - def get(self): - return False - - class _Envs: - SGLANG_OPT_USE_ONLINE_COMPRESS = _FalseEnv() - - env_mod.envs = _Envs() - sys.modules["sglang.srt.environ"] = env_mod - if "sglang.jit_kernel.utils" not in sys.modules: - _load_file_module("sglang.jit_kernel.utils", root / "jit_kernel" / "utils.py") - if "sglang.jit_kernel.dsv4.utils" not in sys.modules: - _load_file_module( - "sglang.jit_kernel.dsv4.utils", - root / "jit_kernel" / "dsv4" / "utils.py", - ) - _SGLANG_COMPRESS_MOD = _load_file_module( - "sglang.jit_kernel.dsv4.compress_old", - root / "jit_kernel" / "dsv4" / "compress_old.py", - ) - return _SGLANG_COMPRESS_MOD - except Exception as exc: - _SGLANG_COMPRESS_ERR = exc - raise exc - - -def _warn_sglang_fallback(exc): - global _SGLANG_COMPRESS_WARNED - if not _SGLANG_COMPRESS_WARNED: - logger.warning("DeepSeek-V4 SGLang compressor JIT unavailable, fallback to torch: %s", exc) - _SGLANG_COMPRESS_WARNED = True + from sglang.jit_kernel.dsv4 import linear_bf16_fp32 + from sglang.jit_kernel.dsv4 import compress_old as compress_mod + except Exception as exc: + _SGLANG_COMPRESS_ERR = RuntimeError( + "DeepSeek-V4 fused compressor requires sglang.jit_kernel.dsv4 " + "(linear_bf16_fp32 + compress_old). Install/export the SGLang package " + "or vendor the DSv4 compressor JIT into LightLLM." + ) + raise _SGLANG_COMPRESS_ERR from exc + _SGLANG_COMPRESS_MOD = compress_mod + _SGLANG_LINEAR_BF16_FP32 = linear_bf16_fp32 + return compress_mod, linear_bf16_fp32 + + +def _load_paged_compress_data_fn(): + from sglang.jit_kernel.dsv4 import triton_create_paged_compress_data + + return triton_create_paged_compress_data def _freq_cis(cos_table, sin_table): @@ -139,39 +55,27 @@ def _sglang_ape(ape, ratio, head_dim): return ape.contiguous() -def _pack_kv_score(kv, score, ratio, head_dim): - if ratio == 4: - return torch.cat( - [ - kv[:, :head_dim], - kv[:, head_dim:], - score[:, :head_dim], - score[:, head_dim:], - ], - dim=1, - ).contiguous() - return torch.cat([kv, score], dim=1).contiguous() - - -def _build_state_from_kv_score(kv, score, ape, ratio, head_dim): - overlap = ratio == 4 - kv_state, score_state = new_compressor_state(ratio, head_dim, kv.device) - s = kv.shape[0] - remainder = s % ratio - cutoff = s - remainder - offset = ratio if overlap else 0 - if overlap and cutoff >= ratio: - kv_state[:ratio] = kv[cutoff - ratio : cutoff] - score_state[:ratio] = score[cutoff - ratio : cutoff] + ape.float() - if remainder > 0: - kv_state[offset : offset + remainder] = kv[cutoff:] - score_state[offset : offset + remainder] = score[cutoff:] + ape.float()[:remainder] - return kv_state, score_state - - -def _sglang_prefill_from_kv_score( - kv, - score, +def _compressor_weight(wkv_w, wgate_w): + return torch.cat([wkv_w, wgate_w], dim=0).contiguous() + + +def _project_kv_score(x, wkv_w, wgate_w): + _, linear_bf16_fp32 = _load_sglang_compressor() + return linear_bf16_fp32(x, _compressor_weight(wkv_w, wgate_w)) + + +def _state_pool_view(state_pool): + if state_pool is None: + raise RuntimeError("DeepSeek-V4 fused compressor requires a persistent state_pool") + if state_pool.dim() == 4 and state_pool.shape[1] == 1: + return state_pool.squeeze(1) + return state_pool + + +def compressor_prefill_state( + x, + wkv_w, + wgate_w, norm_w, ape, ratio, @@ -179,48 +83,42 @@ def _sglang_prefill_from_kv_score( cos_table, sin_table, eps, - dtype, - state_pool=None, + state_pool, ): - if not kv.is_cuda or head_dim % 128 != 0 or ratio not in (4, 128): - return None, None - mod = _load_sglang_compressor() - kv_score = _pack_kv_score(kv, score, ratio, head_dim) - ape_sglang = _sglang_ape(ape.float(), ratio, head_dim) - slots = 8 if ratio == 4 else ratio - if state_pool is None: - state_pool = torch.zeros((1, slots, kv_score.shape[1]), device=kv.device, dtype=kv_score.dtype) - else: - state_pool.zero_() - seq_len = kv.shape[0] + """start_pos==0 prefill for ONE request: x [s, dim] -> compressed entries [s//ratio, head_dim] + (rope applied). state_pool is the request's persistent jit state slice [1, slots, coff*2*head_dim]; + it is rebuilt in place so the decode path can continue from the trailing partial window.""" + mod, _ = _load_sglang_compressor() + kv_score = _project_kv_score(x, wkv_w, wgate_w) + pool = _state_pool_view(state_pool) + pool.zero_() + seq_len = x.shape[0] plan = mod.CompressorPrefillPlan.generate( ratio, seq_len, torch.tensor([seq_len], dtype=torch.int64), torch.tensor([seq_len], dtype=torch.int64), - kv.device, + x.device, ) - indices = torch.zeros((1,), device=kv.device, dtype=torch.int32) + indices = torch.zeros((1,), device=x.device, dtype=torch.int32) out = mod.compress_forward( - state_pool, + pool, kv_score, - ape_sglang, + _sglang_ape(ape.float(), ratio, head_dim), indices, plan, head_dim=head_dim, compress_ratio=ratio, ) ncomp = seq_len // ratio - if ncomp: - mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - ragged_ids = plan.compress_plan.view(torch.int32)[:ncomp, 0].long() - out = out.index_select(0, ragged_ids).to(dtype) - else: - out = kv.new_zeros(0, head_dim).to(dtype) - return out, state_pool + if ncomp == 0: + return x.new_zeros(0, head_dim) + mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) + ragged_ids = plan.compress_plan.view(torch.int32)[:ncomp, 0].long() + return out.index_select(0, ragged_ids).to(x.dtype) -def _sglang_decode_step_from_state_pool( +def compressor_decode_step_single( x_new, wkv_w, wgate_w, @@ -231,17 +129,14 @@ def _sglang_decode_step_from_state_pool( cos_table, sin_table, eps, - start_pos, state_pool, + start_pos, ): - if state_pool is None or not x_new.is_cuda or head_dim % 128 != 0 or ratio not in (4, 128): - return None, False - mod = _load_sglang_compressor() - xf = x_new.float().view(1, -1) - kv = F.linear(xf, wkv_w.float()) - score = F.linear(xf, wgate_w.float()) - kv_score = _pack_kv_score(kv, score, ratio, head_dim) - ape_sglang = _sglang_ape(ape.float(), ratio, head_dim) + """One token for ONE request (chunked-prefill extend path). Returns the finished compressed + entry [head_dim] when (start_pos+1) % ratio == 0, else None. Mutates state_pool in place.""" + mod, _ = _load_sglang_compressor() + kv_score = _project_kv_score(x_new.view(1, -1), wkv_w, wgate_w) + pool = _state_pool_view(state_pool) seq_len = start_pos + 1 plan = mod.CompressorDecodePlan( ratio, @@ -249,66 +144,22 @@ def _sglang_decode_step_from_state_pool( ) indices = torch.zeros((1,), device=x_new.device, dtype=torch.int32) out = mod.compress_forward( - state_pool, + pool, kv_score, - ape_sglang, + _sglang_ape(ape.float(), ratio, head_dim), indices, plan, head_dim=head_dim, compress_ratio=ratio, ) if seq_len % ratio != 0: - return None, True + return None mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - return out[0].to(x_new.dtype), True - - -def compress_prefill(x, wkv_w, wgate_w, norm_w, ape, ratio, head_dim, rope_dim, cos_table, sin_table, eps): - """x:[s,dim] (one request, start_pos=0) -> compressed kv [nwin, head_dim] (rope applied to last rope_dim). - - nwin = s // ratio (remainder tokens are decode-state, handled in the decode path). wkv_w/wgate_w: - [coff*head_dim, dim]; norm_w:[head_dim]; ape:[ratio, coff*head_dim]; cos_table/sin_table: compress rope tables. - """ - overlap = ratio == 4 - coff = 2 if overlap else 1 - d = head_dim - s = x.shape[0] - nwin = s // ratio - if nwin == 0: - # fewer than `ratio` tokens -> no completed window -> no compressed entry (matches reference) - return x.new_zeros(0, head_dim) - cutoff = nwin * ratio - xf = x.float() - kv = F.linear(xf, wkv_w.float())[:cutoff].view(nwin, ratio, coff * d) - score = F.linear(xf, wgate_w.float())[:cutoff].view(nwin, ratio, coff * d) + ape.float() - if overlap: - kv = _overlap_transform(kv, ratio, d, 0.0) - score = _overlap_transform(score, ratio, d, float("-inf")) - kv = (kv * torch.softmax(score, dim=1)).sum(dim=1) # [nwin, d] fp32 - kv = _rmsnorm(kv.to(x.dtype), norm_w, eps) # [nwin, d] - pos = torch.arange(nwin, device=x.device) * ratio - kv_rope = apply_rotary_emb(kv[:, -rope_dim:], cos_table[pos], sin_table[pos]) # cos/sin: [nwin, rope_dim//2] - return torch.cat([kv[:, :-rope_dim], kv_rope], dim=1) - - -def new_compressor_state(ratio, head_dim, device, dtype=torch.float32): - """Per-request compressor running state (matches reference Compressor.kv_state/score_state).""" - coff = 2 if ratio == 4 else 1 - kv_state = torch.zeros(coff * ratio, coff * head_dim, device=device, dtype=dtype) - score_state = torch.full((coff * ratio, coff * head_dim), float("-inf"), device=device, dtype=dtype) - return kv_state, score_state - - -def _finish_entry(kv, norm_w, ape_unused, rope_dim, cos_table, sin_table, position, eps, dtype): - kv = _rmsnorm(kv.to(dtype), norm_w, eps) # [d] - cos = cos_table[position : position + 1] # [1, rope_dim//2] - sin = sin_table[position : position + 1] - kv_rope = apply_rotary_emb(kv[-rope_dim:].unsqueeze(0), cos, sin)[0] - return torch.cat([kv[:-rope_dim], kv_rope], dim=0) + return out[0].to(x_new.dtype) -def compressor_prefill_state( - x, +def compressor_decode_step_batch( + x_new, wkv_w, wgate_w, norm_w, @@ -319,207 +170,193 @@ def compressor_prefill_state( cos_table, sin_table, eps, - return_state_pool=False, - state_pool=None, + state_pool, + b_req_idx, + start_pos, ): - """Faithful reference start_pos==0 path (incl. remainder). Returns (entries[ncomp,d], kv_state, score_state). - - entries have rope applied; kv_state/score_state carry the partial window for the decode path. - """ - overlap = ratio == 4 - coff = 2 if overlap else 1 - d = head_dim - s = x.shape[0] - dtype = x.dtype - xf = x.float() - kv = F.linear(xf, wkv_w.float()) # [s, coff*d] - score = F.linear(xf, wgate_w.float()) # [s, coff*d] - ape = ape.float() - kv_state, score_state = _build_state_from_kv_score(kv, score, ape, ratio, head_dim) - sglang_state_pool = state_pool - try: - comp, sglang_state_pool = _sglang_prefill_from_kv_score( - kv, - score, - norm_w, - ape, - ratio, - head_dim, - cos_table, - sin_table, - eps, - dtype, - state_pool=sglang_state_pool, - ) - if comp is not None: - if return_state_pool: - return comp, kv_state, score_state, sglang_state_pool - return comp, kv_state, score_state - except Exception as exc: - _warn_sglang_fallback(exc) - - should_compress = s >= ratio - remainder = s % ratio - cutoff = s - remainder - if remainder > 0: - kv = kv[:cutoff] - score = score[:cutoff] - if not should_compress: - comp = x.new_zeros(0, head_dim) - if return_state_pool: - return comp, kv_state, score_state, sglang_state_pool - return comp, kv_state, score_state - nwin = cutoff // ratio - kvw = kv.view(nwin, ratio, coff * d) - scw = score.view(nwin, ratio, coff * d) + ape - if overlap: - kvw = _overlap_transform(kvw, ratio, d, 0.0) - scw = _overlap_transform(scw, ratio, d, float("-inf")) - comp = (kvw * torch.softmax(scw, dim=1)).sum(dim=1) # [nwin, d] fp32 - comp = _rmsnorm(comp.to(dtype), norm_w, eps) - pos = torch.arange(nwin, device=x.device) * ratio - comp_rope = apply_rotary_emb(comp[:, -rope_dim:], cos_table[pos], sin_table[pos]) - comp = torch.cat([comp[:, :-rope_dim], comp_rope], dim=1) - if return_state_pool: - return comp, kv_state, score_state, sglang_state_pool - return comp, kv_state, score_state - - -def compressor_decode_step( - x_new, + mod, _ = _load_sglang_compressor() + kv_score = _project_kv_score(x_new, wkv_w, wgate_w) + pool = _state_pool_view(state_pool) + seq_lens = (start_pos + 1).to(torch.int32).contiguous() + plan = mod.CompressorDecodePlan(ratio, seq_lens) + out = mod.compress_forward( + pool, + kv_score, + _sglang_ape(ape.float(), ratio, head_dim), + b_req_idx.to(torch.int32).contiguous(), + plan, + head_dim=head_dim, + compress_ratio=ratio, + ) + should_compress = (seq_lens % ratio) == 0 + mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) + return out.to(x_new.dtype), should_compress + + +# ---------------------------------------------------------------------------- paged state (c4) +# 与 sglang srt compressor 的 paged 路径同构(compress_old 内核 + 分组槽 indices + overlap +# extra_data): state 槽位由 swa 槽位算术派生(翻译③ state_loc = page*ring + swa_loc%ring, +# 分组槽 = state_loc//ratio),state 随 swa 页生灭,radix 命中零拷贝续算。 + + +def paged_state_rows(num_swa_pages: int, ring: int, ratio: int) -> int: + """state 池行数 = 页数*ring + ring(HOLD 页) + 1(哨兵行),向上取整到 ratio 整除 + (分组视图 [-1, ratio, last_dim] 需要)。与 sglang CompressStatePool 的 _size 公式一致。""" + rows = num_swa_pages * ring + ring + 1 + return (rows + ratio - 1) // ratio * ratio + + +def init_paged_state_pool(buffer: torch.Tensor) -> None: + """末行为哨兵: kv 半边置 0、score 半边置 -inf(KVAndScore.clear 语义)。其余行无需初始化 + (内核在组起点覆写)。buffer: [rows, 2*coff*head_dim] fp32。""" + half = buffer.shape[-1] // 2 + buffer[-1, :half].zero_() + buffer[-1, half:].fill_(float("-inf")) + return + + +def _paged_state_group_slot(req_to_token, full_to_swa, b_req_idx, positions, page_size, ring, ratio): + """位置 -> state 分组槽(= sglang create_paged_compressor_data.get_raw_loc): + state_loc = (swa_loc//page)*ring + swa_loc%ring; 分组槽 = state_loc//ratio。 + 负位置按 sglang 语义 mask 到 0;已出窗(swa_loc<0)的位置落到 -1(哨兵行,score=-inf)。""" + positions = positions.masked_fill(positions < 0, 0) + full = req_to_token[b_req_idx.long(), positions] + swa_loc = full_to_swa[full.long()].long() + state_loc = torch.div(swa_loc, page_size, rounding_mode="floor") * ring + swa_loc % ring + state_loc = torch.where(swa_loc < 0, torch.full_like(state_loc, -1), state_loc) + return torch.div(state_loc, ratio, rounding_mode="floor").to(torch.int32) + + +def paged_decode_state_slots( + req_to_token, + full_to_swa, + b_req_idx, + b_seq_len, + page_size: int, + ring: int, + ratio: int, + hold_req_id: int, + num_swa_pages: int, +): + """decode 步的 state 分组槽(写槽 = 当前组 clip_down(seq-1) 的槽,overlap 伙伴 = 前一组)。 + 纯张量算术(prep 已写本步 req_to_token),图安全。padding(HOLD)行重定向到 HOLD 页的 + state 槽,隔离其垃圾累加。""" + seq = b_seq_len.long() + write_positions = torch.div(seq - 1, ratio, rounding_mode="floor") * ratio + write_slot = _paged_state_group_slot(req_to_token, full_to_swa, b_req_idx, write_positions, page_size, ring, ratio) + overlap_slot = _paged_state_group_slot( + req_to_token, full_to_swa, b_req_idx, write_positions - ratio, page_size, ring, ratio + ) + hold_slot = num_swa_pages * ring // ratio # HOLD 页区域([pages*ring, pages*ring+ring))的首个分组槽 + is_hold = b_req_idx.long() == hold_req_id + write_slot = torch.where(is_hold, torch.full_like(write_slot, hold_slot), write_slot) + overlap_slot = torch.where(is_hold, torch.full_like(overlap_slot, hold_slot), overlap_slot) + return write_slot, overlap_slot + + +def paged_prefill_compress_data(req_to_token, full_to_swa, req_idx: int, ready_len: int, seq_len: int, ring: int): + """单请求 prefill chunk 的 (indices, extra_data, plan): 与 sglang 同走 + triton_create_paged_compress_data(按请求产出,内核经 plan 逐 token 步进)。仅 c4(overlap)。 + 三者都与层无关,同一 forward 内可跨全部 c4 层复用。""" + mod, _ = _load_sglang_compressor() + fn = _load_paged_compress_data_fn() + device = req_to_token.device + n_new = seq_len - ready_len + write_loc, extra_data = fn( + compress_ratio=4, + is_overlap=True, + swa_page_size=128, + ring_size=ring, + req_pool_indices=torch.tensor([req_idx], device=device, dtype=torch.int64), + seq_lens=torch.tensor([seq_len], device=device, dtype=torch.int64), + extend_seq_lens=torch.tensor([n_new], device=device, dtype=torch.int64), + req_to_token=req_to_token, + full_to_swa_index_mapping=full_to_swa, + ) + plan = mod.CompressorPrefillPlan.generate( + 4, + n_new, + torch.tensor([seq_len], dtype=torch.int64), + torch.tensor([n_new], dtype=torch.int64), + device, + ) + return write_loc, extra_data, plan + + +def compressor_paged_prefill( + x, wkv_w, wgate_w, norm_w, ape, - ratio, head_dim, - rope_dim, cos_table, sin_table, eps, - kv_state, - score_state, - start_pos, - state_pool=None, + state_buffer, + compress_data, + ready_len, + seq_len, ): - """Faithful reference start_pos>0 path for one new token. Mutates kv_state/score_state in place. - Returns the new compressed entry [d] (rope applied) when a window completes, else None. - """ - overlap = ratio == 4 - d = head_dim - dtype = x_new.dtype - try: - entry, handled = _sglang_decode_step_from_state_pool( - x_new, - wkv_w, - wgate_w, - norm_w, - ape, - ratio, - head_dim, - cos_table, - sin_table, - eps, - start_pos, - state_pool, - ) - if handled: - return entry - except Exception as exc: - _warn_sglang_fallback(exc) - - xf = x_new.float().view(-1) # [dim] - kv = F.linear(xf, wkv_w.float()) # [coff*d] - score = F.linear(xf, wgate_w.float()) + ape.float()[start_pos % ratio] # [coff*d] - should_compress = (start_pos + 1) % ratio == 0 - if overlap: - kv_state[ratio + start_pos % ratio] = kv - score_state[ratio + start_pos % ratio] = score - if should_compress: - kv_cat = torch.cat([kv_state[:ratio, :d], kv_state[ratio:, d:]], dim=0) # [2*ratio, d] - sc_cat = torch.cat([score_state[:ratio, :d], score_state[ratio:, d:]], dim=0) - entry = (kv_cat * torch.softmax(sc_cat, dim=0)).sum(dim=0) # [d] - kv_state[:ratio] = kv_state[ratio:] - score_state[:ratio] = score_state[ratio:] - else: - kv_state[start_pos % ratio] = kv - score_state[start_pos % ratio] = score - if should_compress: - entry = (kv_state * torch.softmax(score_state, dim=0)).sum(dim=0) # [d] - if not should_compress: - return None - return _finish_entry( - entry, - norm_w, - ape, - rope_dim, - cos_table, - sin_table, - start_pos + 1 - ratio, - eps, - dtype, + """单请求 prefill/extend chunk(c4 paged): x [n_new, dim] 为位置 [ready, seq) 的 hidden, + state 写到 swa 派生的分组槽(compress_data 来自 paged_prefill_compress_data,跨层复用)。 + 返回本 chunk 完结组的压缩条目 [seq//4 - ready//4, head_dim](rope 已施加)。""" + mod, _ = _load_sglang_compressor() + ratio = 4 + kv_score = _project_kv_score(x, wkv_w, wgate_w) + pool = state_buffer.view(-1, ratio, state_buffer.shape[-1]) + write_loc, extra_data, plan = compress_data + out = mod.compress_forward( + pool, + kv_score, + _sglang_ape(ape.float(), ratio, head_dim), + write_loc, + plan, + head_dim=head_dim, + compress_ratio=ratio, + extra_data=extra_data, ) + ncomp = seq_len // ratio - ready_len // ratio + if ncomp == 0: + return x.new_zeros(0, head_dim) + mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) + ragged_ids = plan.compress_plan.view(torch.int32)[:ncomp, 0].long() + return out.index_select(0, ragged_ids).to(x.dtype) -def compressor_decode_step_batch( +def compressor_paged_decode_batch( x_new, wkv_w, wgate_w, norm_w, ape, - ratio, head_dim, - rope_dim, cos_table, sin_table, eps, - state_all, - b_req_idx, - start_pos, + state_buffer, + write_slot, + overlap_slot, + b_seq_len, ): - """Graph-safe batch decode compressor step. - - Mutates ``state_all`` for the selected request rows and returns one candidate - entry per batch row plus a boolean mask telling which rows closed a - compression window. - """ - overlap = ratio == 4 - d = head_dim - dtype = x_new.dtype - req = b_req_idx.long() - pos = start_pos.long() - pos_mod = pos % ratio - - xf = x_new.float() - kv = F.linear(xf, wkv_w.float()) - score = F.linear(xf, wgate_w.float()) + ape.float().index_select(0, pos_mod) - - kv_state = state_all[req, 0].clone() - score_state = state_all[req, 1].clone() - row = pos_mod + (ratio if overlap else 0) - batch_ids = torch.arange(x_new.shape[0], device=x_new.device) - kv_state[batch_ids, row] = kv - score_state[batch_ids, row] = score - - should_compress = ((pos + 1) % ratio) == 0 - if overlap: - kv_cat = torch.cat([kv_state[:, :ratio, :d], kv_state[:, ratio:, d:]], dim=1) - score_cat = torch.cat([score_state[:, :ratio, :d], score_state[:, ratio:, d:]], dim=1) - entry = (kv_cat * torch.softmax(score_cat, dim=1)).sum(dim=1) - shifted_kv_state = kv_state.clone() - shifted_score_state = score_state.clone() - shifted_kv_state[:, :ratio] = kv_state[:, ratio:] - shifted_score_state[:, :ratio] = score_state[:, ratio:] - kv_state = torch.where(should_compress.view(-1, 1, 1), shifted_kv_state, kv_state) - score_state = torch.where(should_compress.view(-1, 1, 1), shifted_score_state, score_state) - else: - entry = (kv_state * torch.softmax(score_state, dim=1)).sum(dim=1) - - state_all[req, 0] = kv_state - state_all[req, 1] = score_state - - entry = _rmsnorm(entry.to(dtype), norm_w, eps) - comp_pos = torch.clamp(pos + 1 - ratio, min=0) - entry_rope = apply_rotary_emb(entry[:, -rope_dim:], cos_table[comp_pos], sin_table[comp_pos]) - entry = torch.cat([entry[:, :-rope_dim], entry_rope], dim=1) - return entry, should_compress + """批量 decode 一步(c4 paged): state 槽位为 swa 派生分组槽(paged_decode_state_slots, + 可跨层复用)。返回 (entries [bs, head_dim], should_compress [bs])。""" + mod, _ = _load_sglang_compressor() + ratio = 4 + kv_score = _project_kv_score(x_new, wkv_w, wgate_w) + pool = state_buffer.view(-1, ratio, state_buffer.shape[-1]) + seq_lens = b_seq_len.to(torch.int32).contiguous() + plan = mod.CompressorDecodePlan(ratio, seq_lens) + out = mod.compress_forward( + pool, + kv_score, + _sglang_ape(ape.float(), ratio, head_dim), + write_slot, + plan, + head_dim=head_dim, + compress_ratio=ratio, + extra_data=overlap_slot.view(-1, 1), + ) + should_compress = (seq_lens % ratio) == 0 + mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) + return out.to(x_new.dtype), should_compress diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py index 78cdb3a3f8..b125e9ed06 100644 --- a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py +++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py @@ -1,50 +1,75 @@ import torch +try: + import vllm.model_executor.layers.mhc # noqa: F401 +except Exception as e: + raise RuntimeError("DeepSeek-V4 requires vLLM mHC custom ops; failed to import vllm MHC kernels") from e -def _ensure_vllm_mhc_ops(): - try: - import vllm.model_executor.layers.mhc # noqa: F401 - except Exception as e: - raise RuntimeError("DeepSeek-V4 requires vLLM mHC custom ops; failed to import vllm MHC kernels") from e +# vllm DeepseekV4DecoderLayer.hc_post_alpha +HC_POST_ALPHA = 2.0 -def hc_pre(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, eps, sinkhorn_iters): - """streams:[N, hc*dim] -> (collapsed[N,dim], post[N,hc,1], comb[N,hc,hc]).""" - _ensure_vllm_mhc_ops() - post, comb, collapsed = torch.ops.vllm.mhc_pre( - residual=streams.view(-1, hc_mult, dim).contiguous(), + +def hc_pre(residual, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps): + """Standalone hc_pre for the first layer. residual:[T, hc, dim] -> + (x[T,dim], residual, post_mix[T,hc,1], res_mix[T,hc,hc]); the sub-layer RMSNorm is fused via norm_weight.""" + post_mix, res_mix, x = torch.ops.vllm.mhc_pre_tilelang( + residual=residual, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=rms_eps, + hc_pre_eps=hc_eps, + hc_sinkhorn_eps=hc_eps, + hc_post_mult_value=HC_POST_ALPHA, + sinkhorn_repeat=sinkhorn_iters, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + return x, residual, post_mix, res_mix + + +def hc_fused_post_pre( + x, residual, post_mix, res_mix, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps +): + """hc_post of the previous sub-layer fused with hc_pre of the next one (norm fused too). + Returns (x[T,dim], residual[T,hc,dim], post_mix, res_mix).""" + residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre_tilelang( + x=x, + residual=residual, + post_layer_mix=post_mix, + comb_res_mix=res_mix, fn=hc_fn, hc_scale=hc_scale, hc_base=hc_base, - rms_eps=eps, - hc_pre_eps=eps, - hc_sinkhorn_eps=eps, - hc_post_mult_value=2.0, + rms_eps=rms_eps, + hc_pre_eps=hc_eps, + hc_sinkhorn_eps=hc_eps, + hc_post_mult_value=HC_POST_ALPHA, sinkhorn_repeat=sinkhorn_iters, + norm_weight=norm_weight, + norm_eps=norm_eps, ) - return collapsed, post, comb + return x, residual, post_mix, res_mix -def hc_post(x, residual, post, comb, hc_mult, dim): - """x:[N,dim] sub-layer output, residual:[N, hc*dim] -> [N, hc*dim].""" - _ensure_vllm_mhc_ops() - out = torch.ops.vllm.mhc_post(x, residual.view(-1, hc_mult, dim).contiguous(), post, comb) - return out.reshape(-1, hc_mult * dim) +def hc_post(x, residual, post_mix, res_mix): + """Complete the hc_post left pending by the last layer. -> streams [T, hc, dim].""" + return torch.ops.vllm.mhc_post_tilelang(x, residual, post_mix, res_mix) -def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, eps): +def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, rms_eps, hc_eps): """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim].""" - _ensure_vllm_mhc_ops() out = torch.empty(streams.shape[0], dim, device=streams.device, dtype=streams.dtype) - torch.ops.vllm.hc_head_fused_kernel( + torch.ops.vllm.hc_head_fused_kernel_tilelang( streams.view(-1, hc_mult, dim).contiguous(), hc_fn, hc_scale, hc_base, out, dim, - eps, - eps, + rms_eps, + hc_eps, hc_mult, ) return out diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py index c23d03afb7..bc95c249f7 100644 --- a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py @@ -1,5 +1,5 @@ from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from .hyper_connection import hc_head +from .hyper_connection import hc_head, hc_post from ..infer_struct import DeepseekV4InferStateInfo @@ -8,6 +8,11 @@ class DeepseekV4PostLayerInfer(LlamaPostLayerInfer): def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): cfg = layer_weight.network_config_ + if isinstance(input_embdings, tuple): + # truncated-layer runs (autotune warmup) end before the last layer's _hc_ffn_out + # collapse; finish the pending hc_post here. + streams = hc_post(*input_embdings) + input_embdings = streams.reshape(streams.shape[0], -1) collapsed = hc_head( input_embdings, layer_weight.hc_head_fn_.weight, @@ -15,6 +20,7 @@ def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, l layer_weight.hc_head_base_.weight, cfg["hc_mult"], cfg["hidden_size"], + cfg["rms_norm_eps"], cfg.get("hc_eps", 1e-6), ) return super().token_forward(collapsed, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py index d83e3082b8..b95f5a14a8 100644 --- a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py @@ -8,16 +8,17 @@ class DeepseekV4PreLayerInfer(LlamaPreLayerInfer): """Token embedding, then expand to the hc_mult parallel residual streams [T, hc_mult*hidden].""" - def _embed_and_expand(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): - emb = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) # [T, hidden] - if self.tp_world_size_ > 1: - all_reduce(emb, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - hc_mult = layer_weight.network_config_["hc_mult"] - t, hidden = emb.shape - return emb.unsqueeze(1).expand(t, hc_mult, hidden).reshape(t, hc_mult * hidden).contiguous() + def __init__(self, network_config): + super().__init__(network_config) + self.hc_mult = network_config["hc_mult"] + return def context_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): - return self._embed_and_expand(input_ids, infer_state, layer_weight) + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + t, hidden = input_embdings.shape + return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden) def token_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): - return self._embed_and_expand(input_ids, infer_state, layer_weight) + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + t, hidden = input_embdings.shape + return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 11209d39fd..6f7c0de3fc 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -2,28 +2,25 @@ import torch.nn.functional as F import torch.distributed as dist from lightllm.common.basemodel import TransformerLayerInferTpl +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.distributed.communication_op import all_reduce from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor -from .hyper_connection import hc_pre, hc_post +from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post +from .compressor import ( + compressor_prefill_state, + compressor_decode_step_single, + compressor_decode_step_batch, + compressor_paged_prefill, + compressor_paged_decode_batch, + paged_prefill_compress_data, + paged_decode_state_slots, +) from ..triton_kernel.rotary_emb import apply_rotary_emb from ..infer_struct import DeepseekV4InferStateInfo -from .compressor import compressor_prefill_state, compressor_decode_step, compressor_decode_step_batch -from .attention import vllm_sparse_attn_flat class DeepseekV4TransformerLayerInfer(TransformerLayerInferTpl): - """One V4 decoder layer: HC(attn) then HC(ffn). - - The residual is carried as ``hc_mult`` streams flattened to [T, hc_mult*hidden]; each sub-layer - collapses (hc_pre), computes, and re-expands (hc_post). Attention is MLA over a sliding window + - compressed KV with a per-head sink (vLLM FlashMLA sparse); the MoE reuses lightllm's deepgemm FP8 - grouped GEMM driven by V4's custom router (sqrtsoftplus + hash/topk + bias-for-selection). - - Per-request decode state (window KV history + compressed KV + compressor running state) is kept in - DeepseekV4ReqManager so request alloc/free owns its lifetime. - """ - def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) cfg = network_config @@ -43,6 +40,14 @@ def __init__(self, layer_num, network_config): self.window = cfg["sliding_window"] self.compress_ratio = cfg["compress_ratios"][layer_num] self.is_hash = layer_num < cfg["num_hash_layers"] + self.is_last_layer = layer_num == cfg["n_layer"] - 1 + # complex64 rope table for this layer's variant (sliding / compressed); set by + # DeepseekV4TpPartModel._init_to_get_rotary once the tables are built. The full compress + # cos/sin tables (compressor entry rope uses entry positions, not token positions) are + # wired there too. + self.freqs_cis = None + self.cos_compress_table = None + self.sin_compress_table = None self.topk = cfg["num_experts_per_tok"] self.route_scale = cfg["routed_scaling_factor"] self.swiglu_limit = cfg["swiglu_limit"] @@ -60,47 +65,78 @@ def __init__(self, layer_num, network_config): self.indexer_score_scale = self.index_head_dim ** -0.5 self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 - # ------------------------------------------------------------------ forward (HC-wrapped) - def _hc_forward(self, streams, infer_state: DeepseekV4InferStateInfo, lw, attn_forward): - residual = streams - collapsed, post, comb = hc_pre( - streams, - lw.hc_attn_fn_.weight, - lw.hc_attn_scale_.weight, - lw.hc_attn_base_.weight, - self.hc_mult, - self.hidden, + # ------------------------------------------------------------------ forward (HC-threaded) + def _hc_attn_in(self, input_embdings, layer_weight): + """Layer input -> attention input (attn_norm fused). First layer gets the raw streams + and runs a standalone hc_pre; later layers get (x, residual, post_mix, res_mix) and fuse + the previous layer's ffn hc_post with this layer's attn hc_pre.""" + if torch.is_tensor(input_embdings): + residual = input_embdings.view(-1, self.hc_mult, self.hidden) + return hc_pre( + residual, + layer_weight.hc_attn_fn_.weight, + layer_weight.hc_attn_scale_.weight, + layer_weight.hc_attn_base_.weight, + self.eps_, + self.hc_eps, + self.sinkhorn_iters, + layer_weight.attn_norm_.weight, + self.eps_, + ) + x, residual, post_mix, res_mix = input_embdings + return hc_fused_post_pre( + x, + residual, + post_mix, + res_mix, + layer_weight.hc_attn_fn_.weight, + layer_weight.hc_attn_scale_.weight, + layer_weight.hc_attn_base_.weight, + self.eps_, self.hc_eps, self.sinkhorn_iters, + layer_weight.attn_norm_.weight, + self.eps_, ) - o = attn_forward(self._att_norm(collapsed, infer_state, lw), infer_state, lw) - streams = hc_post(o, residual, post, comb, self.hc_mult, self.hidden) - - residual = streams - collapsed, post, comb = hc_pre( - streams, - lw.hc_ffn_fn_.weight, - lw.hc_ffn_scale_.weight, - lw.hc_ffn_base_.weight, - self.hc_mult, - self.hidden, + + def _hc_ffn_in(self, x, residual, post_mix, res_mix, layer_weight): + """Attention output -> ffn input (ffn_norm fused): fused attn hc_post + ffn hc_pre.""" + return hc_fused_post_pre( + x, + residual, + post_mix, + res_mix, + layer_weight.hc_ffn_fn_.weight, + layer_weight.hc_ffn_scale_.weight, + layer_weight.hc_ffn_base_.weight, + self.eps_, self.hc_eps, self.sinkhorn_iters, + layer_weight.ffn_norm_.weight, + self.eps_, ) - f = self._ffn(self._ffn_norm(collapsed, infer_state, lw), infer_state, lw) - return hc_post(f, residual, post, comb, self.hc_mult, self.hidden) - - def context_forward(self, streams, infer_state: DeepseekV4InferStateInfo, lw): - return self._hc_forward(streams, infer_state, lw, self.context_attention_forward) - def token_forward(self, streams, infer_state: DeepseekV4InferStateInfo, lw): - return self._hc_forward(streams, infer_state, lw, self.token_attention_forward) - - def _att_norm(self, x, infer_state: DeepseekV4InferStateInfo, lw): - return lw.attn_norm_(x, eps=self.eps_) - - def _ffn_norm(self, x, infer_state: DeepseekV4InferStateInfo, lw): - return lw.ffn_norm_(x, eps=self.eps_) + def _hc_ffn_out(self, x, residual, post_mix, res_mix): + """Mid layers leave the ffn hc_post pending for the next layer's fused post+pre; the last + layer completes it and hands the flat streams [T, hc_mult*hidden] back to the model loop.""" + if not self.is_last_layer: + return x, residual, post_mix, res_mix + streams = hc_post(x, residual, post_mix, res_mix) + return streams.reshape(streams.shape[0], -1) + + def context_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): + x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight) + x = self.context_attention_forward(x, infer_state, layer_weight) + x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight) + x = self._ffn(x, infer_state, layer_weight) + return self._hc_ffn_out(x, residual, post_mix, res_mix) + + def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): + x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight) + x = self.token_attention_forward(x, infer_state, layer_weight) + x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight) + x = self._ffn(x, infer_state, layer_weight) + return self._hc_ffn_out(x, residual, post_mix, res_mix) # ------------------------------------------------------------------ shared projections / cache def _select_rope(self, infer_state: DeepseekV4InferStateInfo): @@ -108,20 +144,18 @@ def _select_rope(self, infer_state: DeepseekV4InferStateInfo): return infer_state.position_cos_compress, infer_state.position_sin_compress return infer_state.position_cos_sliding, infer_state.position_sin_sliding - def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, lw): + def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + from sglang.jit_kernel.dsv4 import fused_q_norm_rope + cos_tok, sin_tok = self._select_rope(infer_state) T = x.shape[0] - qa = lw.q_norm_(lw.wq_a_.mm(x), eps=self.eps_) - q = lw.wq_b_.mm(qa).view(T, self.tp_q_heads, self.head_dim).float() - q = (q * torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps_)).to(x.dtype) - q = torch.cat( - [ - q[..., : -self.rope_dim], - apply_rotary_emb(q[..., -self.rope_dim :], cos_tok.unsqueeze(1), sin_tok.unsqueeze(1)), - ], - dim=-1, - ) - kv = lw.kv_norm_(lw.wkv_.mm(x), eps=self.eps_) + qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(x), eps=self.eps_) + q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_heads, self.head_dim) + # per-(token, head) weightless self-RMSNorm + interleaved rope on the last rope_dim dims, + # fused in one sglang dsv4 jit kernel (fp32 norm/rotation, bf16 in between -- same as eager). + q = torch.empty_like(q_in) + fused_q_norm_rope(q_in, q, self.eps_, self.freqs_cis, infer_state.position_ids) + kv = layer_weight.kv_norm_(layer_weight.wkv_.mm(x), eps=self.eps_) kv = torch.cat( [ kv[:, : -self.rope_dim], @@ -131,12 +165,12 @@ def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, lw): ) return q, kv, qa, cos_tok, sin_tok - def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, lw): + def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight): # o: [T, tp_q_heads, head_dim] after inverse rope -> grouped low-rank O -> [T, hidden] T = o.shape[0] o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] - o = lw.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] - o = lw.wo_b_.mm(o) + o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] + o = layer_weight.wo_b_.mm(o) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return o @@ -155,813 +189,400 @@ def _inv_rope(self, o, cos_tok, sin_tok): dim=-1, ) - def _post_cache_kv( - self, cache_kv, infer_state: DeepseekV4InferStateInfo, lw, req_idx=None, start_pos=None, mem_index=None - ): - if req_idx is None or start_pos is None or mem_index is None: - raise RuntimeError("DeepSeek-V4 cache write requires req_idx, start_pos, and mem_index") - positions = torch.arange( - start_pos, - start_pos + cache_kv.shape[0], - device=mem_index.device, - dtype=torch.long, - ) - infer_state.mem_manager.pack_mla_kv_to_cache( - layer_index=self.layer_num_, - mem_index=mem_index, - kv=cache_kv.reshape(cache_kv.shape[0], 1, cache_kv.shape[-1]), - req_idx=req_idx, - positions=positions, + # ------------------------------------------------------------------ compressor / indexer + def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + if self.compress_ratio != 4: + return None, None + cos_tok = infer_state.position_cos_compress + sin_tok = infer_state.position_sin_compress + idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_heads, self.index_head_dim) + idx_q = torch.cat( + [ + idx_q[..., : -self.rope_dim], + apply_rotary_emb(idx_q[..., -self.rope_dim :], cos_tok.unsqueeze(1), sin_tok.unsqueeze(1)), + ], + dim=-1, ) - return + idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale + return idx_q, idx_weight - def _get_compressor_state(self, infer_state: DeepseekV4InferStateInfo, req): - cstate_kv, cstate_score = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) - state = { - "cstate_kv": cstate_kv, - "cstate_score": cstate_score, - } - if self.compress_ratio == 4: - idx_state = infer_state.req_manager.get_c4_indexer_compress_state(self.layer_num_) - state["idx_cstate_kv"] = idx_state[req, 0] - state["idx_cstate_score"] = idx_state[req, 1] - return state + def _gather_compress_slots(self, infer_state: DeepseekV4InferStateInfo, req, entry_start, entry_count): + """组末 token 的 full 槽位 -> 压缩槽(条目 [entry_start, entry_start+entry_count))。 + 槽位已由 prep 阶段(prepare_*_compress_slots)分配并 scatter 进 full_to_c4/c128_indexs。""" + ratio = self.compress_ratio + mem = infer_state.mem_manager + mapping = mem.full_to_c4_indexs if ratio == 4 else mem.full_to_c128_indexs + last = entry_start + entry_count + ends = infer_state.req_manager.req_to_token_indexs[req, ratio - 1 : last * ratio : ratio][entry_start:] + return mapping[ends.long()] def _write_compressed_kv(self, infer_state: DeepseekV4InferStateInfo, req, entry_start, comp): - slots = infer_state.req_manager.ensure_compress_slots(self.layer_num_, req, entry_start, comp.shape[0]) - if comp.shape[0] == 0: - return slots - infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, comp) + slots = self._gather_compress_slots(infer_state, req, entry_start, comp.shape[0]) + if comp.shape[0]: + infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, comp) return slots - def _write_c4_indexer_k(self, infer_state: DeepseekV4InferStateInfo, slots, idx_comp): - if idx_comp is None or idx_comp.shape[0] == 0: - return - infer_state.mem_manager.pack_c4_indexer_k_to_cache(self.layer_num_, slots, idx_comp) - return + def _compressor_weights(self, layer_weight, for_indexer: bool): + if for_indexer: + return ( + layer_weight.idx_cmp_wkv_.mm_param.weight, + layer_weight.idx_cmp_wgate_.mm_param.weight, + layer_weight.idx_cmp_norm_.weight, + layer_weight.idx_cmp_ape_.weight, + self.index_head_dim, + ) + return ( + layer_weight.compressor_wkv_.mm_param.weight, + layer_weight.compressor_wgate_.mm_param.weight, + layer_weight.compressor_norm_.weight, + layer_weight.compressor_ape_.weight, + self.head_dim, + ) - def _dense_kv_from_cache(self, infer_state: DeepseekV4InferStateInfo, req, start_pos, end_pos): - if end_pos <= start_pos: - return torch.empty((0, self.head_dim), dtype=infer_state.mem_manager.dtype, device="cuda") - slots = infer_state.req_manager.req_to_token_indexs[req, start_pos:end_pos].long() - return infer_state.mem_manager.gather_mla_kv(self.layer_num_, slots) + def _run_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + """Per-request compressor for the prefill chunk. Runs as part of the deferred attention + func, before the attention metadata gathers the slot mappings. - def _compressed_kv_from_cache(self, infer_state: DeepseekV4InferStateInfo, req, ncomp): - if ncomp == 0: - return torch.empty((0, self.head_dim), dtype=infer_state.mem_manager.dtype, device="cuda") + c4: paged state (swa-page-derived group slots, translation #3) — one fused extend-aware + call per request; the (write_loc, extra_data, plan) tuple is layer-independent and cached + on infer_state across all c4 layers. c128: req-keyed state (zero at every 128 boundary by + construction, nothing cache-resident), original jit paths.""" + if not self.compress_ratio: + return if self.compress_ratio == 4: - slots = infer_state.req_manager.req_to_c4_indexs[req, :ncomp].long() + self._run_c4_compressor_prefill(x, infer_state, layer_weight) else: - slots = infer_state.req_manager.req_to_c128_indexs[req, :ncomp].long() - return infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots) - - def _c4_indexer_k_from_cache(self, infer_state: DeepseekV4InferStateInfo, req, ncomp): - if self.compress_ratio != 4 or ncomp == 0: - return None - slots = infer_state.req_manager.req_to_c4_indexs[req, :ncomp].long() - return infer_state.mem_manager.gather_c4_indexer_k(self.layer_num_, slots) - - def _run_sparse_attention_batch(self, q_chunks, kv_chunks, index_chunks, sink): - q_flat = torch.cat(q_chunks, dim=0) - kv_flat = torch.cat(kv_chunks, dim=0) - max_topk = max(t.shape[-1] for t in index_chunks) - topk = torch.full( - (q_flat.shape[0], max_topk), - -1, - dtype=torch.int32, - device=q_flat.device, - ) - offset = 0 - for idx in index_chunks: - rows = idx.shape[0] - topk[offset : offset + rows, : idx.shape[1]] = idx.to(torch.int32) - offset += rows - return vllm_sparse_attn_flat(q_flat, kv_flat, sink, topk, self.softmax_scale) - - # ------------------------------------------------------------------ attention (prefill) - def context_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, lw): - q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, lw) - o = self._context_attention_wrapper_run(q, cache_kv, q_lora, x, infer_state, lw) - return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) - - def _context_attention_wrapper_run(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): - if torch.cuda.is_current_stream_capturing(): - q = q.contiguous() - cache_kv = cache_kv.contiguous() - q_lora = q_lora.contiguous() - x = x.contiguous() - _q = tensor_to_no_ref_tensor(q) - _cache_kv = tensor_to_no_ref_tensor(cache_kv) - _q_lora = tensor_to_no_ref_tensor(q_lora) - _x = tensor_to_no_ref_tensor(x) - - pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() - pre_capture_graph.__exit__(None, None, None) - - infer_state.prefill_cuda_graph_create_graph_obj() - infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() - o = torch.empty((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) - _o = tensor_to_no_ref_tensor(o) - - def att_func(new_infer_state: DeepseekV4InferStateInfo): - tmp_o = self._context_attention_kernel(_q, _cache_kv, _q_lora, _x, new_infer_state, lw) - assert tmp_o.shape == _o.shape - _o.copy_(tmp_o) - return - - infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph) - return o - - return self._context_attention_kernel(q, cache_kv, q_lora, x, infer_state, lw) + self._run_c128_compressor_prefill(x, infer_state, layer_weight) + return - def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): - T = x.shape[0] - sink = lw.attn_sink_.weight - o = x.new_empty(T, self.tp_q_heads, self.head_dim) + def _run_c4_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + rm = infer_state.req_manager + mem = infer_state.mem_manager + wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) + iwkv, iwgate, inorm, iape, _ = self._compressor_weights(layer_weight, for_indexer=True) + state_buf = mem.get_c4_state_buffer(self.layer_num_) + idx_state_buf = mem.get_c4_indexer_state_buffer(self.layer_num_) + data_cache = getattr(infer_state, "_dsv4_c4_prefill_data", None) + if data_cache is None: + data_cache = {} + infer_state._dsv4_c4_prefill_data = data_cache b_req = infer_state.b_req_idx.tolist() starts = infer_state.b_q_start_loc.tolist() lens = infer_state.b_q_seq_len.tolist() ready_lens = infer_state.b_ready_cache_len.tolist() - idx_q, idx_weight = self._indexer_q_weight( - x, - q_lora, - infer_state.position_cos_compress, - infer_state.position_sin_compress, - lw, - ) - q_chunks = [] - kv_chunks = [] - index_chunks = [] - out_ranges = [] - kv_offset = 0 - hold_req = infer_state.req_manager.HOLD_REQUEST_ID for req, st, ln, ready_len in zip(b_req, starts, lens, ready_lens): - if req == hold_req: - o[st : st + ln].zero_() + if req == rm.HOLD_REQUEST_ID or ln == 0: continue - q_r = q[st : st + ln] - cache_kv_r = cache_kv[st : st + ln] + seq_len = ready_len + ln + data = data_cache.get(req) + if data is None: + data = paged_prefill_compress_data( + rm.req_to_token_indexs, mem.full_to_swa_indexs, req, ready_len, seq_len, ring=8 + ) + data_cache[req] = data x_r = x[st : st + ln] - idx_q_r = None if idx_q is None else idx_q[st : st + ln] - idx_weight_r = None if idx_weight is None else idx_weight[st : st + ln] - kv_all, dense_base, n_window, ncomp, idx_comp = self._gather_prefill( - x_r, cache_kv_r, req, ready_len, lw, infer_state - ) - ti = self._topk_idxs_prefill( - ln, - dense_base, - n_window, - ncomp, - x.device, - ready_len, - idx_q_r, - idx_comp, - idx_weight_r, - infer_state, - )[0] - ti = torch.where(ti >= 0, ti + kv_offset, ti).to(torch.int32) - q_chunks.append(q_r) - kv_chunks.append(kv_all) - index_chunks.append(ti) - out_ranges.append((st, ln)) - kv_offset += kv_all.shape[0] - self._post_cache_kv( - cache_kv_r, - infer_state, - lw, - req_idx=req, - start_pos=ready_len, - mem_index=infer_state.mem_index[st : st + ln], - ) - if q_chunks: - attn_out = self._run_sparse_attention_batch(q_chunks, kv_chunks, index_chunks, sink) - out_offset = 0 - for st, ln in out_ranges: - o[st : st + ln] = attn_out[out_offset : out_offset + ln] - out_offset += ln - return o - - def _gather_prefill(self, x_r, kv_r, req, ready_len, lw, infer_state: DeepseekV4InferStateInfo): - ln = kv_r.shape[0] - idx_comp = None - if ready_len > 0: - return self._gather_prefill_extend(x_r, kv_r, req, ready_len, lw, infer_state) - if self.compress_ratio: - cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) - comp, ks, ss, cstate_pool = compressor_prefill_state( + comp = compressor_paged_prefill( x_r, - lw.compressor_wkv_.mm_param.weight, - lw.compressor_wgate_.mm_param.weight, - lw.compressor_norm_.weight, - lw.compressor_ape_.weight, - self.compress_ratio, + wkv, + wgate, + norm, + ape, self.head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, + self.cos_compress_table, + self.sin_compress_table, self.eps_, - return_state_pool=True, - state_pool=cstate_pool, + state_buf, + data, + ready_len, + seq_len, ) - comp_slots = self._write_compressed_kv(infer_state, req, 0, comp) - cstate_kv, cstate_score = infer_state.req_manager.get_compress_state_for_req(self.layer_num_, req) - cstate_kv.copy_(ks) - cstate_score.copy_(ss) - if self.compress_ratio == 4: - idx_cstate_pool = infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) - idx_comp, idx_ks, idx_ss, idx_cstate_pool = compressor_prefill_state( - x_r, - lw.idx_cmp_wkv_.mm_param.weight, - lw.idx_cmp_wgate_.mm_param.weight, - lw.idx_cmp_norm_.weight, - lw.idx_cmp_ape_.weight, - 4, - self.index_head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, - self.eps_, - return_state_pool=True, - state_pool=idx_cstate_pool, - ) - self._write_c4_indexer_k(infer_state, comp_slots, idx_comp) - idx_state = infer_state.req_manager.get_c4_indexer_compress_state(self.layer_num_) - idx_cstate_kv = idx_state[req, 0] - idx_cstate_score = idx_state[req, 1] - idx_cstate_kv.copy_(idx_ks) - idx_cstate_score.copy_(idx_ss) - ncomp = comp.shape[0] - comp = self._compressed_kv_from_cache(infer_state, req, ncomp) - idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) - return torch.cat([kv_r, comp], dim=0), 0, ln, ncomp, idx_comp - return kv_r, 0, ln, 0, None - - def _gather_prefill_extend(self, x_r, kv_r, req, ready_len, lw, infer_state: DeepseekV4InferStateInfo): - if self.compress_ratio: - state = self._get_compressor_state(infer_state, req) - cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) - idx_cstate_pool = ( - infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) - if self.compress_ratio == 4 - else None + slots = self._write_compressed_kv(infer_state, req, ready_len // 4, comp) + idx_comp = compressor_paged_prefill( + x_r, + iwkv, + iwgate, + inorm, + iape, + self.index_head_dim, + self.cos_compress_table, + self.sin_compress_table, + self.eps_, + idx_state_buf, + data, + ready_len, + seq_len, ) + if idx_comp.shape[0]: + infer_state.mem_manager.pack_indexer_k_to_cache(self.layer_num_, slots, idx_comp) + return - for j in range(x_r.shape[0]): - start_pos = ready_len + j - entry = compressor_decode_step( - x_r[j], - lw.compressor_wkv_.mm_param.weight, - lw.compressor_wgate_.mm_param.weight, - lw.compressor_norm_.weight, - lw.compressor_ape_.weight, + def _run_c128_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + rm = infer_state.req_manager + wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) + b_req = infer_state.b_req_idx.tolist() + starts = infer_state.b_q_start_loc.tolist() + lens = infer_state.b_q_seq_len.tolist() + ready_lens = infer_state.b_ready_cache_len.tolist() + for req, st, ln, ready_len in zip(b_req, starts, lens, ready_lens): + if req == rm.HOLD_REQUEST_ID: + continue + x_r = x[st : st + ln] + state_pool = rm.get_compress_state_pool_for_req(self.layer_num_, req) + if ready_len == 0: + comp = compressor_prefill_state( + x_r, + wkv, + wgate, + norm, + ape, self.compress_ratio, self.head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, + self.cos_compress_table, + self.sin_compress_table, self.eps_, - state["cstate_kv"], - state["cstate_score"], - start_pos, - state_pool=cstate_pool, + state_pool, ) - if entry is not None: - entry_start = (start_pos + 1) // self.compress_ratio - 1 - slots = self._write_compressed_kv(infer_state, req, entry_start, entry.unsqueeze(0)) - if self.compress_ratio == 4: - idx_entry = compressor_decode_step( + self._write_compressed_kv(infer_state, req, 0, comp) + else: + for j in range(ln): + start_pos = ready_len + j + entry = compressor_decode_step_single( x_r[j], - lw.idx_cmp_wkv_.mm_param.weight, - lw.idx_cmp_wgate_.mm_param.weight, - lw.idx_cmp_norm_.weight, - lw.idx_cmp_ape_.weight, - 4, - self.index_head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, + wkv, + wgate, + norm, + ape, + self.compress_ratio, + self.head_dim, + self.cos_compress_table, + self.sin_compress_table, self.eps_, - state["idx_cstate_kv"], - state["idx_cstate_score"], + state_pool, start_pos, - state_pool=idx_cstate_pool, ) - if idx_entry is not None: - if entry is None: - entry_start = (start_pos + 1) // self.compress_ratio - 1 - slots = infer_state.req_manager.ensure_compress_slots(self.layer_num_, req, entry_start, 1) - self._write_c4_indexer_k(infer_state, slots, idx_entry.unsqueeze(0)) - dense_end = ready_len + x_r.shape[0] - ncomp = dense_end // self.compress_ratio - dense_base = max(0, ready_len - self.window + 1) - cached_dense = self._dense_kv_from_cache(infer_state, req, dense_base, ready_len) - dense = torch.cat([cached_dense, kv_r], dim=0) - comp = self._compressed_kv_from_cache(infer_state, req, ncomp) - idx_comp = self._c4_indexer_k_from_cache(infer_state, req, ncomp) - return ( - torch.cat([dense, comp], dim=0), - dense_base, - dense.shape[0], - ncomp, - idx_comp, - ) - dense_base = max(0, ready_len - self.window + 1) - cached_dense = self._dense_kv_from_cache(infer_state, req, dense_base, ready_len) - dense = torch.cat([cached_dense, kv_r], dim=0) - return ( - dense, - dense_base, - dense.shape[0], - 0, - None, - ) + if entry is not None: + entry_start = (start_pos + 1) // self.compress_ratio - 1 + self._write_compressed_kv(infer_state, req, entry_start, entry.unsqueeze(0)) + return - def _topk_idxs_prefill( - self, - seqlen, - dense_base, - n_window, - ncomp, - device, - base_pos, - idx_q, - idx_comp, - idx_weight, - infer_state: DeepseekV4InferStateInfo, - ): - t = torch.arange(seqlen, device=device) - abs_pos = t + base_pos - offsets = torch.arange(self.window, device=device) - win_abs = abs_pos.unsqueeze(1) - (self.window - 1 - offsets).unsqueeze(0) - valid = (win_abs >= dense_base) & (win_abs < dense_base + n_window) - win = torch.where(valid, win_abs - dense_base, torch.full_like(win_abs, -1)) - if ncomp: - if self.compress_ratio == 4 and ncomp > self.index_topk: - comp = self._indexer_topk(idx_q, idx_comp, idx_weight, abs_pos + 1, n_window, infer_state) - else: - c = torch.arange(ncomp, device=device) - comp = torch.where( - c.unsqueeze(0) < ((abs_pos.unsqueeze(1) + 1) // self.compress_ratio), - (c.unsqueeze(0) + n_window).expand(seqlen, ncomp), - torch.full((seqlen, ncomp), -1, device=device, dtype=torch.long), - ) - return torch.cat([win, comp], dim=1).int().unsqueeze(0) - return win.int().unsqueeze(0) - - def _decode_dense_kv_graph(self, infer_state: DeepseekV4InferStateInfo): - req = infer_state.b_req_idx.long() - seq = infer_state.b_seq_len.long() - B = req.shape[0] - device = infer_state.b_seq_len.device - offsets = torch.arange(self.window, device=device, dtype=torch.long) - win_len = torch.minimum(seq, torch.full_like(seq, self.window)) - start = seq - win_len - pos = start.unsqueeze(1) + offsets.unsqueeze(0) - valid = offsets.unsqueeze(0) < win_len.unsqueeze(1) - hold = infer_state.mem_manager.swa_pool.HOLD_TOKEN_MEMINDEX - safe_pos = torch.where(valid, pos, torch.zeros_like(pos)).long() - full_slots = infer_state.req_manager.req_to_token_indexs[req.unsqueeze(1), safe_pos].long() - swa_slots = infer_state.mem_manager.full_to_swa_indexs[full_slots].long() - slot_valid = valid & (swa_slots >= 0) - swa_slots = torch.where(slot_valid, swa_slots, torch.full_like(swa_slots, hold)) - kv = infer_state.mem_manager.gather_mla_kv_from_swa_slots(self.layer_num_, swa_slots.reshape(-1)) - return kv.view(B, self.window, self.head_dim), valid - - def _decode_all_compressed_kv_graph(self, infer_state: DeepseekV4InferStateInfo, ratio): - req = infer_state.b_req_idx.long() - seq = infer_state.b_seq_len.long() - B = req.shape[0] - device = infer_state.b_seq_len.device - max_comp = max(1, infer_state.max_kv_seq_len // ratio) - offsets = torch.arange(max_comp, device=device, dtype=torch.long) - ncomp = torch.div(seq, ratio, rounding_mode="floor") - valid = offsets.unsqueeze(0) < ncomp.unsqueeze(1) - safe_offsets = torch.where(valid, offsets.unsqueeze(0), torch.zeros_like(offsets).unsqueeze(0)) - if ratio == 4: - table = infer_state.req_manager.req_to_c4_indexs - hold = infer_state.mem_manager.c4_pool.HOLD_TOKEN_MEMINDEX - else: - table = infer_state.req_manager.req_to_c128_indexs - hold = infer_state.mem_manager.c128_pool.HOLD_TOKEN_MEMINDEX - slots = table[req.unsqueeze(1), safe_offsets].long() - slots = torch.where(valid, slots, torch.full_like(slots, hold)) - kv = infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots.reshape(-1)) - kv = kv.view(B, max_comp, self.head_dim) - if ratio != 4: - return kv, None, valid, ncomp - idx_k = infer_state.mem_manager.gather_c4_indexer_k(self.layer_num_, slots.reshape(-1)) - idx_k = idx_k.view(B, max_comp, self.index_head_dim) - return kv, idx_k, valid, ncomp - - def _decode_c4_topk_graph( - self, idx_q, idx_weight, idx_comp, valid_comp, ncomp, infer_state: DeepseekV4InferStateInfo - ): - scores = torch.einsum("bhd,bnd->bhn", idx_q.float(), idx_comp.float()) - scores = F.relu(scores) * self.indexer_score_scale - index_scores = (scores * idx_weight.unsqueeze(-1)).sum(dim=1) - if self.tp_world_size_ > 1: - all_reduce(index_scores, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - index_scores = index_scores.masked_fill(~valid_comp, float("-inf")) - top = index_scores.topk(self.index_topk, dim=-1).indices - valid = top < ncomp.unsqueeze(1) - return torch.where(valid, top, torch.zeros_like(top)), valid + def _run_compressor_decode(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + """Batched decode compressor (cuda-graph safe): state update for every request, cache write + masked to the pool HOLD slot unless this token completes a window. Compressed-cache slots + were pre-allocated by prepare_decode_compress_slots in the prep phase. - def _decode_compressed_candidates_graph(self, idx_q, idx_weight, infer_state: DeepseekV4InferStateInfo): - if self.compress_ratio == 4: - _, idx_comp, valid_all, ncomp = self._decode_all_compressed_kv_graph(infer_state, 4) - top, valid = self._decode_c4_topk_graph(idx_q, idx_weight, idx_comp, valid_all, ncomp, infer_state) - req = infer_state.b_req_idx.long() - slots = infer_state.req_manager.req_to_c4_indexs[req.unsqueeze(1), top].long() - hold = infer_state.mem_manager.c4_pool.HOLD_TOKEN_MEMINDEX - slots = torch.where(valid, slots, torch.full_like(slots, hold)) - comp = infer_state.mem_manager.gather_compressed_kv(self.layer_num_, slots.reshape(-1)) - return comp.view(req.shape[0], self.index_topk, self.head_dim), valid - comp, _, valid, _ = self._decode_all_compressed_kv_graph(infer_state, 128) - return comp, valid - - def _write_decode_compressed_entry_graph(self, x, infer_state: DeepseekV4InferStateInfo, lw, ratio): + c4: paged state — group slots derived from full_to_swa (translation #3) via pure tensor + ops (graph-safe), shared across all c4 layers per step. c128: req-keyed state.""" + if not self.compress_ratio: + return + rm = infer_state.req_manager + mem = infer_state.mem_manager req = infer_state.b_req_idx - start_pos = infer_state.b_seq_len.long() - 1 + ratio = self.compress_ratio + wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) + if ratio == 4: - state_all = infer_state.req_manager.get_c4_compress_state(self.layer_num_) - table = infer_state.req_manager.req_to_c4_indexs - hold = infer_state.mem_manager.c4_pool.HOLD_TOKEN_MEMINDEX + mapping, hold = mem.full_to_c4_indexs, mem.c4_pool.HOLD_TOKEN_MEMINDEX + slot_meta = getattr(infer_state, "_dsv4_c4_decode_slots", None) + if slot_meta is None: + slot_meta = paged_decode_state_slots( + rm.req_to_token_indexs, + mem.full_to_swa_indexs, + req, + infer_state.b_seq_len, + page_size=128, + ring=8, + ratio=4, + hold_req_id=rm.HOLD_REQUEST_ID, + num_swa_pages=mem.swa_num_pages, + ) + infer_state._dsv4_c4_decode_slots = slot_meta + write_slot, overlap_slot = slot_meta + entry, should = compressor_paged_decode_batch( + x, + wkv, + wgate, + norm, + ape, + self.head_dim, + self.cos_compress_table, + self.sin_compress_table, + self.eps_, + mem.get_c4_state_buffer(self.layer_num_), + write_slot, + overlap_slot, + infer_state.b_seq_len, + ) else: - state_all = infer_state.req_manager.get_c128_compress_state(self.layer_num_) - table = infer_state.req_manager.req_to_c128_indexs - hold = infer_state.mem_manager.c128_pool.HOLD_TOKEN_MEMINDEX + mapping, hold = mem.full_to_c128_indexs, mem.c128_pool.HOLD_TOKEN_MEMINDEX + entry, should = compressor_decode_step_batch( + x, + wkv, + wgate, + norm, + ape, + ratio, + self.head_dim, + self.rope_dim, + self.cos_compress_table, + self.sin_compress_table, + self.eps_, + rm.get_compress_state_pool(self.layer_num_), + req, + infer_state.b_seq_len.long() - 1, + ) - entry, should = compressor_decode_step_batch( - x, - lw.compressor_wkv_.mm_param.weight, - lw.compressor_wgate_.mm_param.weight, - lw.compressor_norm_.weight, - lw.compressor_ape_.weight, - ratio, - self.head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, - self.eps_, - state_all, - req, - start_pos, - ) - entry_idx = torch.clamp(torch.div(infer_state.b_seq_len.long(), ratio, rounding_mode="floor") - 1, min=0) - slots = table[req.long(), entry_idx].long() + should = should & (req != rm.HOLD_REQUEST_ID) + # 本步 token 即组末 token(should 为真时),其 full 槽 = mem_index,映射在 prep 已 scatter。 + slots = mapping[infer_state.mem_index.long()].long() slots = torch.where(should, slots, torch.full_like(slots, hold)) - infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, entry) + mem.pack_compressed_kv_to_cache(self.layer_num_, slots, entry) if ratio == 4: - idx_state_all = infer_state.req_manager.get_c4_indexer_compress_state(self.layer_num_) - idx_entry, idx_should = compressor_decode_step_batch( + iwkv, iwgate, inorm, iape, _ = self._compressor_weights(layer_weight, for_indexer=True) + idx_entry, idx_should = compressor_paged_decode_batch( x, - lw.idx_cmp_wkv_.mm_param.weight, - lw.idx_cmp_wgate_.mm_param.weight, - lw.idx_cmp_norm_.weight, - lw.idx_cmp_ape_.weight, - 4, + iwkv, + iwgate, + inorm, + iape, self.index_head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, + self.cos_compress_table, + self.sin_compress_table, self.eps_, - idx_state_all, - req, - start_pos, + mem.get_c4_indexer_state_buffer(self.layer_num_), + write_slot, + overlap_slot, + infer_state.b_seq_len, ) + idx_should = idx_should & (req != rm.HOLD_REQUEST_ID) idx_slots = torch.where(idx_should, slots, torch.full_like(slots, hold)) - infer_state.mem_manager.pack_c4_indexer_k_to_cache(self.layer_num_, idx_slots, idx_entry) + mem.pack_indexer_k_to_cache(self.layer_num_, idx_slots, idx_entry) return - # ------------------------------------------------------------------ attention (decode) - def token_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, lw): - q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, lw) - if infer_state.is_cuda_graph: - o = self._token_attention_kernel_cuda_graph(q, cache_kv, q_lora, x, infer_state, lw) - else: - o = self._token_attention_kernel(q, cache_kv, q_lora, x, infer_state, lw) - return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, lw) - - def _token_attention_kernel_cuda_graph(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): - sink = lw.attn_sink_.weight - infer_state.mem_manager.pack_decode_mla_kv_to_cache( - self.layer_num_, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - cache_kv.reshape(cache_kv.shape[0], 1, cache_kv.shape[-1]), - ) - idx_q, idx_weight = self._indexer_q_weight( - x, - q_lora, - infer_state.position_cos_compress, - infer_state.position_sin_compress, - lw, - ) - if self.compress_ratio: - self._write_decode_compressed_entry_graph(x, infer_state, lw, self.compress_ratio) + # ------------------------------------------------------------------ attention (prefill) + def context_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) + # template hook: write the chunk's packed latent into the swa pool before attention + # reads it back via full_to_swa indices (this custom forward bypasses the tpl path). + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_wrapper_run(q, cache_kv, q_lora, x, infer_state, layer_weight) + return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, layer_weight) + + def _context_attention_wrapper_run( + self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight + ): + if torch.cuda.is_current_stream_capturing(): + q = q.contiguous() + cache_kv = cache_kv.contiguous() + q_lora = q_lora.contiguous() + x = x.contiguous() + _q = tensor_to_no_ref_tensor(q) + _cache_kv = tensor_to_no_ref_tensor(cache_kv) + _q_lora = tensor_to_no_ref_tensor(q_lora) + _x = tensor_to_no_ref_tensor(x) - dense_kv, dense_valid = self._decode_dense_kv_graph(infer_state) - B = q.shape[0] - device = q.device - if self.compress_ratio: - comp_kv, comp_valid = self._decode_compressed_candidates_graph(idx_q, idx_weight, infer_state) - kv_all = torch.cat([dense_kv, comp_kv], dim=1) - comp_offsets = torch.arange(comp_kv.shape[1], device=device, dtype=torch.int32) - else: - kv_all = dense_kv - comp_valid = None - comp_offsets = None - - total_k = kv_all.shape[1] - base = torch.arange(B, device=device, dtype=torch.int32).unsqueeze(1) * total_k - dense_offsets = torch.arange(self.window, device=device, dtype=torch.int32) - dense_topk = torch.where( - dense_valid, - base + dense_offsets.unsqueeze(0), - torch.full((B, self.window), -1, device=device, dtype=torch.int32), - ) - if self.compress_ratio: - comp_topk = torch.where( - comp_valid, - base + self.window + comp_offsets.unsqueeze(0), - torch.full((B, comp_kv.shape[1]), -1, device=device, dtype=torch.int32), - ) - topk = torch.cat([dense_topk, comp_topk], dim=1) - else: - topk = dense_topk - return vllm_sparse_attn_flat( - q, - kv_all.reshape(-1, self.head_dim), - sink, - topk, - self.softmax_scale, - already_compact=True, - ) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) - def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, lw): - B = x.shape[0] # one new token per request - idx_q, idx_weight = self._indexer_q_weight( - x, - q_lora, - infer_state.position_cos_compress, - infer_state.position_sin_compress, - lw, - ) - sink = lw.attn_sink_.weight - b_req = infer_state.b_req_idx.tolist() - seqlens = infer_state.b_seq_len.tolist() - o = x.new_empty(B, self.tp_q_heads, self.head_dim) - hold_req = infer_state.req_manager.HOLD_REQUEST_ID - q_chunks = [] - kv_chunks = [] - index_chunks = [] - out_rows = [] - kv_offset = 0 - for i, (req, seq) in enumerate(zip(b_req, seqlens)): - if req == hold_req: - o[i].zero_() - continue - start_pos = seq - 1 - self._post_cache_kv( - cache_kv[i : i + 1], - infer_state, - lw, - req_idx=req, - start_pos=start_pos, - mem_index=infer_state.mem_index[i : i + 1], - ) - if self.compress_ratio: - stt = self._get_compressor_state(infer_state, req) - cstate_pool = infer_state.req_manager.get_compress_state_pool_for_req(self.layer_num_, req) - e = compressor_decode_step( - x[i], - lw.compressor_wkv_.mm_param.weight, - lw.compressor_wgate_.mm_param.weight, - lw.compressor_norm_.weight, - lw.compressor_ape_.weight, - self.compress_ratio, - self.head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, - self.eps_, - stt["cstate_kv"], - stt["cstate_score"], - start_pos, - state_pool=cstate_pool, - ) - entry_slots = None - if e is not None: - entry_start = (start_pos + 1) // self.compress_ratio - 1 - entry_slots = self._write_compressed_kv(infer_state, req, entry_start, e.unsqueeze(0)) - if self.compress_ratio == 4: - idx_cstate_pool = infer_state.req_manager.get_c4_indexer_state_pool_for_req(self.layer_num_, req) - idx_e = compressor_decode_step( - x[i], - lw.idx_cmp_wkv_.mm_param.weight, - lw.idx_cmp_wgate_.mm_param.weight, - lw.idx_cmp_norm_.weight, - lw.idx_cmp_ape_.weight, - 4, - self.index_head_dim, - self.rope_dim, - infer_state.cos_compress_table, - infer_state.sin_compress_table, - self.eps_, - stt["idx_cstate_kv"], - stt["idx_cstate_score"], - start_pos, - state_pool=idx_cstate_pool, - ) - if idx_e is not None: - if entry_slots is None: - entry_start = (start_pos + 1) // self.compress_ratio - 1 - entry_slots = infer_state.req_manager.ensure_compress_slots( - self.layer_num_, req, entry_start, 1 - ) - self._write_c4_indexer_k(infer_state, entry_slots, idx_e.unsqueeze(0)) - win_start = max(0, seq - self.window) - win_kv = self._dense_kv_from_cache(infer_state, req, win_start, seq) - comp_kv = self._compressed_kv_from_cache(infer_state, req, seq // self.compress_ratio) - idx_comp = self._c4_indexer_k_from_cache(infer_state, req, comp_kv.shape[0]) - kv_all = torch.cat([win_kv, comp_kv], dim=0) - else: - win_start = max(0, seq - self.window) - win_kv = self._dense_kv_from_cache(infer_state, req, win_start, seq) - kv_all = win_kv - comp_kv = None - idx_comp = None - ti = self._topk_idxs_decode( - win_kv.shape[0], - comp_kv, - None if idx_q is None else idx_q[i : i + 1], - idx_comp, - None if idx_weight is None else idx_weight[i : i + 1], - seq, - x.device, - infer_state, - )[0, 0] - ti = torch.where(ti >= 0, ti + kv_offset, ti).view(1, -1).to(torch.int32) - q_chunks.append(q[i : i + 1]) - kv_chunks.append(kv_all) - index_chunks.append(ti) - out_rows.append(i) - kv_offset += kv_all.shape[0] - if q_chunks: - attn_out = self._run_sparse_attention_batch(q_chunks, kv_chunks, index_chunks, sink) - for row, row_out in zip(out_rows, attn_out): - o[row] = row_out - return o + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + o = torch.empty((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) + _o = tensor_to_no_ref_tensor(o) - def _indexer_q_weight(self, x, qa, cos_tok, sin_tok, lw): - if self.compress_ratio != 4: - return None, None - idx_q = lw.idx_wq_b_.mm(qa).view(x.shape[0], self.tp_index_heads, self.index_head_dim) - idx_q = torch.cat( - [ - idx_q[..., : -self.rope_dim], - apply_rotary_emb( - idx_q[..., -self.rope_dim :], - cos_tok.unsqueeze(1), - sin_tok.unsqueeze(1), - ), - ], - dim=-1, - ) - idx_weight = lw.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale - return idx_q, idx_weight + def att_func(new_infer_state: DeepseekV4InferStateInfo): + tmp_o = self._context_attention_kernel(_q, _cache_kv, _q_lora, _x, new_infer_state, layer_weight) + assert tmp_o.shape == _o.shape + _o.copy_(tmp_o) + return - def _indexer_topk( - self, idx_q, idx_comp, idx_weight, positions_1based, offset, infer_state: DeepseekV4InferStateInfo - ): - ncomp = idx_comp.shape[0] - k = min(self.index_topk, ncomp) - if k == 0: - return torch.empty((idx_q.shape[0], 0), device=idx_q.device, dtype=torch.long) - - top_chunks = [] - heads = max(1, idx_q.shape[1]) - max_score_elems = 16 * 1024 * 1024 - chunk_size = max(1, min(idx_q.shape[0], max_score_elems // max(1, heads * ncomp))) - for start in range(0, idx_q.shape[0], chunk_size): - end = min(idx_q.shape[0], start + chunk_size) - scores = torch.einsum("thd,nd->thn", idx_q[start:end].float(), idx_comp.float()) - scores = F.relu(scores) * self.indexer_score_scale - index_scores = (scores * idx_weight[start:end].unsqueeze(-1)).sum(dim=1) - if self.tp_world_size_ > 1: - all_reduce( - index_scores, - op=dist.ReduceOp.SUM, - group=infer_state.dist_group, - async_op=False, - ) - causal_threshold = positions_1based[start:end] // 4 - top_chunks.append(self._indexer_topk_kernel(index_scores, causal_threshold, k)) - top = torch.cat(top_chunks, dim=0) - valid = top >= 0 - return torch.where(valid, top + offset, torch.full_like(top, -1)) - - def _indexer_topk_kernel(self, index_scores, causal_threshold, topk): - if index_scores.is_cuda: - try: - import vllm._C # noqa: F401 - - scores = index_scores.contiguous() - lengths = causal_threshold.to(torch.int32).contiguous() - starts = torch.zeros_like(lengths, dtype=torch.int32) - top = torch.empty((scores.shape[0], topk), dtype=torch.int32, device=scores.device) - torch.ops._C.top_k_per_row_prefill( - scores, - starts, - lengths, - top, - scores.shape[0], - scores.stride(0), - scores.stride(1), - topk, - ) - return top.long() - except Exception: - pass + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph) + return o - entry_indices = torch.arange(index_scores.shape[1], device=index_scores.device) - index_scores = index_scores.masked_fill( - entry_indices.unsqueeze(0) >= causal_threshold.unsqueeze(1), float("-inf") + return self._context_attention_kernel(q, cache_kv, q_lora, x, infer_state, layer_weight) + + def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + self._run_compressor_prefill(x, infer_state, layer_weight) + idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "flashmla_kvcache": True, + "layer_index": self.layer_num_, + "compress_ratio": self.compress_ratio, + "head_dim_v": self.head_dim, + "softmax_scale": self.softmax_scale, + "cache_kv": cache_kv, + "q_lora": q_lora, + "hidden_states": x, + "attn_sink": layer_weight.attn_sink_.weight, + "idx_q": idx_q, + "idx_weight": idx_weight, + "index_topk": self.index_topk, + "indexer_score_scale": self.indexer_score_scale, + "tp_world_size": self.tp_world_size_, + }, + ) + return infer_state.prefill_att_state.prefill_att( + q=q, + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, ) - top = index_scores.topk(topk, dim=-1).indices - valid = top < causal_threshold.unsqueeze(1) - return torch.where(valid, top, torch.full_like(top, -1)) - - def _topk_idxs_decode( - self, - win_len, - comp_kv, - idx_q, - idx_comp, - idx_weight, - seq_len, - device, - infer_state: DeepseekV4InferStateInfo, - ): - win = torch.arange(win_len, device=device, dtype=torch.long) - if comp_kv is None or comp_kv.shape[0] == 0: - return win.view(1, 1, -1).int() - ncomp = comp_kv.shape[0] - if self.compress_ratio == 4 and ncomp > self.index_topk: - comp = self._indexer_topk( - idx_q, - idx_comp, - idx_weight, - torch.tensor([seq_len], device=device, dtype=torch.long), - win_len, - infer_state, - )[0] - else: - comp = torch.arange(ncomp, device=device, dtype=torch.long) + win_len - return torch.cat([win, comp], dim=0).view(1, 1, -1).int() - # ------------------------------------------------------------------ moe - def _fp4_experts(self, x, weights, indices, lw): - experts = lw.experts_ - if getattr(experts, "moe_backend", None) != "marlin": - err = getattr(experts, "moe_backend_error", "unknown") - raise RuntimeError(f"DeepSeek-V4 FP4 MoE requires vLLM Marlin backend, init_error={err}") - return self._fp4_experts_marlin(x, weights, indices, experts) - - def _fp4_experts_marlin(self, x, weights, indices, experts): - from vllm.model_executor.layers.fused_moe.activation import MoEActivation - from vllm.model_executor.layers.fused_moe.experts.marlin_moe import ( - fused_marlin_moe, + # ------------------------------------------------------------------ attention (decode) + def token_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, cache_kv, q_lora, x, infer_state, layer_weight) + return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, layer_weight) + + def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + self._run_compressor_decode(x, infer_state, layer_weight) + idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "flashmla_kvcache": True, + "layer_index": self.layer_num_, + "compress_ratio": self.compress_ratio, + "head_dim_v": self.head_dim, + "softmax_scale": self.softmax_scale, + "cache_kv": cache_kv, + "q_lora": q_lora, + "hidden_states": x, + "attn_sink": layer_weight.attn_sink_.weight, + "idx_q": idx_q, + "idx_weight": idx_weight, + "index_topk": self.index_topk, + "indexer_score_scale": self.indexer_score_scale, + "tp_world_size": self.tp_world_size_, + }, ) - from vllm.scalar_type import scalar_types - - return fused_marlin_moe( - hidden_states=x.contiguous(), - w1=experts.marlin_w13, - w2=experts.marlin_w2, - bias1=None, - bias2=None, - w1_scale=experts.marlin_w13_scale, - w2_scale=experts.marlin_w2_scale, - topk_weights=weights.to(torch.float32).contiguous(), - topk_ids=indices.to(torch.long).contiguous(), - quant_type_id=scalar_types.float4_e2m1f.id, - global_num_experts=experts.n_routed_experts, - activation=MoEActivation.SILU, + return infer_state.decode_att_state.decode_att( + q=q, + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) + + # ------------------------------------------------------------------ moe + def _routed_experts(self, x, weights, indices, layer_weight): + return layer_weight.experts_.experts_with_preselected( + input_tensor=x, + topk_weights=weights, + topk_ids=indices, clamp_limit=float(self.swiglu_limit), ) - def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, lw): - gw = lw.gate_weight_.mm_param.weight + def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + gw = layer_weight.gate_weight_.mm_param.weight logits = F.linear(x.float(), gw.float()).contiguous() - weights, indices = self._select_experts(logits, infer_state, lw) - routed = self._fp4_experts(x, weights, indices, lw) - g = lw.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) - u = lw.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) - shared = lw.shared_down_.mm((F.silu(g) * u).to(x.dtype)) - if self.enable_ep_moe and getattr(lw.experts_, "is_ep", False): + weights, indices = self._select_experts(logits, infer_state, layer_weight) + routed = self._routed_experts(x, weights, indices, layer_weight) + g = layer_weight.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) + u = layer_weight.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + shared = layer_weight.shared_down_.mm((F.silu(g) * u).to(x.dtype)) + if self.enable_ep_moe and getattr(layer_weight.experts_, "is_ep", False): if self.tp_world_size_ > 1: all_reduce( shared, @@ -975,10 +596,10 @@ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, lw): all_reduce(out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return out - def _select_experts(self, logits, infer_state: DeepseekV4InferStateInfo, lw): - return self._select_experts_vllm(logits, infer_state, lw) + def _select_experts(self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight): + return self._select_experts_vllm(logits, infer_state, layer_weight) - def _select_experts_vllm(self, logits, infer_state: DeepseekV4InferStateInfo, lw): + def _select_experts_vllm(self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight): from vllm import _custom_ops as ops M = logits.shape[0] @@ -987,13 +608,13 @@ def _select_experts_vllm(self, logits, infer_state: DeepseekV4InferStateInfo, lw hash_indices_table = None indices_dtype = torch.int64 if self.is_hash: - hash_indices_table = lw.gate_tid2eid_.weight + hash_indices_table = layer_weight.gate_tid2eid_.weight if not hash_indices_table.is_contiguous(): hash_indices_table = hash_indices_table.contiguous() indices_dtype = hash_indices_table.dtype input_tokens = infer_state.input_ids.to(dtype=indices_dtype).contiguous() else: - bias = lw.gate_bias_.weight + bias = layer_weight.gate_bias_.weight weights = torch.empty((M, self.topk), dtype=torch.float32, device=logits.device) indices = torch.empty((M, self.topk), dtype=indices_dtype, device=logits.device) diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py index cdaaac2cdb..a95299628c 100644 --- a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -1,5 +1,3 @@ -import threading - import torch from lightllm.common.basemodel import TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( @@ -9,202 +7,16 @@ RMSNormWeight, ParameterWeight, TpAttSinkWeight, + FusedMoeWeight, ) -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.quantization.registry import QUANTMETHODS -from lightllm.utils.log_utils import init_logger from ..triton_kernel.quant_convert import dequant_fp8_block_to_bf16 -logger = init_logger(__name__) - - -class DeepseekV4FP4ExpertsWeight(BaseWeightTpl): - _marlin_pack_lock = threading.Lock() - - def __init__(self, weight_prefix, n_routed_experts, hidden_size, moe_intermediate_size, data_type): - super().__init__(data_type=data_type) - self.weight_prefix = weight_prefix - self.n_routed_experts = n_routed_experts - self.hidden_size = hidden_size - self.moe_intermediate_size = moe_intermediate_size - self.split_inter_size = moe_intermediate_size // self.tp_world_size_ - self.local_expert_ids = list(range(n_routed_experts)) - self.expert_idx_to_local_idx = {expert_idx: expert_idx for expert_idx in self.local_expert_ids} - self.moe_backend = None - self.moe_backend_error = None - self._marlin_checked = False - self._load_lock = threading.Lock() - self.load_ok = { - name: [False] * n_routed_experts for name in ("w1", "w1_scale", "w2", "w2_scale", "w3", "w3_scale") - } - - def _create_weight(self): - self._ensure_raw_fp4_weight() - - def _ensure_raw_fp4_weight(self): - if hasattr(self, "w1"): - return - device = "cpu" - n = self.n_routed_experts - h = self.hidden_size - inter = self.split_inter_size - self.w1 = torch.empty((n, inter, h // 2), dtype=torch.int8, device=device) - self.w3 = torch.empty((n, inter, h // 2), dtype=torch.int8, device=device) - self.w2 = torch.empty((n, h, inter // 2), dtype=torch.int8, device=device) - self.w1_scale = torch.empty((n, inter, h // 32), dtype=torch.float8_e8m0fnu, device=device) - self.w3_scale = torch.empty((n, inter, h // 32), dtype=torch.float8_e8m0fnu, device=device) - self.w2_scale = torch.empty((n, h, inter // 32), dtype=torch.float8_e8m0fnu, device=device) - - def _copy_expert_weight(self, dst, weight, expert_idx, name, is_down=False): - if is_down: - start = self.tp_rank_ * self.split_inter_size // 2 - end = (self.tp_rank_ + 1) * self.split_inter_size // 2 - src = weight[:, start:end] - else: - start = self.tp_rank_ * self.split_inter_size - end = (self.tp_rank_ + 1) * self.split_inter_size - src = weight[start:end, :] - dst[expert_idx].copy_(src) - self.load_ok[name][expert_idx] = True - - def _copy_expert_scale(self, dst, scale, expert_idx, name, is_down=False): - if is_down: - start = self.tp_rank_ * self.split_inter_size // 32 - end = (self.tp_rank_ + 1) * self.split_inter_size // 32 - src = scale[:, start:end] - else: - start = self.tp_rank_ * self.split_inter_size - end = (self.tp_rank_ + 1) * self.split_inter_size - src = scale[start:end, :] - dst[expert_idx].copy_(src) - self.load_ok[name][expert_idx] = True - - def load_hf_weights(self, weights): - if self._marlin_checked: - return - has_weight = False - for expert_idx in self.local_expert_ids: - prefix = f"{self.weight_prefix}.{expert_idx}" - if ( - f"{prefix}.w1.weight" in weights - or f"{prefix}.w1.scale" in weights - or f"{prefix}.w2.weight" in weights - or f"{prefix}.w2.scale" in weights - or f"{prefix}.w3.weight" in weights - or f"{prefix}.w3.scale" in weights - ): - has_weight = True - break - if not has_weight: - return - - with self._load_lock: - if self._marlin_checked: - return - self._ensure_raw_fp4_weight() - for expert_idx in self.local_expert_ids: - prefix = f"{self.weight_prefix}.{expert_idx}" - w1 = f"{prefix}.w1.weight" - w1_scale = f"{prefix}.w1.scale" - w2 = f"{prefix}.w2.weight" - w2_scale = f"{prefix}.w2.scale" - w3 = f"{prefix}.w3.weight" - w3_scale = f"{prefix}.w3.scale" - if w1 in weights: - self._copy_expert_weight(self.w1, weights[w1], expert_idx, "w1") - if w1_scale in weights: - self._copy_expert_scale(self.w1_scale, weights[w1_scale], expert_idx, "w1_scale") - if w3 in weights: - self._copy_expert_weight(self.w3, weights[w3], expert_idx, "w3") - if w3_scale in weights: - self._copy_expert_scale(self.w3_scale, weights[w3_scale], expert_idx, "w3_scale") - if w2 in weights: - self._copy_expert_weight(self.w2, weights[w2], expert_idx, "w2", is_down=True) - if w2_scale in weights: - self._copy_expert_scale(self.w2_scale, weights[w2_scale], expert_idx, "w2_scale", is_down=True) - if self._raw_load_complete(): - self._try_init_marlin() - - def verify_load(self): - with self._load_lock: - ok = self._raw_load_complete() - if ok and not self._marlin_checked: - self._try_init_marlin() - return ok - - def _raw_load_complete(self): - return all(all(ok_list) for ok_list in self.load_ok.values()) - - def _try_init_marlin(self): - try: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_mxfp4_layer_for_marlin, - ) - - class _MarlinLayer: - pass - - with self._marlin_pack_lock: - torch.cuda.set_device(self.device_id_) - device = torch.device("cuda", self.device_id_) - layer = _MarlinLayer() - layer.params_dtype = self.data_type_ - w13_cpu, w13_scale_cpu = self._build_w13_weight() - w13 = w13_cpu.to(device=device, non_blocking=True).contiguous() - w2 = self.w2.view(torch.uint8).to(device=device, non_blocking=True).contiguous() - w13_scale = w13_scale_cpu.to(device=device, non_blocking=True).contiguous() - w2_scale = self.w2_scale.to(device=device, non_blocking=True).contiguous() - ( - self.marlin_w13, - self.marlin_w2, - self.marlin_w13_scale, - self.marlin_w2_scale, - _, - _, - ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None) - del w13_cpu, w13_scale_cpu, w13, w2, w13_scale, w2_scale - self.moe_backend = "marlin" - self._marlin_checked = True - self._release_raw_fp4_weight() - torch.cuda.empty_cache() - logger.info( - "DeepSeek-V4 FP4 experts use vLLM Marlin backend, prefix=%s, rank=%s", - self.weight_prefix, - self.tp_rank_, - ) - except Exception as e: - self.moe_backend_error = repr(e) - raise RuntimeError( - "DeepSeek-V4 FP4 experts require vLLM Marlin backend, " - f"prefix={self.weight_prefix}, rank={self.tp_rank_}, error={self.moe_backend_error}" - ) from e - - def _build_w13_weight(self): - n = self.n_routed_experts - h = self.hidden_size - inter = self.split_inter_size - w13 = torch.empty((n, 2 * inter, h // 2), dtype=torch.uint8, device=self.w1.device) - w13[:, :inter, :].copy_(self.w1.view(torch.uint8)) - w13[:, inter:, :].copy_(self.w3.view(torch.uint8)) - w13_scale = torch.empty((n, 2 * inter, h // 32), dtype=self.w1_scale.dtype, device=self.w1_scale.device) - w13_scale[:, :inter, :].copy_(self.w1_scale) - w13_scale[:, inter:, :].copy_(self.w3_scale) - return w13.contiguous(), w13_scale.contiguous() - - def _release_raw_fp4_weight(self): - for name in ("w1", "w1_scale", "w2", "w2_scale", "w3", "w3_scale"): - if hasattr(self, name): - delattr(self, name) - - class DeepseekV4TransformerLayerWeight(TransformerLayerWeight): """Per-layer weights for DeepSeek-V4-Flash. - The checkpoint stores most linears in FP8 (e4m3 + block-128 ue8m0 scale) and the routed - experts in FP4 (int8-packed e2m1 + group-32 ue8m0 scale). Hopper does not use the SM100 - MegaMoE path here, so routed experts are kept in packed FP4 and temporarily de-quantized only - for selected experts in the correctness-first torch MoE path. + DS4 does not share DS2/DS3.2's ``model.layers.*.self_attn/mlp`` layout. Its attention is + HC + CSA, and routed experts are checkpointed as MXFP4. """ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): @@ -213,7 +25,6 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): def _parse_config(self): cfg = self.network_config_ - self.fp8_quant = QUANTMETHODS.get("deepgemm-fp8w8a8-b128") self.hidden = cfg["hidden_size"] self.n_heads = cfg["num_attention_heads"] self.head_dim = cfg["head_dim"] @@ -238,13 +49,10 @@ def _parse_config(self): assert self.index_n_heads % self.tp_world_size_ == 0 self.prefix = f"layers.{self.layer_num_}" - def _init_weight_names(self): - return - def _init_weight(self): - self._init_attn() + self._init_qkvo() if self.has_compressor: - self._init_compressor(f"{self.prefix}.attn.compressor", self.head_dim, self.compress_ratio) + self._init_compressor() if self.has_indexer: self._init_indexer() self._init_moe() @@ -252,7 +60,7 @@ def _init_weight(self): self._init_hyper_connection() # ------------------------------------------------------------------ attention - def _init_attn(self): + def _init_qkvo(self): p = f"{self.prefix}.attn" # q low-rank (a replicated, b column-parallel over heads), kv single head (replicated) self.wq_a_ = ROWMMWeight( @@ -260,7 +68,7 @@ def _init_attn(self): out_dims=[self.q_lora_rank], weight_names=f"{p}.wq_a.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("wq_a"), tp_rank=0, tp_world_size=1, ) @@ -269,14 +77,14 @@ def _init_attn(self): out_dims=[self.n_heads * self.head_dim], weight_names=f"{p}.wq_b.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("wq_b"), ) self.wkv_ = ROWMMWeight( in_dim=self.hidden, out_dims=[self.head_dim], weight_names=f"{p}.wkv.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("wkv"), tp_rank=0, tp_world_size=1, ) @@ -301,11 +109,15 @@ def _init_attn(self): out_dims=[self.hidden], weight_names=f"{p}.wo_b.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("wo_b"), ) # ------------------------------------------------------------------ compressor / indexer - def _init_compressor(self, prefix, head_dim, ratio): + def _init_compressor(self): + prefix = f"{self.prefix}.attn.compressor" + head_dim = self.head_dim + ratio = self.compress_ratio + coff = 2 if ratio == 4 else 1 # wkv/wgate are bf16 (no scale) and replicated (single KV head). self.compressor_wkv_ = ROWMMWeight( @@ -341,7 +153,7 @@ def _init_indexer(self): out_dims=[self.index_n_heads * self.index_head_dim], weight_names=f"{p}.wq_b.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("idx_wq_b"), ) self.idx_weights_proj_ = ROWMMWeight( in_dim=self.hidden, @@ -406,28 +218,35 @@ def _init_moe(self): out_dims=[self.moe_inter], weight_names=f"{sp}.w1.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("shared_gate"), ) self.shared_up_ = ROWMMWeight( in_dim=self.hidden, out_dims=[self.moe_inter], weight_names=f"{sp}.w3.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("shared_up"), ) self.shared_down_ = COLMMWeight( in_dim=self.moe_inter, out_dims=[self.hidden], weight_names=f"{sp}.w2.weight", data_type=self.data_type_, - quant_method=self.fp8_quant, + quant_method=self.get_quant_method("shared_down"), ) - self.experts_ = DeepseekV4FP4ExpertsWeight( + self.experts_ = FusedMoeWeight( + gate_proj_name="w1", + down_proj_name="w2", + up_proj_name="w3", + e_score_correction_bias_name="", weight_prefix=f"{p}.experts", n_routed_experts=self.n_routed_experts, hidden_size=self.hidden, moe_intermediate_size=self.moe_inter, data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, ) def _init_norm(self): @@ -439,62 +258,68 @@ def _init_norm(self): ) def _init_hyper_connection(self): - for which in ["attn", "ffn"]: - setattr( - self, - f"hc_{which}_fn_", - ParameterWeight( - weight_name=f"{self.prefix}.hc_{which}_fn", - data_type=torch.float32, - weight_shape=(self.mix_hc, self.hc_mult * self.hidden), - ), - ) - setattr( - self, - f"hc_{which}_base_", - ParameterWeight( - weight_name=f"{self.prefix}.hc_{which}_base", data_type=torch.float32, weight_shape=(self.mix_hc,) - ), - ) - setattr( - self, - f"hc_{which}_scale_", - ParameterWeight( - weight_name=f"{self.prefix}.hc_{which}_scale", data_type=torch.float32, weight_shape=(3,) - ), - ) + p = self.prefix + self.hc_attn_fn_ = ParameterWeight( + weight_name=f"{p}.hc_attn_fn", + data_type=torch.float32, + weight_shape=(self.mix_hc, self.hc_mult * self.hidden), + ) + self.hc_attn_base_ = ParameterWeight( + weight_name=f"{p}.hc_attn_base", data_type=torch.float32, weight_shape=(self.mix_hc,) + ) + self.hc_attn_scale_ = ParameterWeight( + weight_name=f"{p}.hc_attn_scale", data_type=torch.float32, weight_shape=(3,) + ) + self.hc_ffn_fn_ = ParameterWeight( + weight_name=f"{p}.hc_ffn_fn", + data_type=torch.float32, + weight_shape=(self.mix_hc, self.hc_mult * self.hidden), + ) + self.hc_ffn_base_ = ParameterWeight( + weight_name=f"{p}.hc_ffn_base", data_type=torch.float32, weight_shape=(self.mix_hc,) + ) + self.hc_ffn_scale_ = ParameterWeight( + weight_name=f"{p}.hc_ffn_scale", data_type=torch.float32, weight_shape=(3,) + ) # ------------------------------------------------------------------ loading def load_hf_weights(self, weights): self._dequant_in_place(weights) return super().load_hf_weights(weights) - def _direct_fp8_weight_names(self): - names = set() - for attr_name in dir(self): - attr = getattr(self, attr_name, None) - quant_method = getattr(attr, "quant_method", None) - if getattr(quant_method, "method_name", None) == "deepgemm-fp8w8a8-b128": - names.update(getattr(attr, "weight_names", [])) - return names + def _fp8_scale_renames(self): + """Map weight name -> the scale name its quant method loads (e.g. `weight_scale_inv` + for DeepGEMM). Read from each MM weight's own `weight_scale_names`, so the rename + target always matches what that weight will look up; no-quant weights have None + entries and are skipped.""" + renames = {} + for attr in self.__dict__.values(): + weight_names = getattr(attr, "weight_names", ()) + scale_names = getattr(attr, "weight_scale_names", ()) + for weight_name, scale_name in zip(weight_names, scale_names): + if scale_name is not None: + renames[weight_name] = scale_name + return renames def _dequant_in_place(self, weights): p = self.prefix + "." - direct_fp8_names = self._direct_fp8_weight_names() - # Convert every (weight, scale) pair belonging to this layer. Existing FP8 matmul - # weights stay quantized; bmm-only weights are expanded; routed FP4 experts stay packed. - for k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".weight")]: - scale_k = k[: -len(".weight")] + ".scale" - if scale_k not in weights: - continue - w, s = weights[k], weights[scale_k] - if w.dtype == torch.int8: # FP4 routed experts stay packed for DeepseekV4FP4ExpertsWeight. + scale_renames = self._fp8_scale_renames() + # Convert every `.scale` belonging to this layer. Weights are loaded incrementally + # per safetensors shard, so the paired weight may live in another shard: + # - routed FP4 experts keep `.scale` as-is (matches marlin-mxfp4w4a16-b32's suffix); + # - FP8 matmul scales only need renaming for DeepGEMM, no weight required; + # - FP8 pairs on no-quant paths (wo_a's ROWBMMWeight) are expanded to bf16, + # the only case that truly requires weight and scale in the same shard. + for scale_k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".scale")]: + if scale_k.startswith(f"{p}ffn.experts."): continue - elif k in direct_fp8_names: # FP8 e4m3, block-128 scale, run by DeepGEMM directly - weights[k.replace("weight", "weight_scale_inv")] = s.to(torch.float32) + k = scale_k[: -len(".scale")] + ".weight" + target = scale_renames.get(k) + if target is not None: # FP8 e4m3, block-128 scale, run by DeepGEMM directly + weights[target] = weights[scale_k].to(torch.float32) del weights[scale_k] - else: # FP8 e4m3 for no-quant paths such as ROWBMMWeight - weights[k] = dequant_fp8_block_to_bf16(w, s).to(self.data_type_) + else: + weights[k] = dequant_fp8_block_to_bf16(weights[k], weights[scale_k]).to(self.data_type_) del weights[scale_k] # grouped-O: reshape [groups*o_lora, in] -> [groups, in, o_lora] for the batched matmul woa = f"{self.prefix}.attn.wo_a.weight" diff --git a/lightllm/models/deepseek_v4/mem_manager.py b/lightllm/models/deepseek_v4/mem_manager.py deleted file mode 100644 index 288d433380..0000000000 --- a/lightllm/models/deepseek_v4/mem_manager.py +++ /dev/null @@ -1,12 +0,0 @@ -from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager - - -class DeepseekV4MemoryManager(Deepseek2MemoryManager): - """Stores the per-token MLA KV (head_num=1, head_dim=512), reusing the deepseek2 layout/operator. - - The prefill path computes attention in-layer from the request's hidden states, so it does not read - this buffer. The decode/incremental path (M6) will add the sliding-window ring + compressed-KV + - per-request compressor-state buffers here. - """ - - pass diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 687d5f46f0..1a88e08977 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -7,11 +7,6 @@ from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.req_manager import DeepseekV4ReqManager from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager -from lightllm.common.basemodel.attention.base_att import ( - BaseAttBackend, - BasePrefillAttState, - BaseDecodeAttState, -) from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import ( DeepseekV4PreAndPostLayerWeight, ) @@ -27,12 +22,13 @@ from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import ( DeepseekV4TransformerLayerInfer, ) +from lightllm.common.basemodel.attention.create_utils import nsa_data_type_to_backend from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo from lightllm.models.llama.yarn_rotary_utils import ( find_correction_range, linear_ramp_mask, ) -from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -40,36 +36,6 @@ DSV4_DECODE_CUDAGRAPH_MAX_LEN = 8192 -class DeepseekV4DirectSparseAttBackend(BaseAttBackend): - """Lifecycle placeholder for V4 direct attention. - - V4 attention is currently driven inside the layer, not by the generic - `infer_state.prefill_att_state.prefill_att()` / `decode_att()` backend selector. - """ - - def create_att_prefill_state(self, infer_state: DeepseekV4InferStateInfo): - return DeepseekV4DirectSparsePrefillAttState(backend=self, infer_state=infer_state) - - def create_att_decode_state(self, infer_state: DeepseekV4InferStateInfo): - return DeepseekV4DirectSparseDecodeAttState(backend=self, infer_state=infer_state) - - -class DeepseekV4DirectSparsePrefillAttState(BasePrefillAttState): - def init_state(self): - return - - def prefill_att(self, *args, **kwargs): - raise RuntimeError("DeepSeek-V4 attention is executed directly in layer_infer.") - - -class DeepseekV4DirectSparseDecodeAttState(BaseDecodeAttState): - def init_state(self): - return - - def decode_att(self, *args, **kwargs): - raise RuntimeError("DeepSeek-V4 attention is executed directly in layer_infer.") - - @ModelRegistry("deepseek_v4") class DeepseekV4TpPartModel(LlamaTpPartModel): req_manager: DeepseekV4ReqManager @@ -107,6 +73,7 @@ def _init_req_manager(self): compress_rates=self._dsv4_compress_rates, head_dim=self.config["head_dim"], indexer_head_dim=self.config["index_head_dim"], + sliding_window=self.config["sliding_window"], ) return @@ -117,6 +84,10 @@ def _get_compress_rates(self, layer_num): def _init_mem_manager(self): layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num() compress_rates = getattr(self, "_dsv4_compress_rates", self._get_compress_rates(layer_num)) + sliding_window = int(self.config["sliding_window"]) + # 活跃窗口之外的 swa 余量: 在途 prefill chunk 的瞬时占用(出窗槽位到下一次 prep 才回收) + # + radix cache 持有的窗口尾部(每条缓存序列约一个 window)。 + swa_extra_token_num = int(self.batch_max_tokens or 0) + self.max_req_num * sliding_window self.mem_manager = DeepseekV4MemoryManager( self.max_total_token_num, dtype=self.data_type, @@ -126,7 +97,8 @@ def _init_mem_manager(self): compress_rates=compress_rates, indexer_head_dim=self.config["index_head_dim"], max_request_num=self.max_req_num, - sliding_window=self.config["sliding_window"], + sliding_window=sliding_window, + swa_extra_token_num=swa_extra_token_num, mem_fraction=self.mem_fraction, ) assert isinstance(self.req_manager, DeepseekV4ReqManager) @@ -137,7 +109,7 @@ def _init_cudagraph(self): if not self.disable_cudagraph and self.graph_max_len_in_batch > DSV4_DECODE_CUDAGRAPH_MAX_LEN: logger.info( "DeepSeek-V4 caps decode cudagraph max_len_in_batch from %s to %s for the current " - "graph-safe sparse-attention fallback; longer decode batches run eager.", + "graph-safe sparse-attention path; longer decode batches run eager.", self.graph_max_len_in_batch, DSV4_DECODE_CUDAGRAPH_MAX_LEN, ) @@ -150,8 +122,14 @@ def _can_run_prefill_cudagraph(self, infer_state: DeepseekV4InferStateInfo, hand return False def _init_att_backend(self): - self.prefill_att_backend = DeepseekV4DirectSparseAttBackend(model=self) - self.decode_att_backend = DeepseekV4DirectSparseAttBackend(model=self) + args = get_env_start_args() + if args.llm_kv_type == "None": + args.llm_kv_type = "fp8kv_dsa" + if args.llm_kv_type != "fp8kv_dsa": + raise RuntimeError("DeepSeek-V4 requires llm_kv_type=fp8kv_dsa for packed FlashMLA sparse attention") + backend_cls = nsa_data_type_to_backend["fp8kv_dsa"]["flashmla_sparse"] + self.prefill_att_backend = backend_cls(model=self) + self.decode_att_backend = backend_cls(model=self) return def _init_custom(self): @@ -165,11 +143,12 @@ def _init_custom(self): return def _init_to_get_rotary(self): - # Interleaved (GPT-J) rope. Build real cos/sin tables (_cos_cached_*/_sin_cached_*) following the - # gemma4 two-variant convention; the infer-struct slices them into position_cos_*/position_sin_* - # and apply_rotary_emb (interleaved, NOT the NeoX rotary_emb_fwd) applies them. Sliding-window - # layers use base rope_theta (no YaRN); compressed (CSA/HCA) layers use compress_rope_theta with - # YaRN. Tables kept fp32 for accuracy (the apply upcasts anyway). + # Interleaved (GPT-J) rope. Build complex64 freqs_cis tables (_freqs_cis_*) following the + # gemma4 two-variant convention; the fused sglang q kernel consumes them directly, while + # _cos_cached_*/_sin_cached_* are .real/.imag views of the same storage for the kv rope, + # inverse rope and compressor paths (apply_rotary_emb: interleaved, NOT the NeoX + # rotary_emb_fwd). Sliding-window layers use base rope_theta (no YaRN); compressed (CSA/HCA) + # layers use compress_rope_theta with YaRN. Kept fp32 for accuracy (the apply upcasts anyway). cfg = self.config rs = cfg.get("rope_scaling", {}) or {} dim = cfg["qk_rope_head_dim"] @@ -185,18 +164,29 @@ def build(base, factor, orig_max): smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda() freqs = freqs / factor * (1 - smooth) + freqs * smooth f = torch.outer(torch.arange(max_seq, dtype=torch.float32, device="cuda"), freqs) # [max_seq, dim//2] - return f.cos(), f.sin() + return torch.complex(f.cos(), f.sin()) - self._cos_cached_sliding, self._sin_cached_sliding = build( + self._freqs_cis_sliding = build( cfg["rope_theta"], rs.get("factor", 16), rs.get("original_max_position_embeddings", 65536), ) - self._cos_cached_compress, self._sin_cached_compress = build( + self._freqs_cis_compress = build( cfg["compress_rope_theta"], rs.get("factor", 16), rs.get("original_max_position_embeddings", 65536), ) + self._cos_cached_sliding = self._freqs_cis_sliding.real + self._sin_cached_sliding = self._freqs_cis_sliding.imag + self._cos_cached_compress = self._freqs_cis_compress.real + self._sin_cached_compress = self._freqs_cis_compress.imag + # Each layer uses exactly one rope variant; wire its table once here (layers are already + # built: _init_infer_layer runs before _init_custom) instead of relaying via infer_state. + # The compressor needs the full compress tables (entry rope positions != token positions). + for layer in self.layers_infer: + layer.freqs_cis = self._freqs_cis_compress if layer.compress_ratio else self._freqs_cis_sliding + layer.cos_compress_table = self._cos_cached_compress + layer.sin_compress_table = self._sin_cached_compress return diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py new file mode 100644 index 0000000000..3510b92c30 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py @@ -0,0 +1,92 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_k_dsv4( + K, + Dest_loc, + O_fp8, + O_f32, + stride_k_bs, + stride_k_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + SCALE_MIN: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + # negative dest (unmapped slot) is a no-op, not an OOB write into a neighboring page. + if dest_index < 0: + return + + page = dest_index // PAGE_SIZE + token_in_page = dest_index % PAGE_SIZE + + offs_d = tl.arange(0, HEAD_DIM) + vals = tl.load(K + cur_index * stride_k_bs + offs_d * stride_k_d).to(tl.float32) + amax = tl.max(tl.abs(vals), axis=0) + # per-token plain fp32 scale (not ue8m0), matching DeepseekV4MemoryManager._pack_indexer_k + scale = tl.maximum(amax / FP8_MAX, SCALE_MIN) + k_fp8 = tl.clamp(vals / scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + data_base = page * BYTES_PER_PAGE + token_in_page * HEAD_DIM + tl.store(O_fp8 + data_base + offs_d, k_fp8) + scale_idx = (page * BYTES_PER_PAGE + PAGE_SIZE * HEAD_DIM) // 4 + token_in_page + tl.store(O_f32 + scale_idx, scale) + return + + +@torch.no_grad() +def destindex_copy_indexer_k_dsv4( + K: torch.Tensor, + DestLoc: torch.Tensor, + O_buffer: torch.Tensor, + page_size: int, +): + """Packed indexer-K page-slab writer (DeepSeek-V4 c4/CSA layers). + + K: [T, 128] bf16 unquantized indexer keys. + DestLoc: [T] int — c4-pool-local token slots; must already be allocated by the caller. + Negative slots (unmapped) are skipped. + O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from the c4 indexer + PackedPagePool (128B fp8 data region + 4B fp32 scale tail per token). + + Bit-compatible with DeepseekV4MemoryManager._pack_indexer_k + PackedPagePool.write. + """ + seq_len = DestLoc.shape[0] + if seq_len == 0: + return + head_dim, scale_bytes = 128, 4 + + K = K.reshape(-1, head_dim) + assert K.shape[0] == seq_len, f"Expected K shape[0]={seq_len}, got {K.shape[0]}" + assert K.dtype == torch.bfloat16, f"Expected bf16 indexer K, got {K.dtype}" + bytes_per_page = O_buffer.shape[-1] + assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous() + assert bytes_per_page % 4 == 0 + assert bytes_per_page >= page_size * (head_dim + scale_bytes) + + flat = O_buffer.view(-1) + _fwd_kernel_destindex_copy_indexer_k_dsv4[(seq_len,)]( + K, + DestLoc, + flat.view(torch.float8_e4m3fn), + flat.view(torch.float32), + K.stride(0), + K.stride(1), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + SCALE_MIN=1e-4, + HEAD_DIM=head_dim, + PAGE_SIZE=page_size, + BYTES_PER_PAGE=bytes_per_page, + num_warps=1, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py new file mode 100644 index 0000000000..a3ec6ed8cf --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py @@ -0,0 +1,121 @@ +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_flashmla_dsv4( + KV, + Dest_loc, + O_fp8, + O_bf16, + O_u8, + stride_kv_bs, + stride_kv_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + SCALE_MIN: tl.constexpr, + NOPE_DIM: tl.constexpr, + ROPE_DIM: tl.constexpr, + GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, + SCALE_BYTES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + # negative dest (unmapped slot, e.g. full_to_c* rows that never closed a group) is a no-op, + # not an OOB write into a neighboring page. + if dest_index < 0: + return + + page = dest_index // PAGE_SIZE + token_in_page = dest_index % PAGE_SIZE + data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_DIM * 2) + scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_DIM * 2) + token_in_page * SCALE_BYTES + + # nope: per-group ue8m0 quant. SCALE_BYTES(=NUM_GROUPS+1) lanes cover the exponent bytes + # plus the trailing zero pad byte in one store. libdevice.log2 (not tl.log2, which is the + # approx instruction) and the bit-packed 2**e keep this bit-exact with the torch oracle + # DeepseekV4MemoryManager._pack_mla_kv. + offs_g = tl.arange(0, SCALE_BYTES) + offs_e = tl.arange(0, GROUP_SIZE) + group_mask = offs_g < NUM_GROUPS + kv_ptrs = KV + cur_index * stride_kv_bs + (offs_g[:, None] * GROUP_SIZE + offs_e[None, :]) * stride_kv_d + vals = tl.load(kv_ptrs, mask=group_mask[:, None], other=0.0).to(tl.float32) + amax = tl.max(tl.abs(vals), axis=1) + scale_exp = tl.ceil(libdevice.log2(tl.maximum(amax / FP8_MAX, SCALE_MIN))).to(tl.int32) + scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True) + kv_fp8 = tl.clamp(vals / scale[:, None], min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + tl.store(O_fp8 + data_base + offs_g[:, None] * GROUP_SIZE + offs_e[None, :], kv_fp8, mask=group_mask[:, None]) + scale_bytes = tl.where(group_mask, scale_exp + 127, 0).to(tl.uint8) + tl.store(O_u8 + scale_base + offs_g, scale_bytes) + + # rope: bf16 passthrough into the data region right after the nope bytes + offs_r = tl.arange(0, ROPE_DIM) + rope = tl.load(KV + cur_index * stride_kv_bs + (NOPE_DIM + offs_r) * stride_kv_d) + tl.store(O_bf16 + (data_base + NOPE_DIM) // 2 + offs_r, rope) + return + + +@torch.no_grad() +def destindex_copy_kv_flashmla_dsv4( + KV: torch.Tensor, + DestLoc: torch.Tensor, + O_buffer: torch.Tensor, + page_size: int, +): + """fp8_ds_mla packed page-slab writer (DeepSeek-V4 ABI, all latent pools). + + KV: [T, 512] bf16 — 448 normed-latent dims + 64 rope'd dims per token. + DestLoc: [T] int — pool-local token slots (page = slot // page_size); the pool HOLD slot is + a valid in-bounds row, negative slots (unmapped) are skipped. Slots must already be + resolved/allocated by the caller. + O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from PackedPagePool + (swa page=128 / c4 page=64 / c128 page=2 all share this kernel). + + Per token: 448B fp8(e4m3) in 7x64 ue8m0 groups + 128B bf16 rope in the page data region; + 7 exponent bytes (e+127) + 1 zero pad at the page scale tail. Bit-compatible with + DeepseekV4MemoryManager._pack_mla_kv + PackedPagePool.write. + """ + seq_len = DestLoc.shape[0] + if seq_len == 0: + return + nope_dim, rope_dim, group_size = 448, 64, 64 + head_dim = nope_dim + rope_dim + scale_bytes = nope_dim // group_size + 1 + + KV = KV.reshape(-1, head_dim) + assert KV.shape[0] == seq_len, f"Expected KV shape[0]={seq_len}, got {KV.shape[0]}" + assert KV.dtype == torch.bfloat16, f"Expected bf16 KV (rope bytes are stored as-is), got {KV.dtype}" + bytes_per_page = O_buffer.shape[-1] + assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous() + assert bytes_per_page % 2 == 0 + assert bytes_per_page >= page_size * (nope_dim + rope_dim * 2 + scale_bytes) + + flat = O_buffer.view(-1) + _fwd_kernel_destindex_copy_kv_flashmla_dsv4[(seq_len,)]( + KV, + DestLoc, + flat.view(torch.float8_e4m3fn), + flat.view(torch.bfloat16), + flat, + KV.stride(0), + KV.stride(1), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + SCALE_MIN=1e-4, + NOPE_DIM=nope_dim, + ROPE_DIM=rope_dim, + GROUP_SIZE=group_size, + NUM_GROUPS=nope_dim // group_size, + SCALE_BYTES=scale_bytes, + PAGE_SIZE=page_size, + BYTES_PER_PAGE=bytes_per_page, + num_warps=4, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py index c7d2d59ec6..47d87d4932 100644 --- a/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py +++ b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py @@ -1,15 +1,5 @@ import torch -# DeepSeek-V4-Flash ships weights in two quantized formats: -# * non-expert linears: FP8 e4m3 with block-[128,128] scales stored as float8_e8m0fnu (ue8m0) -# * routed experts: FP4 e2m1 packed 2-per-byte (stored as int8) with group-32 ue8m0 scales -# Hopper (H200) has no native SM100 MegaMoE path. Non-expert FP8 weights can run directly through -# DeepGEMM. Routed FP4 experts are converted blockwise to FP8, avoiding a full bf16 expansion. - -# OCP E2M1 magnitude table for the 3 low bits (sign = bit 3). torch.float4_e2m1fn_x2 packs two -# such codes per byte, low nibble = lower (even) logical index. -_E2M1_MAG = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] - def e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: """float8_e8m0fnu encodes 2**(byte-127); torch decodes it correctly on .to(float32).""" @@ -24,70 +14,3 @@ def dequant_fp8_block_to_bf16(weight_e4m3: torch.Tensor, scale_e8m0: torch.Tenso s = e8m0_to_fp32(scale_e8m0).cuda().contiguous() # weight_dequant runs with torch default dtype for the output; force bf16 result. return weight_dequant(w, s, block_size) - - -def cast_e2m1fn_to_e4m3fn(weight_int8: torch.Tensor, scale_e8m0: torch.Tensor): - """Cast packed FP4 e2m1 expert weights to FP8 e4m3 with block-128 fp32 scales. - - This follows the DeepSeek-V4 reference converter, but returns the scale in fp32 because - LightLLM's DeepGEMM FP8 weight pack stores block scales as fp32. - """ - assert weight_int8.dtype == torch.int8 - assert weight_int8.ndim == 2 - out_dim, packed_in = weight_int8.shape - in_dim = packed_in * 2 - fp8_block_size = 128 - fp4_block_size = 32 - assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0 - assert scale_e8m0.shape[0] == out_dim - assert scale_e8m0.shape[1] == in_dim // fp4_block_size - - table = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], - dtype=torch.float32, - device=weight_int8.device, - ) - packed = weight_int8.view(torch.uint8) - low = packed & 0x0F - high = (packed >> 4) & 0x0F - vals = torch.stack([table[low.long()], table[high.long()]], dim=-1).reshape(out_dim, in_dim) - - # 6.0 * 2**6 fits in e4m3fn (384 < 448), while 6.0 * 2**7 would overflow. - max_offset_bits = 6 - block_out = out_dim // fp8_block_size - block_in = in_dim // fp8_block_size - - vals = vals.view(block_out, fp8_block_size, block_in, fp8_block_size).transpose(1, 2) - scale = scale_e8m0.float().view(block_out, fp8_block_size, block_in, -1).transpose(1, 2).flatten(2) - block_scale = scale.amax(dim=-1, keepdim=True) / (2**max_offset_bits) - offset = scale / block_scale - offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1) - vals = (vals * offset).transpose(1, 2).reshape(out_dim, in_dim) - block_scale = block_scale.squeeze(-1).to(torch.float8_e8m0fnu).to(torch.float32) - return vals.to(torch.float8_e4m3fn), block_scale - - -def dequant_fp4_group_to_bf16(weight_int8: torch.Tensor, scale_e8m0: torch.Tensor, group_size: int = 32): - """De-quantize an int8-packed FP4 e2m1 weight to bf16. - - weight_int8: [out, in // 2] int8 (two e2m1 codes per byte, low nibble = even index). - scale_e8m0: [out, in // group_size] ue8m0 (one scale per group_size logical elements along K). - returns: [out, in] bf16. - """ - w = weight_int8.cuda() - out, packed_in = w.shape - in_dim = packed_in * 2 - b = w.to(torch.int32).bitwise_and(0xFF) - lut = torch.tensor(_E2M1_MAG, dtype=torch.float32, device=w.device) - - def _decode(nib: torch.Tensor) -> torch.Tensor: - mag = lut[nib.bitwise_and(0x7)] - neg = nib.bitwise_and(0x8).bool() - return torch.where(neg, -mag, mag) - - lo = _decode(b.bitwise_and(0xF)) - hi = _decode(b.bitwise_right_shift(4).bitwise_and(0xF)) - vals = torch.stack([lo, hi], dim=-1).reshape(out, in_dim) # [out, in] - s = e8m0_to_fp32(scale_e8m0).cuda() # [out, in//group_size] - s = s.repeat_interleave(group_size, dim=1)[:, :in_dim] - return (vals * s).to(torch.bfloat16) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 2db6c67e77..70d5c72ac3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -625,8 +625,11 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, default=None, choices=["fp8", "fp4"], - help="""Expert quantization dtype for EP MoE. Supported values are - fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""", + help="""Requested dtype for MoE expert weights, fp8 or fp4. Resolves the fused_moe + quant method: fp8 -> deepgemm-fp8w8a8-b128; fp4 -> deepgemm-fp4fp8-b32 (online + quantization) on SM100 GPUs, or marlin-mxfp4w4a16-b32 (Marlin W4A16, TP only) on other GPUs. + Defaults to `expert_dtype` in config.json if present. Per-layer override: + --quant_cfg mix_bits with name `fused_moe`.""", ) parser.add_argument( "--vit_quant_type", diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 8be5198eb3..ff09c018be 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -316,6 +316,9 @@ def _insert_helper_no_recursion( def match_prefix(self, key, update_refs=False): key = key[: self._align_len(len(key))] + if len(key) == 0: + return None, 0, None + key = self._trim_key_by_extra_value_validity(key) if len(key) == 0: return None, 0, None ans_value_list = [] @@ -331,6 +334,30 @@ def match_prefix(self, key, update_refs=False): self.dec_node_ref_counter(self.root_node) return None, 0, None + def _trim_key_by_extra_value_validity(self, key: torch.Tensor) -> torch.Tensor: + """命中有效性裁剪(extra_value_ops 提供 valid_match_length 时启用,如 DeepSeek-V4 的 + swa 按页 bitmap): 先做一次只读探测遍历得到自然命中与沿路 extra_value,按其有效边界截短 + key,随后的正常遍历(加引用/分裂)只走截短后的前缀 —— 引用计数与最终返回值在同一次遍历 + 内保持一致,不存在事后裁剪导致的漏减/多减。 + + 探测遍历可能分裂部分命中的节点(与正常遍历同语义,树不变式不受影响)。裁剪只会缩短命中, + 没有任何失败路径。""" + if self.extra_value_ops is None: + return key + valid_match_length = getattr(self.extra_value_ops, "valid_match_length", None) + if valid_match_length is None: + return key + probe_values = [] + probe_node = self._match_prefix_helper(self.root_node, key, probe_values, update_refs=False) + if probe_node == self.root_node or len(probe_values) == 0: + return key + natural_len = sum(len(v) for v in probe_values) + extra_value = self.get_extra_value_by_node(probe_node) + valid_len = int(valid_match_length(extra_value, natural_len)) + if valid_len < natural_len: + return key[:valid_len] + return key + def _match_prefix_helper( self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False ) -> TreeNode: @@ -595,6 +622,36 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return + def reclaim_unreferenced_swa_pages(self, need_pages: int) -> None: + """DeepSeek-V4 swa 压力阀: 页 allocator 触底时,沿 LRU 序(evict_tree_set)只对 + ref_count==0 的节点链回收其 swa 页(full 槽与压缩条目保留——节点仍可服务更长前缀的 + 中段命中),并清载荷 bitmap 位使后续命中按缩短语义裁剪。所有权判定直接复用 radix + 引用计数: 节点被任何活跃请求借用即 ref>0,其页不可达。不够时由 allocator 的 assert + 兜底(最后防线)。""" + if self.mem_manager is None or self.extra_value_ops is None: + return + invalidate = getattr(self.extra_value_ops, "invalidate_swa_pages", None) + if invalidate is None: + return + allocator = self.mem_manager.swa_page_allocator + target = allocator.can_use_mem_size + int(need_pages) + for leaf in list(self.evict_tree_set): + if allocator.can_use_mem_size >= target: + break + node = leaf + # 叶子起步沿父链回收: 引用计数向上累加(add_node_ref_counter 走父链), + # 因此 ref==0 的祖先必无任何活跃借用方。重复访问无害(evict_swa/-1 跳过)。 + # 每回收一个节点就复查目标,避免多回收(无谓削减命中可用性)。 + while node is not None and node is not self.root_node and node.ref_counter == 0: + if len(node.token_mem_index_value) > 0: + self.mem_manager.evict_swa(node.token_mem_index_value) + if node.token_extra_value is not None: + invalidate(node.token_extra_value) + if allocator.can_use_mem_size >= target: + return + node = node.parent + return + def free_radix_cache_to_get_enough_token(self, need_token_num): assert self.mem_manager is not None if need_token_num > self.mem_manager.allocator.can_use_mem_size: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3fd6e0463a..32b7f6b3d5 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -124,40 +124,19 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs - def free_a_req_mem( - self, - free_token_index: List, - req: "InferReq", - free_c4_index: Optional[List] = None, - free_c128_index: Optional[List] = None, - ): + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): is_dsv4_req_manager = hasattr(self.req_manager, "build_prompt_cache_payload") - if hasattr(self.req_manager, "pop_compress_indices_for_req") and not is_dsv4_req_manager: - c4, c128 = self.req_manager.pop_compress_indices_for_req(req.req_idx) - if c4 is not None and free_c4_index is not None: - free_c4_index.append(c4) - if c128 is not None and free_c128_index is not None: - free_c128_index.append(c128) - self.req_manager.clear_runtime_state(req.req_idx) - if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) if is_dsv4_req_manager: - c4, c128 = self.req_manager.pop_compress_indices_for_req(req.req_idx) - if c4 is not None and free_c4_index is not None: - free_c4_index.append(c4) - if c128 is not None and free_c128_index is not None: - free_c128_index.append(c128) - self.req_manager.clear_runtime_state(req.req_idx) + # 槽位随 full 槽经 mem_manager.free 级联回收。pause 路径不释放 req_idx, + # 必须在此复位出窗水位线 + 清 c128 在途状态(恢复命中走 extend,不会再有 + # restore/zero 时机;c4 状态随 swa 页生灭,无需处理)。 + self.req_manager.init_compress_state(req.req_idx) else: if not self.is_linear_att_mixed_model: if is_dsv4_req_manager: - self._dsv4_full_att_free_req( - free_token_index=free_token_index, - req=req, - free_c4_index=free_c4_index, - free_c128_index=free_c128_index, - ) + self._dsv4_full_att_free_req(free_token_index=free_token_index, req=req) else: self._full_att_free_req(free_token_index=free_token_index, req=req) else: @@ -187,79 +166,60 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"): req.shared_kv_node = None return - def _dsv4_full_att_free_req( - self, - free_token_index: List, - req: "InferReq", - free_c4_index: Optional[List] = None, - free_c128_index: Optional[List] = None, - ): + def _dsv4_full_att_free_req(self, free_token_index: List, req: "InferReq"): if req.cur_kv_len == 0: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0:0]) return old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len - cache_len = self.radix_cache.align_len(req.cur_kv_len) inserted_len = old_prefix_len duplicate_prefix_len = old_prefix_len - inserted_payload = None - pending_payload = getattr(req, "prompt_cache_snapshot_payload", None) - pending_cache_len = getattr(req, "prompt_cache_snapshot_len", 0) - - # The current V4 runtime state is only guaranteed to describe the current - # sequence end. Cache aligned current ends; leave unaligned tails uncached. - if pending_payload is not None and pending_cache_len > old_prefix_len: - cache_len = pending_cache_len - input_token_ids = req.get_input_token_ids() - key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu") - value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() - duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value, extra_value=pending_payload) - inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len - if inserted_len == cache_len: - inserted_payload = pending_payload - else: - self.req_manager.release_prompt_cache_detached_swa(pending_payload) - pending_payload = None - inserted_len = old_prefix_len - duplicate_prefix_len = old_prefix_len - elif cache_len == req.cur_kv_len and cache_len > old_prefix_len: - input_token_ids = req.get_input_token_ids() - key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu") - value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() + + # 载荷只剩按页 bitmap(compressor 状态随 swa 页生灭/边界自然归零,不进载荷), + # 任意 128 对齐前缀皆可插入——含生成段(floor(cur_kv_len) 边界,回收保留尾页保证其驻留)。 + cache_len = self.radix_cache.align_len(req.cur_kv_len) + if cache_len > old_prefix_len: payload = self.req_manager.build_prompt_cache_payload(req.req_idx, cache_len) - duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value, extra_value=payload) - inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len - if inserted_len == cache_len: - inserted_payload = payload - self.req_manager.detach_prompt_cache_payload_from_req(req.req_idx, inserted_payload) - else: - inserted_len = old_prefix_len - duplicate_prefix_len = old_prefix_len + value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() + # 按页有效性 bitmap 用插入时刻的映射写定(此后只会被阀清 0,不会复活)。水位线 + # 纯 CPU 推导,避免 router 关键路径上的 GPU gather 同步(每插入一次要等全部在途 + # decode kernel)。插入门: 截掉结尾的 invalid 页 —— 它们生来不可命中,还会永久 + # 挡住后续更长前缀复用同一段 token(全量重插会因前缀已存在而保留旧 bitmap)。 + page_size = self.req_manager.get_prompt_cache_page_size() + bitmap = self.req_manager.swa_page_valid_from_watermark(req.req_idx, cache_len) + n_pages = int(bitmap.numel()) + while n_pages > 0 and not bool(bitmap[n_pages - 1]): + n_pages -= 1 + gated_len = n_pages * page_size + if gated_len < cache_len: + logger.info( + f"DeepSeek-V4 prompt cache insert gate: trailing swa pages already evicted, " + f"shrink insert {cache_len} -> {gated_len}" + ) + cache_len = gated_len + payload.cache_len = cache_len + payload.swa_page_valid = bitmap[:n_pages].clone() + + if cache_len > old_prefix_len: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu") + duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value[:cache_len], extra_value=payload) + inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len + if inserted_len != cache_len: + inserted_len = old_prefix_len + duplicate_prefix_len = old_prefix_len - if ( - pending_payload is not None - and inserted_payload is not pending_payload - and pending_cache_len <= old_prefix_len - ): - self.req_manager.release_prompt_cache_detached_swa(pending_payload) - req.prompt_cache_snapshot_payload = None - req.prompt_cache_snapshot_len = 0 dense_row = self.req_manager.req_to_token_indexs[req.req_idx] self._append_free_token_index(free_token_index, dense_row[old_prefix_len:duplicate_prefix_len]) self._append_free_token_index(free_token_index, dense_row[inserted_len : req.cur_kv_len]) if len(free_token_index) == 0: free_token_index.append(dense_row[0:0]) + # 释放的 full 槽经 mem_manager.free 级联回收 swa/c4/c128(映射键控,无需收集槽位)。 - c4, c128 = self.req_manager.pop_prompt_cache_free_compress_indices( - req.req_idx, - keep_len=inserted_len, - duplicate_start_len=old_prefix_len, - duplicate_end_len=duplicate_prefix_len, - ) - if c4 is not None and free_c4_index is not None: - free_c4_index.append(c4) - if c128 is not None and free_c128_index is not None: - free_c128_index.append(c128) + # pause 路径不会走 req_manager.free/init: 复位出窗水位线(残留水位线会破坏下一次 + # prefill 的共享前缀保护)并清 c128 在途状态(恢复命中走 extend 续算,若残留暂停前的 + # 半窗聚合会算错;c128 状态在 128 对齐命中边界本应为零)。 + self.req_manager.init_compress_state(req.req_idx) if req.shared_kv_node is not None: assert req.shared_kv_node.node_prefix_total_len <= max(inserted_len, old_prefix_len) @@ -375,13 +335,11 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] - free_c4_index = [] - free_c128_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req, free_c4_index, free_c128_index) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") @@ -389,17 +347,7 @@ def _filter(self, finished_request_ids: List[int]): self.shm_req_manager.put_back_req_obj(req.shm_req) free_token_index = custom_cat(free_token_index) - if hasattr(self.req_manager, "free_compress_indices"): - free_c4_index = custom_cat(free_c4_index) if free_c4_index else None - free_c128_index = custom_cat(free_c128_index) if free_c128_index else None - self.req_manager.free( - free_req_index, - free_token_index, - free_c4_index=free_c4_index, - free_c128_index=free_c128_index, - ) - else: - self.req_manager.free(free_req_index, free_token_index) + self.req_manager.free(free_req_index, free_token_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -428,13 +376,11 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): g_infer_state_lock.acquire() free_token_index = [] - free_c4_index = [] - free_c128_index = [] for req in pause_reqs: if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req, free_c4_index, free_c128_index) + self.free_a_req_mem(free_token_index, req) assert req.wait_pause is True req.wait_pause = False req.paused = True @@ -445,13 +391,6 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if len(free_token_index) != 0: free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) - if hasattr(self.req_manager, "free_compress_indices"): - free_c4_index = custom_cat(free_c4_index) if free_c4_index else None - free_c128_index = custom_cat(free_c128_index) if free_c128_index else None - self.req_manager.free_compress_indices( - free_c4_index=free_c4_index, - free_c128_index=free_c128_index, - ) g_infer_state_lock.release() return self @@ -738,8 +677,6 @@ def _init_all_state(self): g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self) if hasattr(g_infer_context.req_manager, "init_compress_state"): g_infer_context.req_manager.init_compress_state(req_idx=self.req_idx) - self.prompt_cache_snapshot_len = 0 - self.prompt_cache_snapshot_payload = None self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list() # token healing mode 才被使用的管理对象 @@ -775,11 +712,9 @@ def _match_radix_cache(self): ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor - if hasattr(g_infer_context.req_manager, "restore_prompt_cache_payload"): - payload = g_infer_context.radix_cache.get_extra_value_by_node(share_node) - if payload is None: - raise RuntimeError("DeepSeek-V4 radix cache hit is missing prompt-cache payload") - g_infer_context.req_manager.restore_prompt_cache_payload(self.req_idx, payload) + # DeepSeek-V4 命中无需任何恢复: 槽位由 full_to_* 映射键控(radix 持有 full 槽即有效, + # 命中长度已在 match_prefix 内按 bitmap 裁剪),c4 compressor 状态随 swa 页常驻 + # (零拷贝续算),c128 状态在 128 对齐边界自然归零(init_compress_state 已清)。 self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 74fdb1e87b..4c5e1af222 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -212,6 +212,9 @@ def init_model(self, kvargs): page_size=radix_page_size, extra_value_ops=radix_extra_value_ops, ) + if radix_extra_value_ops is not None and hasattr(self.model.mem_manager, "set_swa_pressure_valve"): + # swa 页 allocator 触底时让 radix 对 ref==0 节点回收 swa 页(DeepSeek-V4)。 + self.model.mem_manager.set_swa_pressure_valve(self.radix_cache.reclaim_unreferenced_swa_pages) if "prompt_cache_kv_buffer" in model_cfg: assert self.use_dynamic_prompt_cache @@ -713,31 +716,6 @@ def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): """ pass - def _maybe_capture_prompt_cache_payload(self, req_obj: InferReq): - if self.radix_cache is None: - return - req_manager = g_infer_context.req_manager - if not hasattr(req_manager, "build_prompt_cache_payload"): - return - if req_obj.sampling_param.disable_prompt_cache: - return - page_size = getattr(self.args, "dynamic_prompt_cache_page_size", 1) - cache_len = int(req_obj.cur_kv_len) - if page_size <= 1 or cache_len <= 0 or cache_len % page_size != 0: - return - if cache_len > req_obj.shm_req.input_len: - return - if getattr(req_obj, "prompt_cache_snapshot_len", 0) >= cache_len: - return - - payload = req_manager.build_prompt_cache_payload(req_obj.req_idx, cache_len, clone_swa=True) - old_payload = getattr(req_obj, "prompt_cache_snapshot_payload", None) - if old_payload is not None: - req_manager.release_prompt_cache_detached_swa(old_payload, keep_payload=payload) - req_obj.prompt_cache_snapshot_len = cache_len - req_obj.prompt_cache_snapshot_payload = payload - return - # 一些可以复用的通用功能函数 def _pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: bool) -> List[InferReqUpdatePack]: update_func_objs: List[InferReqUpdatePack] = [] @@ -785,7 +763,6 @@ def _post_handle( ): req_obj: InferReq = req_obj pack: InferReqUpdatePack = pack - self._maybe_capture_prompt_cache_payload(req_obj) pack.handle( next_token_id=next_token_id, next_token_logprob=next_token_logprob, From b3b81237cd886fbe697c8835b24aedbd55db4078 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 11 Jun 2026 07:31:59 +0000 Subject: [PATCH 09/30] fix --- .../deepseek4_mem_manager.py | 8 - .../layer_infer/hyper_connection.py | 4 +- .../layer_infer/post_layer_infer.py | 1 + .../layer_infer/transformer_layer_infer.py | 139 +++++++++++------- 4 files changed, 87 insertions(+), 65 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 47fdf76fd9..3df5c0e12e 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -34,7 +34,6 @@ # c4 compressor state ring(overlap 对: 每页 2 个分组槽 × ratio 4 行)。c128 state 在 128 边界 # 自然归零(在线聚合),无缓存常驻需求,保持 req 键控,不进 swa 派生池。 DSV4_C4_STATE_RING = 8 -DSV4_PROFILE_MAX_FULL_TOKENS = 1_500_000 # swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 的对应物)。 # lightllm 的调度准入只看 full 池,prefill 优先的波次会让"已 prefill 未 decode"的请求整段 # prompt 占住 swa 槽(首次 decode prep 才批量出窗回收),峰值≈准入波次 prompt 总和。在 @@ -276,13 +275,6 @@ def profile_size(self, mem_fraction): dist.all_reduce(tensor, op=dist.ReduceOp.MIN) self.size = tensor.item() - if self.size > DSV4_PROFILE_MAX_FULL_TOKENS: - logger.info( - f"DeepseekV4MemoryManager cap profiled max_total_token_num from " - f"{self.size} to {DSV4_PROFILE_MAX_FULL_TOKENS} to keep runtime headroom" - ) - self.size = DSV4_PROFILE_MAX_FULL_TOKENS - logger.info( f"{str(available_memory)} GB space is available after load the model weight\n" f"{str(self.get_cell_size() / 1024 ** 2)} MB is the conservative size of one token kv cache\n" diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py index b125e9ed06..080ebabd89 100644 --- a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py +++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py @@ -58,9 +58,9 @@ def hc_post(x, residual, post_mix, res_mix): return torch.ops.vllm.mhc_post_tilelang(x, residual, post_mix, res_mix) -def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, rms_eps, hc_eps): +def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, rms_eps, hc_eps, alloc_func): """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim].""" - out = torch.empty(streams.shape[0], dim, device=streams.device, dtype=streams.dtype) + out = alloc_func((streams.shape[0], dim), dtype=streams.dtype, device=streams.device) torch.ops.vllm.hc_head_fused_kernel_tilelang( streams.view(-1, hc_mult, dim).contiguous(), hc_fn, diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py index bc95c249f7..8eddfb3b9d 100644 --- a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py @@ -22,5 +22,6 @@ def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, l cfg["hidden_size"], cfg["rms_norm_eps"], cfg.get("hc_eps", 1e-6), + self.alloc_tensor, ) return super().token_forward(collapsed, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 6f7c0de3fc..95119586fd 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -4,6 +4,7 @@ from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.distributed.communication_op import all_reduce +from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post @@ -66,7 +67,7 @@ def __init__(self, layer_num, network_config): self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 # ------------------------------------------------------------------ forward (HC-threaded) - def _hc_attn_in(self, input_embdings, layer_weight): + def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWeight): """Layer input -> attention input (attn_norm fused). First layer gets the raw streams and runs a standalone hc_pre; later layers get (x, residual, post_mix, res_mix) and fuse the previous layer's ffn hc_post with this layer's attn hc_pre.""" @@ -99,7 +100,7 @@ def _hc_attn_in(self, input_embdings, layer_weight): self.eps_, ) - def _hc_ffn_in(self, x, residual, post_mix, res_mix, layer_weight): + def _hc_ffn_in(self, x, residual, post_mix, res_mix, layer_weight: DeepseekV4TransformerLayerWeight): """Attention output -> ffn input (ffn_norm fused): fused attn hc_post + ffn hc_pre.""" return hc_fused_post_pre( x, @@ -124,14 +125,18 @@ def _hc_ffn_out(self, x, residual, post_mix, res_mix): streams = hc_post(x, residual, post_mix, res_mix) return streams.reshape(streams.shape[0], -1) - def context_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): + def context_forward( + self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight) x = self.context_attention_forward(x, infer_state, layer_weight) x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight) x = self._ffn(x, infer_state, layer_weight) return self._hc_ffn_out(x, residual, post_mix, res_mix) - def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): + def token_forward( + self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight) x = self.token_attention_forward(x, infer_state, layer_weight) x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight) @@ -144,36 +149,39 @@ def _select_rope(self, infer_state: DeepseekV4InferStateInfo): return infer_state.position_cos_compress, infer_state.position_sin_compress return infer_state.position_cos_sliding, infer_state.position_sin_sliding - def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): from sglang.jit_kernel.dsv4 import fused_q_norm_rope + x = self._tpsp_allgather(input=x, infer_state=infer_state) cos_tok, sin_tok = self._select_rope(infer_state) T = x.shape[0] qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(x), eps=self.eps_) q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_heads, self.head_dim) # per-(token, head) weightless self-RMSNorm + interleaved rope on the last rope_dim dims, # fused in one sglang dsv4 jit kernel (fp32 norm/rotation, bf16 in between -- same as eager). - q = torch.empty_like(q_in) + q = self.alloc_tensor(q_in.shape, dtype=q_in.dtype, device=q_in.device) fused_q_norm_rope(q_in, q, self.eps_, self.freqs_cis, infer_state.position_ids) - kv = layer_weight.kv_norm_(layer_weight.wkv_.mm(x), eps=self.eps_) - kv = torch.cat( - [ - kv[:, : -self.rope_dim], - apply_rotary_emb(kv[:, -self.rope_dim :], cos_tok, sin_tok), - ], - dim=1, + # kv: rmsnorm + rope + fp8 pack + scatter 进 swa 池,一个 sglang jit kernel 完成 + # (同 sglang _compute_kv_to_cache),替代 eager norm/rope/cat + _post_cache_kv。 + # bf16 kv 中间量没有其他消费者: flashmla 路径注意力读 cache,压缩器/indexer 取 x。 + infer_state.mem_manager.pack_mla_kv_to_cache_fused_norm_rope( + layer_index=self.layer_num_, + mem_index=infer_state.mem_index, + kv=layer_weight.wkv_.mm(x), + kv_weight=layer_weight.kv_norm_.weight, + eps=self.eps_, + freqs_cis=self.freqs_cis, + positions=infer_state.position_ids, ) - return q, kv, qa, cos_tok, sin_tok + return q, qa, cos_tok, sin_tok - def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): # o: [T, tp_q_heads, head_dim] after inverse rope -> grouped low-rank O -> [T, hidden] T = o.shape[0] o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] o = layer_weight.wo_b_.mm(o) - if self.tp_world_size_ > 1: - all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - return o + return self._tpsp_reduce(input=o, infer_state=infer_state) def _inv_rope(self, o, cos_tok, sin_tok): return torch.cat( @@ -190,7 +198,9 @@ def _inv_rope(self, o, cos_tok, sin_tok): ) # ------------------------------------------------------------------ compressor / indexer - def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _indexer_q_weight( + self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): if self.compress_ratio != 4: return None, None cos_tok = infer_state.position_cos_compress @@ -222,7 +232,7 @@ def _write_compressed_kv(self, infer_state: DeepseekV4InferStateInfo, req, entry infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, comp) return slots - def _compressor_weights(self, layer_weight, for_indexer: bool): + def _compressor_weights(self, layer_weight: DeepseekV4TransformerLayerWeight, for_indexer: bool): if for_indexer: return ( layer_weight.idx_cmp_wkv_.mm_param.weight, @@ -239,7 +249,9 @@ def _compressor_weights(self, layer_weight, for_indexer: bool): self.head_dim, ) - def _run_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _run_compressor_prefill( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): """Per-request compressor for the prefill chunk. Runs as part of the deferred attention func, before the attention metadata gathers the slot mappings. @@ -255,7 +267,9 @@ def _run_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, laye self._run_c128_compressor_prefill(x, infer_state, layer_weight) return - def _run_c4_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _run_c4_compressor_prefill( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): rm = infer_state.req_manager mem = infer_state.mem_manager wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) @@ -316,7 +330,9 @@ def _run_c4_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, l infer_state.mem_manager.pack_indexer_k_to_cache(self.layer_num_, slots, idx_comp) return - def _run_c128_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _run_c128_compressor_prefill( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): rm = infer_state.req_manager wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) b_req = infer_state.b_req_idx.tolist() @@ -365,7 +381,9 @@ def _run_c128_compressor_prefill(self, x, infer_state: DeepseekV4InferStateInfo, self._write_compressed_kv(infer_state, req, entry_start, entry.unsqueeze(0)) return - def _run_compressor_decode(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _run_compressor_decode( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): """Batched decode compressor (cuda-graph safe): state update for every request, cache write masked to the pool HOLD slot unless this token completes a window. Compressed-cache slots were pre-allocated by prepare_decode_compress_slots in the prep phase. @@ -460,24 +478,24 @@ def _run_compressor_decode(self, x, infer_state: DeepseekV4InferStateInfo, layer return # ------------------------------------------------------------------ attention (prefill) - def context_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): - q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) - # template hook: write the chunk's packed latent into the swa pool before attention - # reads it back via full_to_swa indices (this custom forward bypasses the tpl path). - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run(q, cache_kv, q_lora, x, infer_state, layer_weight) + def context_attention_forward( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + # _get_qkv writes the chunk's packed latent into the swa pool (fused kernel) before + # attention reads it back via full_to_swa indices (this custom forward bypasses the + # tpl _post_cache_kv path). + q, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) + o = self._context_attention_wrapper_run(q, q_lora, x, infer_state, layer_weight) return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, layer_weight) def _context_attention_wrapper_run( - self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight + self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight ): if torch.cuda.is_current_stream_capturing(): q = q.contiguous() - cache_kv = cache_kv.contiguous() q_lora = q_lora.contiguous() x = x.contiguous() _q = tensor_to_no_ref_tensor(q) - _cache_kv = tensor_to_no_ref_tensor(cache_kv) _q_lora = tensor_to_no_ref_tensor(q_lora) _x = tensor_to_no_ref_tensor(x) @@ -486,11 +504,13 @@ def _context_attention_wrapper_run( infer_state.prefill_cuda_graph_create_graph_obj() infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() - o = torch.empty((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) + # Same graph-split output handoff as the template, but avoid its dry-run because + # DSV4 attention mutates compressor/cache state before returning. + o = self.alloc_tensor((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) _o = tensor_to_no_ref_tensor(o) def att_func(new_infer_state: DeepseekV4InferStateInfo): - tmp_o = self._context_attention_kernel(_q, _cache_kv, _q_lora, _x, new_infer_state, layer_weight) + tmp_o = self._context_attention_kernel(_q, _q_lora, _x, new_infer_state, layer_weight) assert tmp_o.shape == _o.shape _o.copy_(tmp_o) return @@ -498,9 +518,11 @@ def att_func(new_infer_state: DeepseekV4InferStateInfo): infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph) return o - return self._context_attention_kernel(q, cache_kv, q_lora, x, infer_state, layer_weight) + return self._context_attention_kernel(q, q_lora, x, infer_state, layer_weight) - def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _context_attention_kernel( + self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): self._run_compressor_prefill(x, infer_state, layer_weight) idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) att_control = AttControl( @@ -511,7 +533,6 @@ def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: Deepsee "compress_ratio": self.compress_ratio, "head_dim_v": self.head_dim, "softmax_scale": self.softmax_scale, - "cache_kv": cache_kv, "q_lora": q_lora, "hidden_states": x, "attn_sink": layer_weight.attn_sink_.weight, @@ -530,13 +551,16 @@ def _context_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: Deepsee ) # ------------------------------------------------------------------ attention (decode) - def token_attention_forward(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): - q, cache_kv, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, cache_kv, q_lora, x, infer_state, layer_weight) + def token_attention_forward( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + q, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) + o = self._token_attention_kernel(q, q_lora, x, infer_state, layer_weight) return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, layer_weight) - def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _token_attention_kernel( + self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): self._run_compressor_decode(x, infer_state, layer_weight) idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) att_control = AttControl( @@ -547,7 +571,6 @@ def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV "compress_ratio": self.compress_ratio, "head_dim_v": self.head_dim, "softmax_scale": self.softmax_scale, - "cache_kv": cache_kv, "q_lora": q_lora, "hidden_states": x, "attn_sink": layer_weight.attn_sink_.weight, @@ -566,7 +589,7 @@ def _token_attention_kernel(self, q, cache_kv, q_lora, x, infer_state: DeepseekV ) # ------------------------------------------------------------------ moe - def _routed_experts(self, x, weights, indices, layer_weight): + def _routed_experts(self, x, weights, indices, layer_weight: DeepseekV4TransformerLayerWeight): return layer_weight.experts_.experts_with_preselected( input_tensor=x, topk_weights=weights, @@ -574,7 +597,11 @@ def _routed_experts(self, x, weights, indices, layer_weight): clamp_limit=float(self.swiglu_limit), ) - def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): + x = x.view(-1, self.hidden) + if not self.enable_ep_moe: + x = self._tpsp_allgather(input=x, infer_state=infer_state) + gw = layer_weight.gate_weight_.mm_param.weight logits = F.linear(x.float(), gw.float()).contiguous() weights, indices = self._select_experts(logits, infer_state, layer_weight) @@ -582,7 +609,7 @@ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): g = layer_weight.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) u = layer_weight.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) shared = layer_weight.shared_down_.mm((F.silu(g) * u).to(x.dtype)) - if self.enable_ep_moe and getattr(layer_weight.experts_, "is_ep", False): + if self.enable_ep_moe: if self.tp_world_size_ > 1: all_reduce( shared, @@ -592,14 +619,16 @@ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight): ) return routed + shared out = routed + shared - if self.tp_world_size_ > 1: - all_reduce(out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - return out + return self._tpsp_reduce(input=out, infer_state=infer_state) - def _select_experts(self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _select_experts( + self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): return self._select_experts_vllm(logits, infer_state, layer_weight) - def _select_experts_vllm(self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight): + def _select_experts_vllm( + self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): from vllm import _custom_ops as ops M = logits.shape[0] @@ -616,9 +645,9 @@ def _select_experts_vllm(self, logits, infer_state: DeepseekV4InferStateInfo, la else: bias = layer_weight.gate_bias_.weight - weights = torch.empty((M, self.topk), dtype=torch.float32, device=logits.device) - indices = torch.empty((M, self.topk), dtype=indices_dtype, device=logits.device) - token_expert_indices = torch.empty((M, self.topk), dtype=torch.int32, device=logits.device) + weights = self.alloc_tensor((M, self.topk), dtype=torch.float32, device=logits.device) + indices = self.alloc_tensor((M, self.topk), dtype=indices_dtype, device=logits.device) + token_expert_indices = self.alloc_tensor((M, self.topk), dtype=torch.int32, device=logits.device) ops.topk_hash_softplus_sqrt( weights, indices, From 6002866d2fc7c7858c26f7383364eb75bf19edd8 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 11 Jun 2026 12:47:19 +0000 Subject: [PATCH 10/30] fix rope --- .../deepseek2/triton_kernel/rotary_emb.py | 66 ++++++++++++------- .../layer_infer/transformer_layer_infer.py | 25 ++----- lightllm/models/deepseek_v4/model.py | 58 ++++++++-------- .../deepseek_v4/triton_kernel/rotary_emb.py | 26 -------- 4 files changed, 75 insertions(+), 100 deletions(-) delete mode 100644 lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py index 30e5a59248..a8f851de2a 100644 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py +++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py @@ -29,6 +29,8 @@ def _rotary_kernel( BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, NUM_STAGE: tl.constexpr, + HAS_K: tl.constexpr, + INVERSE: tl.constexpr, ): head_start_index = tl.program_id(0) seq_block_index = tl.program_id(1) @@ -44,6 +46,8 @@ def _rotary_kernel( off_dimcos_sin = seq_index * stride_cosbs + cos_range * stride_cosd cos = tl.load(Cos + off_dimcos_sin) sin = tl.load(Sin + off_dimcos_sin) + if INVERSE: + sin = -sin if HEAD_PARALLEL_NUM == 1: for q_head_index in tl.static_range(0, HEAD_Q, step=1): @@ -56,18 +60,19 @@ def _rotary_kernel( tl.store(Q + off_q0, out_q0) tl.store(Q + off_q1, out_q1) - for k_head_index in tl.static_range(0, HEAD_K, step=1): - off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd - off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd + if HAS_K: + for k_head_index in tl.static_range(0, HEAD_K, step=1): + off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd + off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd - k0 = tl.load(K + off_k0) - k1 = tl.load(K + off_k1) + k0 = tl.load(K + off_k0) + k1 = tl.load(K + off_k1) - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos - tl.store(K + off_k0, out_k0) - tl.store(K + off_k1, out_k1) + tl.store(K + off_k0, out_k0) + tl.store(K + off_k1, out_k1) else: for q_head_index in tl.range(head_start_index, HEAD_Q, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE): off_q0 = seq_index * stride_qbs + q_head_index * stride_qh + dim_range0 * stride_qd @@ -79,18 +84,19 @@ def _rotary_kernel( tl.store(Q + off_q0, out_q0) tl.store(Q + off_q1, out_q1) - for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE): - off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd - off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd + if HAS_K: + for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE): + off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd + off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd - k0 = tl.load(K + off_k0) - k1 = tl.load(K + off_k1) + k0 = tl.load(K + off_k0) + k1 = tl.load(K + off_k1) - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos - tl.store(K + off_k0, out_k0) - tl.store(K + off_k1, out_k1) + tl.store(K + off_k0, out_k0) + tl.store(K + off_k1, out_k1) return @@ -109,7 +115,10 @@ def get_test_configs(): def get_static_key(q, k): - head_num_q, head_num_k, head_dim = q.shape[1], k.shape[1], q.shape[2] + assert q is not None, "q can not be None" + head_num_q = q.shape[1] + head_num_k = k.shape[1] if k is not None else 0 + head_dim = q.shape[2] return { "Q_HEAD_NUM": head_num_q, "K_HEAD_NUM": head_num_k, @@ -126,12 +135,17 @@ def get_static_key(q, k): mutates_args=["q", "k"], ) @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, run_config=None): +def rotary_emb_fwd(q, k, cos, sin, inverse=False, run_config=None): + assert q is not None, "q can not be None" + has_k = k is not None and k.shape[1] != 0 total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] + head_num_q = q.shape[1] + head_num_k = k.shape[1] if k is not None else 0 head_dim = q.shape[2] assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + if k is not None: + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + assert k.shape[2] == head_dim, f"k shape {k.shape} q head_dim {head_dim}" assert triton.next_power_of_2(head_dim) == head_dim if not run_config: @@ -157,9 +171,9 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None): stride_qbs=q.stride(0), stride_qh=q.stride(1), stride_qd=q.stride(2), - stride_kbs=k.stride(0), - stride_kh=k.stride(1), - stride_kd=k.stride(2), + stride_kbs=k.stride(0) if k is not None else 0, + stride_kh=k.stride(1) if k is not None else 0, + stride_kd=k.stride(2) if k is not None else 0, stride_cosbs=cos.stride(0), stride_cosd=cos.stride(1), stride_sinbs=sin.stride(0), @@ -171,6 +185,8 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None): BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, NUM_STAGE=num_stages, + HAS_K=has_k, + INVERSE=inverse, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 95119586fd..e610473170 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -17,7 +17,7 @@ paged_prefill_compress_data, paged_decode_state_slots, ) -from ..triton_kernel.rotary_emb import apply_rotary_emb +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from ..infer_struct import DeepseekV4InferStateInfo @@ -184,18 +184,9 @@ def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: Deepsee return self._tpsp_reduce(input=o, infer_state=infer_state) def _inv_rope(self, o, cos_tok, sin_tok): - return torch.cat( - [ - o[..., : -self.rope_dim], - apply_rotary_emb( - o[..., -self.rope_dim :], - cos_tok.unsqueeze(1), - sin_tok.unsqueeze(1), - inverse=True, - ), - ], - dim=-1, - ) + # in-place; 单张量路径只需要旋转 rope 切片。 + rotary_emb_fwd(o[..., -self.rope_dim :], None, cos_tok, sin_tok, inverse=True) + return o # ------------------------------------------------------------------ compressor / indexer def _indexer_q_weight( @@ -206,13 +197,7 @@ def _indexer_q_weight( cos_tok = infer_state.position_cos_compress sin_tok = infer_state.position_sin_compress idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_heads, self.index_head_dim) - idx_q = torch.cat( - [ - idx_q[..., : -self.rope_dim], - apply_rotary_emb(idx_q[..., -self.rope_dim :], cos_tok.unsqueeze(1), sin_tok.unsqueeze(1)), - ], - dim=-1, - ) + rotary_emb_fwd(idx_q[..., -self.rope_dim :], None, cos_tok, sin_tok) idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale return idx_q, idx_weight diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 1a88e08977..63430e548b 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -22,7 +22,7 @@ from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import ( DeepseekV4TransformerLayerInfer, ) -from lightllm.common.basemodel.attention.create_utils import nsa_data_type_to_backend +from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo from lightllm.models.llama.yarn_rotary_utils import ( find_correction_range, @@ -125,11 +125,11 @@ def _init_att_backend(self): args = get_env_start_args() if args.llm_kv_type == "None": args.llm_kv_type = "fp8kv_dsa" + # TODO: 支持其他 kv type if args.llm_kv_type != "fp8kv_dsa": raise RuntimeError("DeepSeek-V4 requires llm_kv_type=fp8kv_dsa for packed FlashMLA sparse attention") - backend_cls = nsa_data_type_to_backend["fp8kv_dsa"]["flashmla_sparse"] - self.prefill_att_backend = backend_cls(model=self) - self.decode_att_backend = backend_cls(model=self) + self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) return def _init_custom(self): @@ -146,36 +146,36 @@ def _init_to_get_rotary(self): # Interleaved (GPT-J) rope. Build complex64 freqs_cis tables (_freqs_cis_*) following the # gemma4 two-variant convention; the fused sglang q kernel consumes them directly, while # _cos_cached_*/_sin_cached_* are .real/.imag views of the same storage for the kv rope, - # inverse rope and compressor paths (apply_rotary_emb: interleaved, NOT the NeoX - # rotary_emb_fwd). Sliding-window layers use base rope_theta (no YaRN); compressed (CSA/HCA) - # layers use compress_rope_theta with YaRN. Kept fp32 for accuracy (the apply upcasts anyway). + # inverse rope and compressor paths (deepseek2's interleaved triton rotary_emb_fwd). + # Sliding-window layers use base rope_theta (no YaRN); + # compressed (CSA/HCA) layers use compress_rope_theta with configured rope_scaling. + # Kept fp32 for accuracy (the apply upcasts anyway). cfg = self.config rs = cfg.get("rope_scaling", {}) or {} dim = cfg["qk_rope_head_dim"] - beta_fast = rs.get("beta_fast", 32) - beta_slow = rs.get("beta_slow", 1) max_seq = max(int(self.max_seq_length), int(cfg.get("max_position_embeddings", 8192))) max_seq = min(max_seq, 1 << 18) # cap table size (256K) for correctness-first - - def build(base, factor, orig_max): - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim)) - if orig_max > 0: - low, high = find_correction_range(beta_fast, beta_slow, dim, base, orig_max) - smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda() - freqs = freqs / factor * (1 - smooth) + freqs * smooth - f = torch.outer(torch.arange(max_seq, dtype=torch.float32, device="cuda"), freqs) # [max_seq, dim//2] - return torch.complex(f.cos(), f.sin()) - - self._freqs_cis_sliding = build( - cfg["rope_theta"], - rs.get("factor", 16), - rs.get("original_max_position_embeddings", 65536), - ) - self._freqs_cis_compress = build( - cfg["compress_rope_theta"], - rs.get("factor", 16), - rs.get("original_max_position_embeddings", 65536), - ) + freq_exponents = torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim + positions = torch.arange(max_seq, dtype=torch.float32, device="cuda") + + sliding_freqs = 1.0 / (cfg["rope_theta"] ** freq_exponents) + f = torch.outer(positions, sliding_freqs) # [max_seq, dim//2] + self._freqs_cis_sliding = torch.complex(f.cos(), f.sin()) + + compress_freqs = 1.0 / (cfg["compress_rope_theta"] ** freq_exponents) + rope_type = rs.get("rope_type", rs.get("type", "default")) + orig_max = rs.get("original_max_position_embeddings", 0) + if rope_type == "yarn" and orig_max > 0: + beta_fast = rs.get("beta_fast", 32) + beta_slow = rs.get("beta_slow", 1) + factor = rs.get("factor", 1) + if factor is None: + factor = cfg.get("max_position_embeddings", max_seq) / orig_max + low, high = find_correction_range(beta_fast, beta_slow, dim, cfg["compress_rope_theta"], orig_max) + smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda() + compress_freqs = compress_freqs / factor * (1 - smooth) + compress_freqs * smooth + f = torch.outer(positions, compress_freqs) # [max_seq, dim//2] + self._freqs_cis_compress = torch.complex(f.cos(), f.sin()) self._cos_cached_sliding = self._freqs_cis_sliding.real self._sin_cached_sliding = self._freqs_cis_sliding.imag self._cos_cached_compress = self._freqs_cis_compress.real diff --git a/lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py b/lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py deleted file mode 100644 index cb50977446..0000000000 --- a/lightllm/models/deepseek_v4/triton_kernel/rotary_emb.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch - -# Interleaved (GPT-J) rotary application for DeepSeek-V4. Unlike llama/gemma's NeoX-style -# rotary_emb_fwd (rotate-half: pairs channel i with i+d/2 over a real cos/sin table), V4 rotates -# adjacent pairs (x0,x1),(x2,x3),... — a different channel pairing — so it cannot reuse -# rotary_emb_fwd, but it consumes the same real cos/sin tables (built in model.py:_init_to_get_rotary -# as _cos_cached_*/_sin_cached_*, gemma4-style). Correctness-first pure-torch; a fused triton port is -# a perf follow-up. - - -def apply_rotary_emb(x, cos, sin, inverse=False): - """Apply interleaved rope to the LAST dim of x (size = 2*cos.size(-1)). - - x: [..., rope_dim] (real). cos/sin: [..., rope_dim//2], broadcastable to x's paired view. - For x of shape [N, H, rope_dim], pass cos/sin [N, 1, rope_dim//2]; for [N, rope_dim] pass [N, rope_dim//2]. - Returns a new tensor of x's dtype (not in-place). inverse=True applies the conjugate rotation. - """ - dtype = x.dtype - x = x.float().reshape(*x.shape[:-1], -1, 2) - x0, x1 = x[..., 0], x[..., 1] - cos = cos.float() - sin = sin.float() - if inverse: - sin = -sin - out = torch.stack([x0 * cos - x1 * sin, x0 * sin + x1 * cos], dim=-1) - return out.flatten(-2).to(dtype) From 6bc34adb3896473a8eda46609c232b571faa7d8e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 11 Jun 2026 06:53:24 +0000 Subject: [PATCH 11/30] dsv4: enable decode cudagraph; fix warmup-baked FlashMLASchedMeta Root cause of the historical cudagraph accuracy drop (gsm8k 0.96 -> 0.74, coherent-but-runaway generations; same 0.75 the pre-v5 fullslot_decode experiments worked around): _capture_decode warms up via copy.copy(infer_state), which SHARES decode_att_state. FlashMLASchedMeta is lazily planned at the first kernel call and written back onto that shared state, so the warmup pass locks a schedule planned for the dummy batch (seq=2); the capture pass then binds those stale scheduler tensors and every replay runs real requests with a tile schedule planned for near-empty kv (systematically under-read attention). Fix: reset_sched_meta_for_capture() hook on the nsa decode att state, invoked in both capture paths after warmup, so planning happens INSIDE the captured region and re-plans on every replay from live tensors. Validation (tp4, H200, prompt cache on): batch-1 greedy decode is now character-identical to eager; per-layer probe shows embed+swa layers bitwise equal under replay, benign rounding-class deltas only in compress layers, argmax unchanged. gsm8k 100q/128: cold 0.960/111s, warm 0.960/23.3s 100% hits (eager: 0.95-0.97, cold 141s / warm 50s). Batch-1 decode 20.4ms/token vs 142ms eager. 41/41 unit tests green. Codex review GO (incl. overlap-path symmetry). launch.sh: drop --disable_cudagraph, derive PYTHONPATH from the script dir (hardcoded tree path made a worktree launch silently serve main-tree code). Co-Authored-By: Claude Fable 5 --- launch.sh | 39 +++++++++++++++++++ .../attention/nsa/fp8_flashmla_sparse.py | 8 ++++ lightllm/common/basemodel/cuda_graph.py | 19 +++++++++ 3 files changed, 66 insertions(+) create mode 100644 launch.sh diff --git a/launch.sh b/launch.sh new file mode 100644 index 0000000000..ce016efc96 --- /dev/null +++ b/launch.sh @@ -0,0 +1,39 @@ +# DeepSeek-V4-Flash serving (run inside the lightllm container, repo mounted at /data/wanzihao/lightllm-ds4). +# Verified 2026-06-11: smoke + gsm8k pass with this configuration (prompt cache ENABLED, decode +# cudagraph ENABLED; gsm8k 100q/128: cold 0.960/112s, warm 0.970/23.5s with 100% cache hits — +# vs eager cold 0.970/141s, warm 0.960/50s; batch-1 decode 20.4ms/token vs 142ms eager). +# +# Required env/flags and why: +# LOADWORKER=16 - parallel weight loading (~5x faster startup). +# Optional sizing knobs (defaults shown): LIGHTLLM_DSV4_SWA_FULL_TOKENS_RATIO=0.1 (swa pool floor +# as a fraction of full tokens; raise for long-prompt x high-parallel workloads), +# LIGHTLLM_DSV4_PROFILE_MAX_FULL_TOKENS=1500000 (auto-profile cap on max_total_token_num). +# PYTHONPATH sglang - _get_qkv / compressor reuse sglang.jit_kernel.dsv4 (fused_q_norm_rope, compress_old). +# --batch_max_tokens 8192 - FlashMLA get_decoding_sched_meta rejects >8192 rows per call (probed: 8192 OK, 12288 fails). +# decode cudagraph ENABLED - the v5 decode path is graph-safe: slot alloc/scatter in prep (outside +# graph), forward is pure gathers, HOLD padding rows redirect to HOLD slots. CORRECTNESS NOTE: +# FlashMLASchedMeta is lazily planned at first kernel call and written back onto the (shared) +# decode att state; the capture warmup pass would bake a dummy-content plan into the graph +# (gsm8k dropped to 0.74 with coherent-but-runaway generations). reset_sched_meta_for_capture() +# in cuda_graph._capture_decode re-plans INSIDE the captured region so every replay re-plans. +# DSV4 caps graph max_len_in_batch at 8192; longer decode batches fall back to eager. +# --disable_flashinfer_allreduce - flashinfer cuda_ipc resolves libcudart to tilelang's stub (undefined cudaDeviceReset); symm-mem allreduce is used instead. +# +# One-time container setup already applied (survives until container rebuild): +# pip install ipython (sglang import dependency) +# site-packages/vllm: layers/mhc.py + kernels/mhc/ + _tilelang_ops.py overlaid from /data/wanzihao/vllm (mhc_pre_tilelang ops; original kept at layers/mhc.py.bak) +# +# original: python -m lightllm.server.api_server --model_dir /data/models/DeepSeek-V4-Flash --tp 4 --enable_prefill_cudagraph + +# repo root = this script's directory, so the same file works in the main tree and in worktrees +# (a hardcoded tree path here once made a worktree launch silently serve main-tree code). +REPO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +LOADWORKER=16 \ +PYTHONPATH="${REPO_DIR}":/data/wanzihao/sglang/python \ +python -m lightllm.server.api_server \ + --model_dir /data/models/DeepSeek-V4-Flash \ + --tp 4 \ + --batch_max_tokens 8192 \ + --disable_flashinfer_allreduce \ + --port 8000 diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 0570adea83..14b1b3307d 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -464,6 +464,14 @@ def init_state(self): self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} return + def reset_sched_meta_for_capture(self): + # cuda-graph capture hook: the warmup pass already locked/stored sched meta on this + # (shared) state object; reset so the capture pass re-plans INSIDE the graph and every + # replay re-plans from the live tensors instead of binding warmup leftovers. + flash_mla = self.backend.flash_mla() + self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} + return + def decode_att( self, q: Tuple[torch.Tensor, torch.Tensor], diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..e1d96b744e 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -14,6 +14,20 @@ logger = init_logger(__name__) +def _reset_att_state_sched_meta(infer_state: InferStateInfo): + # capture 前调用: warmup 趟用 copy.copy 浅拷贝共享 decode_att_state,其内部惰性初始化的 + # 调度对象(如 FlashMLASchedMeta,首次内核调用时按当时数据规划并回写)会被 warmup 的 + # dummy 负载锁定;若不重置,捕获趟将绑定为 dummy 规划的调度张量,所有 replay 都用错误 + # 的 tile schedule(DSV4 实测 gsm8k 0.96 -> 0.74)。重置后规划发生在捕获区内,随 replay 重算。 + for att_state in (infer_state.decode_att_state, infer_state.decode_att_state1): + if att_state is None: + continue + reset_fn = getattr(att_state, "reset_sched_meta_for_capture", None) + if reset_fn is not None: + reset_fn() + return + + class CudaGraph: # CudaGraph forward pass for the decoding stage. @@ -94,6 +108,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): if param_name not in pure_para_set: delattr(infer_state, param_name) + _reset_att_state_sched_meta(infer_state) + with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) @@ -128,6 +144,9 @@ def _capture_decode_overlap( if para_name not in pure_para_set1: delattr(infer_state1, para_name) + _reset_att_state_sched_meta(infer_state) + _reset_att_state_sched_meta(infer_state1) + with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( From e78e0d429dff4a58aad5cb3e9ac8158b215d870f Mon Sep 17 00:00:00 2001 From: wanzihao Date: Thu, 11 Jun 2026 08:32:14 +0000 Subject: [PATCH 12/30] dsv4: enable prefill cudagraph; zero pad-row attention output Graph-sandwich prefill (graphs capture dense ops only; attention/compressor run eagerly between segments) was already in-tree; enabling it exposed that HOLD-pad rows read the racing HOLD slot, making their hiddens nondeterministic and perturbing real rows via MoE expert batching (ulp-level, amplified ~1.9x/layer). Zero the pad rows' attention output. Residual greedy-trajectory divergence vs eager equals the fp4 marlin MoE kernel's own run-to-run reduction-order noise (eager-vs-eager control: 0/4 match), accepted statistically: gsm8k 100q cold 0.980/115.5s warm 0.960/25.9s (eager-baseline parity); batch-1 TTFT 1.86x at 46 tokens. --- launch.sh | 17 ++++++++++++++--- lightllm/models/deepseek_v4/infer_struct.py | 7 +++++++ .../layer_infer/transformer_layer_infer.py | 7 ++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/launch.sh b/launch.sh index ce016efc96..ccf870dc8c 100644 --- a/launch.sh +++ b/launch.sh @@ -5,9 +5,6 @@ # # Required env/flags and why: # LOADWORKER=16 - parallel weight loading (~5x faster startup). -# Optional sizing knobs (defaults shown): LIGHTLLM_DSV4_SWA_FULL_TOKENS_RATIO=0.1 (swa pool floor -# as a fraction of full tokens; raise for long-prompt x high-parallel workloads), -# LIGHTLLM_DSV4_PROFILE_MAX_FULL_TOKENS=1500000 (auto-profile cap on max_total_token_num). # PYTHONPATH sglang - _get_qkv / compressor reuse sglang.jit_kernel.dsv4 (fused_q_norm_rope, compress_old). # --batch_max_tokens 8192 - FlashMLA get_decoding_sched_meta rejects >8192 rows per call (probed: 8192 OK, 12288 fails). # decode cudagraph ENABLED - the v5 decode path is graph-safe: slot alloc/scatter in prep (outside @@ -17,6 +14,18 @@ # (gsm8k dropped to 0.74 with coherent-but-runaway generations). reset_sched_meta_for_capture() # in cuda_graph._capture_decode re-plans INSIDE the captured region so every replay re-plans. # DSV4 caps graph max_len_in_batch at 8192; longer decode batches fall back to eager. +# --enable_prefill_cudagraph + --prefill_cudagraph_max_handle_token 2048 - graph-sandwich prefill: +# graphs capture only the per-token dense ops; attention/compressor/indexer run eagerly between +# graph segments (att_func), so host-side planning and .tolist() prep never enter capture. Only +# cold prefills (prefix_total_token_num == 0, model gate) of <= 2048 new tokens replay; cache-hit +# and large batched prefills stay eager. Buckets are padded with a HOLD tail request whose +# attention output MUST be zeroed (infer_struct._dsv4_prefill_pad_q_len): pad rows read the +# racing HOLD slot, and nondeterministic pad hiddens perturb real rows via MoE expert batching +# (ulp-level, chaotically amplified ~1.9x/layer to O(1) by layer ~16 -> greedy token flips). +# Residual caveat: padded-vs-unpadded expert-batch composition still shifts reductions by ulps, +# same class as decode bucket padding; run-to-run determinism is anyway bounded by the fp4 +# marlin MoE kernel itself (probabilistic 1-ulp reduction-order noise measured eager-vs-eager). +# Acceptance is therefore statistical (gsm8k parity), not bitwise. # --disable_flashinfer_allreduce - flashinfer cuda_ipc resolves libcudart to tilelang's stub (undefined cudaDeviceReset); symm-mem allreduce is used instead. # # One-time container setup already applied (survives until container rebuild): @@ -36,4 +45,6 @@ python -m lightllm.server.api_server \ --tp 4 \ --batch_max_tokens 8192 \ --disable_flashinfer_allreduce \ + --enable_prefill_cudagraph \ + --prefill_cudagraph_max_handle_token 2048 \ --port 8000 diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py index 39a6889d72..caf8ca6fa8 100644 --- a/lightllm/models/deepseek_v4/infer_struct.py +++ b/lightllm/models/deepseek_v4/infer_struct.py @@ -26,3 +26,10 @@ def init_some_extra_state(self, model): self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos) self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos) self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos) + # prefill-cudagraph 桶填充的 HOLD 尾请求的 q 行数。其注意力读 HOLD 槽位(内容被并发写 + # 竞争,每轮不同),输出必须清零,否则 pad 行 hidden 不确定 -> MoE 路由抖动 -> 共享 expert + # 批次组成变化 -> 真实行 GEMM 归约顺序变化(ulp 级),44 层放大后翻转低置信 token。 + self._dsv4_prefill_pad_q_len = 0 + if self.is_prefill and self.b_req_idx.numel() > 0: + if int(self.b_req_idx[-1].item()) == self.req_manager.HOLD_REQUEST_ID: + self._dsv4_prefill_pad_q_len = int((self.b_seq_len[-1] - self.b_ready_cache_len[-1]).item()) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index e610473170..761558c95d 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -528,12 +528,17 @@ def _context_attention_kernel( "tp_world_size": self.tp_world_size_, }, ) - return infer_state.prefill_att_state.prefill_att( + out = infer_state.prefill_att_state.prefill_att( q=q, k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), v=None, att_control=att_control, ) + pad_q_len = getattr(infer_state, "_dsv4_prefill_pad_q_len", 0) + if pad_q_len: + # pad 行读 HOLD 槽位(参见 infer_struct._dsv4_prefill_pad_q_len),清零以保持确定性 + out[-pad_q_len:] = 0 + return out # ------------------------------------------------------------------ attention (decode) def token_attention_forward( From c09dc6aa1d90a6de9b63135dd80748bbb2b7b27d Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 11 Jun 2026 14:29:49 +0000 Subject: [PATCH 13/30] fix profile --- launch.sh | 3 + .../deepseek4_mem_manager.py | 43 ++++++++++-- lightllm/common/quantization/deepgemm.py | 70 +++++++++++++++++-- 3 files changed, 106 insertions(+), 10 deletions(-) diff --git a/launch.sh b/launch.sh index ccf870dc8c..b9c10d3f0a 100644 --- a/launch.sh +++ b/launch.sh @@ -7,6 +7,9 @@ # LOADWORKER=16 - parallel weight loading (~5x faster startup). # PYTHONPATH sglang - _get_qkv / compressor reuse sglang.jit_kernel.dsv4 (fused_q_norm_rope, compress_old). # --batch_max_tokens 8192 - FlashMLA get_decoding_sched_meta rejects >8192 rows per call (probed: 8192 OK, 12288 fails). +# kv pool sizing: auto-profiled from mem_fraction. The fp4 marlin MoE weights materialize their +# CUDA marlin-layout buffers at construction (MXFP4MoEQuantizationMethod._create_weight), so the +# profile sees the true weight footprint on any GPU/config. --max_total_token_num overrides. # decode cudagraph ENABLED - the v5 decode path is graph-safe: slot alloc/scatter in prep (outside # graph), forward is pure gathers, HOLD padding rows redirect to HOLD slots. CORRECTNESS NOTE: # FlashMLASchedMeta is lazily planned at first kernel call and written back onto the (shared) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 3df5c0e12e..f87a11704d 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -34,11 +34,10 @@ # c4 compressor state ring(overlap 对: 每页 2 个分组槽 × ratio 4 行)。c128 state 在 128 边界 # 自然归零(在线聚合),无缓存常驻需求,保持 req 键控,不进 swa 派生池。 DSV4_C4_STATE_RING = 8 -# swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 的对应物)。 -# lightllm 的调度准入只看 full 池,prefill 优先的波次会让"已 prefill 未 decode"的请求整段 -# prompt 占住 swa 槽(首次 decode prep 才批量出窗回收),峰值≈准入波次 prompt 总和。在 -# v5 的 swa 压力阀/准入耦合落地前,用比 sglang 更宽的 0.3 兜住该瞬时峰值。 -DSV4_SWA_FULL_TOKENS_RATIO = 0.3 +# swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 同值)。 +# v5 的 swa 压力阀(借页/驱逐)已覆盖 radix 树与准入波次的瞬时增长,结构性预算 +# (max_req×window + batch_max_tokens 余量)另行叠加,0.1 仅作 full 池比例下限。 +DSV4_SWA_FULL_TOKENS_RATIO = 0.1 def _ceil_div(a: int, b: int) -> int: @@ -702,6 +701,40 @@ def pack_mla_kv_to_cache(self, layer_index: int, mem_index: torch.Tensor, kv: to ) return + def pack_mla_kv_to_cache_fused_norm_rope( + self, + layer_index: int, + mem_index: torch.Tensor, + kv: torch.Tensor, + kv_weight: torch.Tensor, + eps: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + ): + """同 pack_mla_kv_to_cache,但 rmsnorm + 尾部交错 rope 融合进写入 kernel + (sglang fused_k_norm_rope_flashmla,即 sglang _compute_kv_to_cache 的池侧), + 省掉 bf16 kv 中间量。kv 为 wkv 投影原始输出 [T, head_dim+rope_dim]。""" + if kv.shape[0] == 0: + return + from sglang.jit_kernel.dsv4 import fused_k_norm_rope_flashmla + + swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)] + # 未映射槽位(-1, 如 decode 图 warmup 的 HOLD 行: prep 跳过 alloc_swa)对老 triton + # 写入核是显式 no-op;sglang fused 核无负槽位防护(负页偏移=非法访存),mask 到 + # swa HOLD 槽(垃圾桶语义,与 padding 行写入一致)。 + swa_slots = torch.where(swa_slots < 0, torch.full_like(swa_slots, self.swa_pool.HOLD_TOKEN_MEMINDEX), swa_slots) + fused_k_norm_rope_flashmla( + kv=kv, + kv_weight=kv_weight, + eps=eps, + freqs_cis=freqs_cis, + positions=positions, + out_loc=swa_slots, + kvcache=self.swa_pool.get_layer_buffer(layer_index), + page_size=self.swa_pool.page_size, + ) + return + def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor): if comp.shape[0] == 0: return diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 677d3b7dd7..bedf22ee95 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -227,18 +227,68 @@ def apply( ) -> torch.Tensor: raise NotImplementedError("marlin-mxfp4w4a16-b32 is only implemented for fused MoE expert weights") + def _probe_marlin_layout(self, size_n: int, size_k: int, dtype: torch.dtype, device_id: int): + """用零输入走一遍真实的 per-expert repack 路径,探出 marlin 终态布局的形状与类型。 + 只调用 finalize 同款的 vllm 函数,不复刻其内部公式,杜绝形状漂移。结果按维度缓存 + (各 MoE 层同维,全程只探两次: w13 一次、w2 一次)。""" + cache_key = (size_n, size_k, dtype) + cache = getattr(self, "_marlin_layout_cache", None) + if cache is None: + cache = self._marlin_layout_cache = {} + if cache_key in cache: + return cache[cache_key] + + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + mxfp4_marlin_process_scales, + ) + + input_dtype = get_marlin_input_dtype() + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + device = f"cuda:{device_id}" + qweight = torch.zeros((size_n, size_k // 2), dtype=torch.int8, device=device).view(torch.int32).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + is_a_8bit=is_a_8bit, + ) + scale = torch.zeros((size_k // self.block_size, size_n), dtype=dtype, device=device) + marlin_scale = marlin_permute_scales( + s=scale, size_k=size_k, size_n=size_n, group_size=self.block_size, is_a_8bit=is_a_8bit + ) + marlin_scale = mxfp4_marlin_process_scales(marlin_scale, input_dtype=input_dtype) + layout = ( + (tuple(marlin_qweight.shape), marlin_qweight.dtype), + (tuple(marlin_scale.shape), marlin_scale.dtype), + ) + cache[cache_key] = layout + return layout + def _create_weight( self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims - assert in_dim % 2 == 0, "MXFP4 packed weight requires even input dimension" assert in_dim % self.block_size == 0, "MXFP4 scale dimension must be divisible by block_size" expert_prefix = (num_experts,) if num_experts > 1 else () + # CPU 暂存区: load_hf_weights 灌入原始预打包 MXFP4,finalize 时 repack 进 CUDA 终态。 weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8, device="cpu") weight_scale = torch.empty( expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float8_e8m0fnu, device="cpu" ) mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + # CUDA 终态(marlin 布局)在构造期物化,使 mem manager 的 profile 看到真实权重占用 + # ("构造即分配、load 只灌数"的框架契约,与其它 quant 方法一致;惰性到 finalize 才 + # 进卡会让空卡 profile 把 kv 池撑到挤爆权重加载)。finalize 时 repack 结果拷入。 + (w_shape, w_dtype), (s_shape, s_dtype) = self._probe_marlin_layout(out_dim, in_dim, dtype, device_id) + mm_param.marlin_weight = torch.empty((num_experts,) + w_shape, dtype=w_dtype, device=f"cuda:{device_id}") + mm_param.marlin_weight_scale = torch.empty((num_experts,) + s_shape, dtype=s_dtype, device=f"cuda:{device_id}") mm_param_list = self._split_weight_pack( mm_param, weight_out_dims=out_dims, @@ -267,13 +317,23 @@ class _MXFP4Layer: w13_scale = moe_weight.w13.weight_scale.to(device=device, non_blocking=True).contiguous() w2_scale = moe_weight.w2.weight_scale.to(device=device, non_blocking=True).contiguous() ( - moe_weight.w13.weight, - moe_weight.w2.weight, - moe_weight.w13.weight_scale, - moe_weight.w2.weight_scale, + w13_new, + w2_new, + w13_scale_new, + w2_scale_new, _, _, ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None) + # repack 结果拷入构造期预分配的 marlin 终态 buffer(与 AWQ marlin 路径同形态), + # CPU 暂存与 repack 临时随引用释放;shape 失配会在 copy_ 处显式报错(探针保证一致)。 + moe_weight.w13.marlin_weight.copy_(w13_new) + moe_weight.w13.marlin_weight_scale.copy_(w13_scale_new) + moe_weight.w2.marlin_weight.copy_(w2_new) + moe_weight.w2.marlin_weight_scale.copy_(w2_scale_new) + moe_weight.w13.weight = moe_weight.w13.marlin_weight + moe_weight.w13.weight_scale = moe_weight.w13.marlin_weight_scale + moe_weight.w2.weight = moe_weight.w2.marlin_weight + moe_weight.w2.weight_scale = moe_weight.w2.marlin_weight_scale def _deepgemm_fp8_nt(a_tuple, b_tuple, out): From c07e38c5979baf3af04ef77773edb7fc90d34f1c Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 12 Jun 2026 02:00:32 +0000 Subject: [PATCH 14/30] support fp8 --- .../meta_weights/fused_moe/impl/deepgemm_impl.py | 2 ++ .../meta_weights/fused_moe/impl/marlin_impl.py | 2 ++ .../meta_weights/fused_moe/impl/triton_impl.py | 3 +++ .../triton_kernel/fused_moe/moe_silu_and_mul.py | 11 ++++++++++- .../layer_infer/transformer_layer_infer.py | 4 +++- .../layer_weights/transformer_layer_weight.py | 10 ++++++++-- 6 files changed, 28 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..72acf2430a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,7 +76,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, ): + assert clamp_limit is None, "EP deepgemm fused MoE does not support clamp_limit yet" output = fused_experts( hidden_states=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..1fdfd94d0d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,7 +30,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, ): + assert clamp_limit is None, "awq_marlin fused MoE does not support clamp_limit yet" w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 09ce88e3fd..8967dda34e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -94,6 +94,7 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + clamp_limit: Optional[float] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -111,6 +112,7 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + limit=clamp_limit, ) return input_tensor @@ -131,6 +133,7 @@ def fused_experts_with_topk( topk_weights=topk_weights, topk_ids=topk_ids, is_prefill=is_prefill, + clamp_limit=clamp_limit, ) def __call__( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index 45c7ea73c6..82fc9131c1 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -24,6 +24,7 @@ def _silu_and_mul_kernel_fast( NEED_MASK: tl.constexpr, layout: tl.constexpr = "blocked", # "blocked" or "interleaved" USE_LIMIT_AND_ALPHA: tl.constexpr = False, + USE_LIMIT_ONLY: tl.constexpr = False, USE_TANH_APPROXIMATE_GELU: tl.constexpr = False, ): stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) @@ -76,6 +77,11 @@ def _silu_and_mul_kernel_fast( mask=mask, ) else: + if USE_LIMIT_ONLY: + # clamped swiglu (DeepSeek-V4 swiglu_limit): clamp 后接标准 silu, + # 无 gpt-oss 的 alpha 缩放与 (up+1)。 + gate = tl.minimum(gate, limit) + up = tl.minimum(tl.maximum(up, -limit), limit) if USE_TANH_APPROXIMATE_GELU: # tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP. gate_cubed = gate * gate * gate @@ -124,7 +130,8 @@ def silu_and_mul_fwd( ): assert input.is_contiguous() assert output.is_contiguous() - assert (limit is None and alpha is None) or (limit is not None and alpha is not None) + # limit+alpha: gpt-oss 语义 (up+1)*silu(alpha*gate); 仅 limit: clamp 后标准 silu (DeepSeek-V4) + assert alpha is None or limit is not None stride_input_m = input.stride(0) stride_input_n = input.stride(1) @@ -147,6 +154,7 @@ def silu_and_mul_fwd( while triton.cdiv(size_m, BLOCK_M) > 8192: BLOCK_M *= 2 USE_LIMIT_AND_ALPHA = limit is not None and alpha is not None + USE_LIMIT_ONLY = limit is not None and alpha is None grid = ( triton.cdiv(size_n, BLOCK_N), @@ -171,6 +179,7 @@ def silu_and_mul_fwd( num_warps=num_warps, layout=layout, USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA, + USE_LIMIT_ONLY=USE_LIMIT_ONLY, USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(), ) return diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 761558c95d..613fb58097 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -595,10 +595,12 @@ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV gw = layer_weight.gate_weight_.mm_param.weight logits = F.linear(x.float(), gw.float()).contiguous() weights, indices = self._select_experts(logits, infer_state, layer_weight) - routed = self._routed_experts(x, weights, indices, layer_weight) + # shared expert 必须先于 routed 计算: fp8 路径 (FuseMoeTriton) 的 fused_experts + # 是 inplace 的,_routed_experts 返回后 x 已被覆盖为 routed 输出。 g = layer_weight.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) u = layer_weight.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) shared = layer_weight.shared_down_.mm((F.silu(g) * u).to(x.dtype)) + routed = self._routed_experts(x, weights, indices, layer_weight) if self.enable_ep_moe: if self.tp_world_size_ > 1: all_reduce( diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py index a95299628c..7b2f67e123 100644 --- a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -16,7 +16,8 @@ class DeepseekV4TransformerLayerWeight(TransformerLayerWeight): """Per-layer weights for DeepSeek-V4-Flash. DS4 does not share DS2/DS3.2's ``model.layers.*.self_attn/mlp`` layout. Its attention is - HC + CSA, and routed experts are checkpointed as MXFP4. + HC + CSA, and routed experts are checkpointed as MXFP4 (fp4 release) or + FP8 block-128 (fp8 release, same layout as the dense fp8 weights). """ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): @@ -306,12 +307,17 @@ def _dequant_in_place(self, weights): scale_renames = self._fp8_scale_renames() # Convert every `.scale` belonging to this layer. Weights are loaded incrementally # per safetensors shard, so the paired weight may live in another shard: - # - routed FP4 experts keep `.scale` as-is (matches marlin-mxfp4w4a16-b32's suffix); + # - routed expert `.scale` follows the fused_moe quant method's weight_scale_suffix: + # MXFP4 consumes `.scale` as-is, FP8 DeepGEMM expects `.weight_scale_inv` (rename only); # - FP8 matmul scales only need renaming for DeepGEMM, no weight required; # - FP8 pairs on no-quant paths (wo_a's ROWBMMWeight) are expanded to bf16, # the only case that truly requires weight and scale in the same shard. + expert_scale_suffix = self.experts_.quant_method.weight_scale_suffix for scale_k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".scale")]: if scale_k.startswith(f"{p}ffn.experts."): + if expert_scale_suffix is not None and expert_scale_suffix != "scale": + weights[scale_k[: -len("scale")] + expert_scale_suffix] = weights[scale_k].to(torch.float32) + del weights[scale_k] continue k = scale_k[: -len(".scale")] + ".weight" target = scale_renames.get(k) From ff717061ec5d664218762a785ad752e923f9f50f Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 12 Jun 2026 03:23:35 +0000 Subject: [PATCH 15/30] optimize --- .../layer_infer/transformer_layer_infer.py | 115 +++++++++--------- 1 file changed, 57 insertions(+), 58 deletions(-) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 613fb58097..23f742b7d3 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -4,6 +4,7 @@ from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.distributed.communication_op import all_reduce +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -21,27 +22,26 @@ from ..infer_struct import DeepseekV4InferStateInfo -class DeepseekV4TransformerLayerInfer(TransformerLayerInferTpl): +class DeepseekV4TransformerLayerInfer(Deepseek3_2TransformerLayerInfer): def __init__(self, layer_num, network_config): - super().__init__(layer_num, network_config) - cfg = network_config - self.eps_ = cfg["rms_norm_eps"] - self.hidden = cfg["hidden_size"] - self.n_heads = cfg["num_attention_heads"] - self.head_dim = cfg["head_dim"] - self.rope_dim = cfg["qk_rope_head_dim"] - self.index_n_heads = cfg["index_n_heads"] - self.index_head_dim = cfg["index_head_dim"] - self.index_topk = cfg["index_topk"] - self.o_groups = cfg["o_groups"] - self.o_lora = cfg["o_lora_rank"] - self.hc_mult = cfg["hc_mult"] - self.sinkhorn_iters = cfg["hc_sinkhorn_iters"] - self.hc_eps = cfg["hc_eps"] - self.window = cfg["sliding_window"] - self.compress_ratio = cfg["compress_ratios"][layer_num] - self.is_hash = layer_num < cfg["num_hash_layers"] - self.is_last_layer = layer_num == cfg["n_layer"] - 1 + TransformerLayerInferTpl.__init__(self, layer_num, network_config) + self.eps_ = network_config["rms_norm_eps"] + self.embed_dim_ = network_config["hidden_size"] + self.num_heads = network_config["num_attention_heads"] + self.head_dim_ = network_config["head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.qk_nope_head_dim = self.head_dim_ - self.qk_rope_head_dim + self.v_head_dim = self.head_dim_ + self.index_n_heads = network_config["index_n_heads"] + self.index_head_dim = network_config["index_head_dim"] + self.index_topk = network_config["index_topk"] + self.o_groups = network_config["o_groups"] + self.hc_mult = network_config["hc_mult"] + self.sinkhorn_iters = network_config["hc_sinkhorn_iters"] + self.hc_eps = network_config["hc_eps"] + self.compress_ratio = network_config["compress_ratios"][layer_num] + self.is_hash = layer_num < network_config["num_hash_layers"] + self.is_last_layer = layer_num == network_config["n_layer"] - 1 # complex64 rope table for this layer's variant (sliding / compressed); set by # DeepseekV4TpPartModel._init_to_get_rotary once the tables are built. The full compress # cos/sin tables (compressor entry rope uses entry positions, not token positions) are @@ -49,19 +49,13 @@ def __init__(self, layer_num, network_config): self.freqs_cis = None self.cos_compress_table = None self.sin_compress_table = None - self.topk = cfg["num_experts_per_tok"] - self.route_scale = cfg["routed_scaling_factor"] - self.swiglu_limit = cfg["swiglu_limit"] - self.softmax_scale = self.head_dim ** -0.5 - self.tp_q_heads = self.n_heads // self.tp_world_size_ - self.tp_index_heads = self.index_n_heads // self.tp_world_size_ + self.num_experts_per_tok = network_config["num_experts_per_tok"] + self.routed_scaling_factor = network_config["routed_scaling_factor"] + self.swiglu_limit = network_config["swiglu_limit"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.tp_q_head_num_ = self.num_heads // self.tp_world_size_ + self.tp_index_n_heads = self.index_n_heads // self.tp_world_size_ self.tp_groups = self.o_groups // self.tp_world_size_ - self.tp_q_head_num_ = self.tp_q_heads - self.tp_k_head_num_ = 1 - self.tp_v_head_num_ = 1 - self.tp_o_head_num_ = self.tp_q_heads - self.head_dim_ = self.head_dim - self.embed_dim_ = self.hc_mult * self.hidden self.enable_ep_moe = get_env_start_args().enable_ep_moe self.indexer_score_scale = self.index_head_dim ** -0.5 self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 @@ -72,7 +66,7 @@ def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWe and runs a standalone hc_pre; later layers get (x, residual, post_mix, res_mix) and fuse the previous layer's ffn hc_post with this layer's attn hc_pre.""" if torch.is_tensor(input_embdings): - residual = input_embdings.view(-1, self.hc_mult, self.hidden) + residual = input_embdings.view(-1, self.hc_mult, self.embed_dim_) return hc_pre( residual, layer_weight.hc_attn_fn_.weight, @@ -149,14 +143,19 @@ def _select_rope(self, infer_state: DeepseekV4InferStateInfo): return infer_state.position_cos_compress, infer_state.position_sin_compress return infer_state.position_cos_sliding, infer_state.position_sin_sliding - def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): + def _get_qkv( + self, + input: torch.Tensor, + infer_state: DeepseekV4InferStateInfo, + layer_weight: DeepseekV4TransformerLayerWeight, + ): from sglang.jit_kernel.dsv4 import fused_q_norm_rope - x = self._tpsp_allgather(input=x, infer_state=infer_state) + input = self._tpsp_allgather(input=input, infer_state=infer_state) cos_tok, sin_tok = self._select_rope(infer_state) - T = x.shape[0] - qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(x), eps=self.eps_) - q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_heads, self.head_dim) + T = input.shape[0] + qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(input), eps=self.eps_) + q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_head_num_, self.head_dim_) # per-(token, head) weightless self-RMSNorm + interleaved rope on the last rope_dim dims, # fused in one sglang dsv4 jit kernel (fp32 norm/rotation, bf16 in between -- same as eager). q = self.alloc_tensor(q_in.shape, dtype=q_in.dtype, device=q_in.device) @@ -167,7 +166,7 @@ def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: Deeps infer_state.mem_manager.pack_mla_kv_to_cache_fused_norm_rope( layer_index=self.layer_num_, mem_index=infer_state.mem_index, - kv=layer_weight.wkv_.mm(x), + kv=layer_weight.wkv_.mm(input), kv_weight=layer_weight.kv_norm_.weight, eps=self.eps_, freqs_cis=self.freqs_cis, @@ -176,7 +175,7 @@ def _get_qkv(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: Deeps return q, qa, cos_tok, sin_tok def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): - # o: [T, tp_q_heads, head_dim] after inverse rope -> grouped low-rank O -> [T, hidden] + # o: [T, tp_q_head_num_, head_dim_] after inverse rope -> grouped low-rank O -> [T, embed_dim_] T = o.shape[0] o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] @@ -185,7 +184,7 @@ def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: Deepsee def _inv_rope(self, o, cos_tok, sin_tok): # in-place; 单张量路径只需要旋转 rope 切片。 - rotary_emb_fwd(o[..., -self.rope_dim :], None, cos_tok, sin_tok, inverse=True) + rotary_emb_fwd(o[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok, inverse=True) return o # ------------------------------------------------------------------ compressor / indexer @@ -196,8 +195,8 @@ def _indexer_q_weight( return None, None cos_tok = infer_state.position_cos_compress sin_tok = infer_state.position_sin_compress - idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_heads, self.index_head_dim) - rotary_emb_fwd(idx_q[..., -self.rope_dim :], None, cos_tok, sin_tok) + idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_n_heads, self.index_head_dim) + rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok) idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale return idx_q, idx_weight @@ -231,7 +230,7 @@ def _compressor_weights(self, layer_weight: DeepseekV4TransformerLayerWeight, fo layer_weight.compressor_wgate_.mm_param.weight, layer_weight.compressor_norm_.weight, layer_weight.compressor_ape_.weight, - self.head_dim, + self.head_dim_, ) def _run_compressor_prefill( @@ -286,7 +285,7 @@ def _run_c4_compressor_prefill( wgate, norm, ape, - self.head_dim, + self.head_dim_, self.cos_compress_table, self.sin_compress_table, self.eps_, @@ -337,7 +336,7 @@ def _run_c128_compressor_prefill( norm, ape, self.compress_ratio, - self.head_dim, + self.head_dim_, self.cos_compress_table, self.sin_compress_table, self.eps_, @@ -354,7 +353,7 @@ def _run_c128_compressor_prefill( norm, ape, self.compress_ratio, - self.head_dim, + self.head_dim_, self.cos_compress_table, self.sin_compress_table, self.eps_, @@ -406,7 +405,7 @@ def _run_compressor_decode( wgate, norm, ape, - self.head_dim, + self.head_dim_, self.cos_compress_table, self.sin_compress_table, self.eps_, @@ -424,8 +423,8 @@ def _run_compressor_decode( norm, ape, ratio, - self.head_dim, - self.rope_dim, + self.head_dim_, + self.qk_rope_head_dim, self.cos_compress_table, self.sin_compress_table, self.eps_, @@ -491,7 +490,7 @@ def _context_attention_wrapper_run( infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() # Same graph-split output handoff as the template, but avoid its dry-run because # DSV4 attention mutates compressor/cache state before returning. - o = self.alloc_tensor((q.shape[0], self.tp_q_heads, self.head_dim), dtype=q.dtype, device=q.device) + o = self.alloc_tensor((q.shape[0], self.tp_q_head_num_, self.head_dim_), dtype=q.dtype, device=q.device) _o = tensor_to_no_ref_tensor(o) def att_func(new_infer_state: DeepseekV4InferStateInfo): @@ -516,7 +515,7 @@ def _context_attention_kernel( "flashmla_kvcache": True, "layer_index": self.layer_num_, "compress_ratio": self.compress_ratio, - "head_dim_v": self.head_dim, + "head_dim_v": self.v_head_dim, "softmax_scale": self.softmax_scale, "q_lora": q_lora, "hidden_states": x, @@ -559,7 +558,7 @@ def _token_attention_kernel( "flashmla_kvcache": True, "layer_index": self.layer_num_, "compress_ratio": self.compress_ratio, - "head_dim_v": self.head_dim, + "head_dim_v": self.v_head_dim, "softmax_scale": self.softmax_scale, "q_lora": q_lora, "hidden_states": x, @@ -588,7 +587,7 @@ def _routed_experts(self, x, weights, indices, layer_weight: DeepseekV4Transform ) def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): - x = x.view(-1, self.hidden) + x = x.view(-1, self.embed_dim_) if not self.enable_ep_moe: x = self._tpsp_allgather(input=x, infer_state=infer_state) @@ -637,16 +636,16 @@ def _select_experts_vllm( else: bias = layer_weight.gate_bias_.weight - weights = self.alloc_tensor((M, self.topk), dtype=torch.float32, device=logits.device) - indices = self.alloc_tensor((M, self.topk), dtype=indices_dtype, device=logits.device) - token_expert_indices = self.alloc_tensor((M, self.topk), dtype=torch.int32, device=logits.device) + weights = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.float32, device=logits.device) + indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=indices_dtype, device=logits.device) + token_expert_indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.int32, device=logits.device) ops.topk_hash_softplus_sqrt( weights, indices, token_expert_indices, logits, True, - self.route_scale, + self.routed_scaling_factor, bias, input_tokens, hash_indices_table, From d7dd6e057acd19053f9ee49b6b983aa401920880 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 12 Jun 2026 03:27:16 +0000 Subject: [PATCH 16/30] fix --- .../layer_infer/transformer_layer_infer.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 23f742b7d3..cf6c18bfaa 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -152,7 +152,6 @@ def _get_qkv( from sglang.jit_kernel.dsv4 import fused_q_norm_rope input = self._tpsp_allgather(input=input, infer_state=infer_state) - cos_tok, sin_tok = self._select_rope(infer_state) T = input.shape[0] qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(input), eps=self.eps_) q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_head_num_, self.head_dim_) @@ -172,21 +171,18 @@ def _get_qkv( freqs_cis=self.freqs_cis, positions=infer_state.position_ids, ) - return q, qa, cos_tok, sin_tok + return q, qa def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): # o: [T, tp_q_head_num_, head_dim_] after inverse rope -> grouped low-rank O -> [T, embed_dim_] + position_cos, position_sin = self._select_rope(infer_state) + rotary_emb_fwd(o[..., -self.qk_rope_head_dim :], None, position_cos, position_sin, inverse=True) T = o.shape[0] o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] o = layer_weight.wo_b_.mm(o) return self._tpsp_reduce(input=o, infer_state=infer_state) - def _inv_rope(self, o, cos_tok, sin_tok): - # in-place; 单张量路径只需要旋转 rope 切片。 - rotary_emb_fwd(o[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok, inverse=True) - return o - # ------------------------------------------------------------------ compressor / indexer def _indexer_q_weight( self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight @@ -468,9 +464,9 @@ def context_attention_forward( # _get_qkv writes the chunk's packed latent into the swa pool (fused kernel) before # attention reads it back via full_to_swa indices (this custom forward bypasses the # tpl _post_cache_kv path). - q, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) + q, q_lora = self._get_qkv(x, infer_state, layer_weight) o = self._context_attention_wrapper_run(q, q_lora, x, infer_state, layer_weight) - return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, layer_weight) + return self._get_o(o, infer_state, layer_weight) def _context_attention_wrapper_run( self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight @@ -543,9 +539,9 @@ def _context_attention_kernel( def token_attention_forward( self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight ): - q, q_lora, cos_tok, sin_tok = self._get_qkv(x, infer_state, layer_weight) + q, q_lora = self._get_qkv(x, infer_state, layer_weight) o = self._token_attention_kernel(q, q_lora, x, infer_state, layer_weight) - return self._get_o(self._inv_rope(o, cos_tok, sin_tok), infer_state, layer_weight) + return self._get_o(o, infer_state, layer_weight) def _token_attention_kernel( self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight From 3a5dcdc1122d555fe30013759b6e6636beb154f6 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Sun, 14 Jun 2026 10:30:50 +0000 Subject: [PATCH 17/30] compress infer --- .../deepseek_v4/layer_infer/compressor.py | 430 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 342 +++----------- .../layer_weights/transformer_layer_weight.py | 15 +- 3 files changed, 503 insertions(+), 284 deletions(-) diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py index 2256ecd1a9..129a93ae0d 100644 --- a/lightllm/models/deepseek_v4/layer_infer/compressor.py +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -1,4 +1,10 @@ +from dataclasses import dataclass +from typing import Optional + import torch +import triton +import triton.language as tl +from triton.language.extra import libdevice _SGLANG_COMPRESS_ERR = None @@ -7,6 +13,430 @@ _FREQ_CIS_CACHE = {} +@dataclass +class CoreCompressorMetadata: + layer_idx: int + compress_ratio: int + out_slots: torch.Tensor + mem_index: torch.Tensor + state_buffer: torch.Tensor + out_buffer: torch.Tensor + out_page_size: int + position_ids: torch.Tensor + b_req_idx: torch.Tensor + b_seq_len: torch.Tensor + b_ready_cache_len: Optional[torch.Tensor] + b_q_start_loc: Optional[torch.Tensor] + req_to_token_indexs: torch.Tensor + full_to_swa_indexs: torch.Tensor + token_to_batch_idx: Optional[torch.Tensor] + kv_score: Optional[torch.Tensor] + is_prefill: bool + + +@triton.jit +def _add_ape_to_kv_score_kernel( + kv_score, + kv_score_stride0, + kv_score_stride1, + ape, + ape_stride0, + positions, + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + BLOCK: tl.constexpr, +): + token_idx = tl.program_id(0) + offs = tl.arange(0, BLOCK) + mask = offs < STATE_WIDTH + + position = tl.load(positions + token_idx) + ape_row = position % COMPRESS_RATIO + score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask) + ape_value = tl.load(ape + ape_row * ape_stride0 + offs, mask=mask) + tl.store( + kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, + score + ape_value, + mask=mask, + ) + return + + +@triton.jit +def _save_partial_states_kernel( + kv_score, + kv_score_stride0, + kv_score_stride1, + positions, + token_to_batch_idx, + b_req_idx, + b_seq_len, + mem_index, + full_to_swa, + state_buffer, + STATE_WIDTH: tl.constexpr, + STATE_LAST_DIM: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + IS_C4: tl.constexpr, + IS_PREFILL: tl.constexpr, + SWA_PAGE_SIZE: tl.constexpr, + C4_STATE_RING: tl.constexpr, + BLOCK: tl.constexpr, +): + token_idx = tl.program_id(0) + batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx + position = tl.load(positions + token_idx) + seq_len = tl.load(b_seq_len + batch_idx) + + if IS_C4: + same_page_next = (position % SWA_PAGE_SIZE) + C4_STATE_RING < SWA_PAGE_SIZE + if same_page_next and position + C4_STATE_RING < seq_len: + return + full_slot = tl.load(mem_index + token_idx).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot).to(tl.int64) + if swa_slot < 0: + return + state_row = (swa_slot // SWA_PAGE_SIZE) * C4_STATE_RING + (swa_slot % C4_STATE_RING) + else: + if position + COMPRESS_RATIO < seq_len: + return + req_idx = tl.load(b_req_idx + batch_idx).to(tl.int64) + state_row = req_idx * COMPRESS_RATIO + (position % COMPRESS_RATIO) + + offs = tl.arange(0, BLOCK) + mask = offs < STATE_WIDTH + kv = tl.load(kv_score + token_idx * kv_score_stride0 + offs * kv_score_stride1, mask=mask) + score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask) + state_base = state_buffer + state_row * STATE_LAST_DIM + tl.store(state_base + offs, kv, mask=mask) + tl.store(state_base + STATE_WIDTH + offs, score, mask=mask) + return + + +@triton.jit +def _fused_compress_norm_rope_insert_kernel( + kv_score, + kv_score_stride0, + kv_score_stride1, + state_buffer, + positions, + token_to_batch_idx, + b_req_idx, + b_seq_len, + b_ready_cache_len, + b_q_start_loc, + req_to_token, + req_to_token_stride0, + full_to_swa, + out_slots, + norm_weight, + rms_eps, + cos_table, + cos_stride0, + sin_table, + sin_stride0, + out_buffer, + HEAD_DIM: tl.constexpr, + STATE_WIDTH: tl.constexpr, + STATE_LAST_DIM: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + WINDOW_SIZE: tl.constexpr, + IS_C4: tl.constexpr, + IS_PREFILL: tl.constexpr, + SWA_PAGE_SIZE: tl.constexpr, + C4_STATE_RING: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + FP8_MAX: tl.constexpr, + SCALE_MIN: tl.constexpr, + NOPE_DIM: tl.constexpr, + QUANT_BLOCK: tl.constexpr, + SCALE_BYTES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, + BLOCK: tl.constexpr, +): + token_idx = tl.program_id(0) + out_slot = tl.load(out_slots + token_idx).to(tl.int64) + if out_slot < 0: + return + + position = tl.load(positions + token_idx) + if (position + 1) % COMPRESS_RATIO != 0: + return + + batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx + req_idx = tl.load(b_req_idx + batch_idx).to(tl.int64) + seq_len = tl.load(b_seq_len + batch_idx) + if IS_PREFILL: + ready_len = tl.load(b_ready_cache_len + batch_idx) + q_start = tl.load(b_q_start_loc + batch_idx) + else: + ready_len = position + q_start = token_idx + + token_offsets = tl.arange(0, WINDOW_SIZE) + start = position - WINDOW_SIZE + 1 + gather_pos = start + token_offsets + valid_pos = (gather_pos >= 0) & (gather_pos < seq_len) + use_current = (gather_pos >= ready_len) & valid_pos if IS_PREFILL else gather_pos == position + current_idx = q_start + (gather_pos - ready_len) if IS_PREFILL else token_idx + token_offsets * 0 + + if IS_C4: + full_slot = tl.load( + req_to_token + req_idx * req_to_token_stride0 + gather_pos, + mask=valid_pos & (~use_current), + other=0, + ).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64) + state_row = (swa_slot // SWA_PAGE_SIZE) * C4_STATE_RING + (swa_slot % C4_STATE_RING) + state_valid = valid_pos & (~use_current) & (swa_slot >= 0) + head_offset = tl.where(token_offsets >= COMPRESS_RATIO, HEAD_DIM, 0) + else: + state_row = req_idx * COMPRESS_RATIO + (gather_pos % COMPRESS_RATIO) + state_valid = valid_pos & (~use_current) + head_offset = token_offsets * 0 + + offs = tl.arange(0, BLOCK) + dim_mask = offs < HEAD_DIM + current_mask = use_current[:, None] & dim_mask[None, :] + state_mask = state_valid[:, None] & dim_mask[None, :] + + cur_kv = tl.load( + kv_score + current_idx[:, None] * kv_score_stride0 + (head_offset[:, None] + offs[None, :]) * kv_score_stride1, + mask=current_mask, + other=0.0, + ) + cur_score = tl.load( + kv_score + + current_idx[:, None] * kv_score_stride0 + + (STATE_WIDTH + head_offset[:, None] + offs[None, :]) * kv_score_stride1, + mask=current_mask, + other=float("-inf"), + ) + state_kv = tl.load( + state_buffer + state_row[:, None] * STATE_LAST_DIM + head_offset[:, None] + offs[None, :], + mask=state_mask, + other=0.0, + ) + state_score = tl.load( + state_buffer + state_row[:, None] * STATE_LAST_DIM + STATE_WIDTH + head_offset[:, None] + offs[None, :], + mask=state_mask, + other=float("-inf"), + ) + + kv = tl.where(current_mask, cur_kv, state_kv) + score = tl.where(current_mask, cur_score, state_score) + score = tl.softmax(score, dim=0) + compressed_kv = tl.sum(kv * score, axis=0) + + rms_w = tl.load(norm_weight + offs, mask=dim_mask, other=0.0) + variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_DIM + rrms = tl.rsqrt(variance + rms_eps) + normed = compressed_kv * rrms * rms_w + + num_pairs: tl.constexpr = BLOCK // 2 + nope_pairs: tl.constexpr = NOPE_DIM // 2 + pair_2d = tl.reshape(normed, (num_pairs, 2)) + even, odd = tl.split(pair_2d) + pair_idx = tl.arange(0, num_pairs) + rope_pair_local = pair_idx - nope_pairs + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cos_v = tl.load(cos_table + compressed_pos * cos_stride0 + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(sin_table + compressed_pos * sin_stride0 + cs_idx, mask=is_rope_pair, other=0.0) + new_even = even * cos_v - odd * sin_v + new_odd = odd * cos_v + even * sin_v + rotated = tl.interleave(new_even, new_odd) + + page = out_slot // PAGE_SIZE + token_in_page = out_slot % PAGE_SIZE + data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_HEAD_DIM * 2) + scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_HEAD_DIM * 2) + token_in_page * SCALE_BYTES + + n_quant_blocks: tl.constexpr = BLOCK // QUANT_BLOCK + n_nope_blocks: tl.constexpr = NOPE_DIM // QUANT_BLOCK + quant_input = normed.to(tl.bfloat16).to(tl.float32) + quant_2d = tl.reshape(quant_input, (n_quant_blocks, QUANT_BLOCK)) + abs_2d = tl.abs(quant_2d) + block_absmax = tl.max(abs_2d, axis=1) + scale_exp = tl.ceil(libdevice.log2(tl.maximum(block_absmax / FP8_MAX, SCALE_MIN))).to(tl.int32) + scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True) + kv_fp8 = tl.clamp(quant_2d / scale[:, None], -FP8_MAX, FP8_MAX).to(tl.float8e4nv) + kv_u8 = tl.reshape(kv_fp8.to(tl.uint8, bitcast=True), (BLOCK,)) + tl.store(out_buffer + data_base + offs, kv_u8, mask=offs < NOPE_DIM) + + scale_idx = tl.arange(0, SCALE_BYTES) + scale_bytes = tl.where(scale_idx < n_nope_blocks, scale_exp + 127, 0).to(tl.uint8) + tl.store(out_buffer + scale_base + scale_idx, scale_bytes) + + rope_local = offs - NOPE_DIM + rope_mask = (offs >= NOPE_DIM) & dim_mask + rope_ptr = (out_buffer + data_base + NOPE_DIM).to(tl.pointer_type(tl.bfloat16)) + tl.store(rope_ptr + rope_local, rotated.to(tl.bfloat16), mask=rope_mask) + return + + +def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int): + if compress_ratio == 0: + return None + + mem_manager = infer_state.mem_manager + if compress_ratio == 4: + out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = mem_manager.get_c4_state_buffer(layer_idx) + out_pool = mem_manager.c4_pool + elif compress_ratio == 128: + out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = infer_state.req_manager.get_compress_state_pool(layer_idx) + out_pool = mem_manager.c128_pool + else: + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") + + token_to_batch_idx = infer_state.b_req_idx + if infer_state.is_prefill: + token_to_batch_idx = getattr(infer_state, "_dsv4_token_to_batch_idx", None) + if token_to_batch_idx is None or token_to_batch_idx.numel() != infer_state.position_ids.numel(): + q_lens = (infer_state.b_seq_len - infer_state.b_ready_cache_len).to(torch.long) + batch_idx = torch.arange(infer_state.b_req_idx.shape[0], device=infer_state.b_req_idx.device) + token_to_batch_idx = torch.repeat_interleave(batch_idx, q_lens).to(torch.int32) + infer_state._dsv4_token_to_batch_idx = token_to_batch_idx + + return CoreCompressorMetadata( + layer_idx=layer_idx, + compress_ratio=compress_ratio, + out_slots=out_slots, + mem_index=infer_state.mem_index, + state_buffer=state_buffer, + out_buffer=mem_manager.get_compressed_kv_buffer(layer_idx), + out_page_size=out_pool.page_size, + position_ids=infer_state.position_ids, + b_req_idx=infer_state.b_req_idx, + b_seq_len=infer_state.b_seq_len, + b_ready_cache_len=infer_state.b_ready_cache_len, + b_q_start_loc=infer_state.b_q_start_loc, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + full_to_swa_indexs=mem_manager.full_to_swa_indexs, + token_to_batch_idx=token_to_batch_idx, + kv_score=None, + is_prefill=infer_state.is_prefill, + ) + + +def prepare_partial_states( + *, + kv_score: torch.Tensor, + metadata: Optional[CoreCompressorMetadata], + ape: torch.Tensor, + compress_ratio: int, +): + if metadata is None or kv_score.shape[0] == 0: + return + state_width = kv_score.shape[-1] // 2 + _add_ape_to_kv_score_kernel[(kv_score.shape[0],)]( + kv_score, + kv_score.stride(0), + kv_score.stride(1), + ape, + ape.stride(0), + metadata.position_ids, + STATE_WIDTH=state_width, + COMPRESS_RATIO=compress_ratio, + BLOCK=triton.next_power_of_2(state_width), + num_warps=4, + ) + return + + +def fused_compress( + *, + kv_score: torch.Tensor, + metadata: Optional[CoreCompressorMetadata], + norm_weight: torch.Tensor, + ape: torch.Tensor, + eps: float, + head_dim: int, + qk_rope_head_dim: int, + compress_ratio: int, + cos_table: torch.Tensor, + sin_table: torch.Tensor, +): + if metadata is None or kv_score.shape[0] == 0: + return + + state_width = kv_score.shape[-1] // 2 + state_last_dim = metadata.state_buffer.shape[-1] + is_c4 = compress_ratio == 4 + block_state = triton.next_power_of_2(state_width) + block_head = triton.next_power_of_2(head_dim) + + _fused_compress_norm_rope_insert_kernel[(kv_score.shape[0],)]( + kv_score, + kv_score.stride(0), + kv_score.stride(1), + metadata.state_buffer, + metadata.position_ids, + metadata.token_to_batch_idx, + metadata.b_req_idx, + metadata.b_seq_len, + metadata.b_ready_cache_len if metadata.b_ready_cache_len is not None else metadata.b_seq_len, + metadata.b_q_start_loc if metadata.b_q_start_loc is not None else metadata.b_seq_len, + metadata.req_to_token_indexs, + metadata.req_to_token_indexs.stride(0), + metadata.full_to_swa_indexs, + metadata.out_slots, + norm_weight, + eps, + cos_table, + cos_table.stride(0), + sin_table, + sin_table.stride(0), + metadata.out_buffer, + HEAD_DIM=head_dim, + STATE_WIDTH=state_width, + STATE_LAST_DIM=state_last_dim, + COMPRESS_RATIO=compress_ratio, + WINDOW_SIZE=compress_ratio * (2 if is_c4 else 1), + IS_C4=is_c4, + IS_PREFILL=metadata.is_prefill, + SWA_PAGE_SIZE=128, + C4_STATE_RING=8, + ROPE_HEAD_DIM=qk_rope_head_dim, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + SCALE_MIN=1e-4, + NOPE_DIM=head_dim - qk_rope_head_dim, + QUANT_BLOCK=64, + SCALE_BYTES=(head_dim - qk_rope_head_dim) // 64 + 1, + PAGE_SIZE=metadata.out_page_size, + BYTES_PER_PAGE=metadata.out_buffer.shape[-1], + BLOCK=block_head, + num_warps=4, + ) + + _save_partial_states_kernel[(kv_score.shape[0],)]( + kv_score, + kv_score.stride(0), + kv_score.stride(1), + metadata.position_ids, + metadata.token_to_batch_idx, + metadata.b_req_idx, + metadata.b_seq_len, + metadata.mem_index, + metadata.full_to_swa_indexs, + metadata.state_buffer, + STATE_WIDTH=state_width, + STATE_LAST_DIM=state_last_dim, + COMPRESS_RATIO=compress_ratio, + IS_C4=is_c4, + IS_PREFILL=metadata.is_prefill, + SWA_PAGE_SIZE=128, + C4_STATE_RING=8, + BLOCK=block_state, + num_warps=4, + ) + return + + def _load_sglang_compressor(): global _SGLANG_COMPRESS_ERR, _SGLANG_COMPRESS_MOD, _SGLANG_LINEAR_BF16_FP32 if _SGLANG_COMPRESS_MOD is not None: diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index cf6c18bfaa..1b5c5f4c3f 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -9,15 +9,9 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post -from .compressor import ( - compressor_prefill_state, - compressor_decode_step_single, - compressor_decode_step_batch, - compressor_paged_prefill, - compressor_paged_decode_batch, - paged_prefill_compress_data, - paged_decode_state_slots, -) +from .compressor import fused_compress as fused_compress_op +from .compressor import prepare_partial_states +from .compressor import prepare_compress_states from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from ..infer_struct import DeepseekV4InferStateInfo @@ -59,6 +53,9 @@ def __init__(self, layer_num, network_config): self.enable_ep_moe = get_env_start_args().enable_ep_moe self.indexer_score_scale = self.index_head_dim ** -0.5 self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 + self.compressor = CompressorInfer( + layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_ + ) # ------------------------------------------------------------------ forward (HC-threaded) def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWeight): @@ -196,267 +193,6 @@ def _indexer_q_weight( idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale return idx_q, idx_weight - def _gather_compress_slots(self, infer_state: DeepseekV4InferStateInfo, req, entry_start, entry_count): - """组末 token 的 full 槽位 -> 压缩槽(条目 [entry_start, entry_start+entry_count))。 - 槽位已由 prep 阶段(prepare_*_compress_slots)分配并 scatter 进 full_to_c4/c128_indexs。""" - ratio = self.compress_ratio - mem = infer_state.mem_manager - mapping = mem.full_to_c4_indexs if ratio == 4 else mem.full_to_c128_indexs - last = entry_start + entry_count - ends = infer_state.req_manager.req_to_token_indexs[req, ratio - 1 : last * ratio : ratio][entry_start:] - return mapping[ends.long()] - - def _write_compressed_kv(self, infer_state: DeepseekV4InferStateInfo, req, entry_start, comp): - slots = self._gather_compress_slots(infer_state, req, entry_start, comp.shape[0]) - if comp.shape[0]: - infer_state.mem_manager.pack_compressed_kv_to_cache(self.layer_num_, slots, comp) - return slots - - def _compressor_weights(self, layer_weight: DeepseekV4TransformerLayerWeight, for_indexer: bool): - if for_indexer: - return ( - layer_weight.idx_cmp_wkv_.mm_param.weight, - layer_weight.idx_cmp_wgate_.mm_param.weight, - layer_weight.idx_cmp_norm_.weight, - layer_weight.idx_cmp_ape_.weight, - self.index_head_dim, - ) - return ( - layer_weight.compressor_wkv_.mm_param.weight, - layer_weight.compressor_wgate_.mm_param.weight, - layer_weight.compressor_norm_.weight, - layer_weight.compressor_ape_.weight, - self.head_dim_, - ) - - def _run_compressor_prefill( - self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight - ): - """Per-request compressor for the prefill chunk. Runs as part of the deferred attention - func, before the attention metadata gathers the slot mappings. - - c4: paged state (swa-page-derived group slots, translation #3) — one fused extend-aware - call per request; the (write_loc, extra_data, plan) tuple is layer-independent and cached - on infer_state across all c4 layers. c128: req-keyed state (zero at every 128 boundary by - construction, nothing cache-resident), original jit paths.""" - if not self.compress_ratio: - return - if self.compress_ratio == 4: - self._run_c4_compressor_prefill(x, infer_state, layer_weight) - else: - self._run_c128_compressor_prefill(x, infer_state, layer_weight) - return - - def _run_c4_compressor_prefill( - self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight - ): - rm = infer_state.req_manager - mem = infer_state.mem_manager - wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) - iwkv, iwgate, inorm, iape, _ = self._compressor_weights(layer_weight, for_indexer=True) - state_buf = mem.get_c4_state_buffer(self.layer_num_) - idx_state_buf = mem.get_c4_indexer_state_buffer(self.layer_num_) - data_cache = getattr(infer_state, "_dsv4_c4_prefill_data", None) - if data_cache is None: - data_cache = {} - infer_state._dsv4_c4_prefill_data = data_cache - b_req = infer_state.b_req_idx.tolist() - starts = infer_state.b_q_start_loc.tolist() - lens = infer_state.b_q_seq_len.tolist() - ready_lens = infer_state.b_ready_cache_len.tolist() - for req, st, ln, ready_len in zip(b_req, starts, lens, ready_lens): - if req == rm.HOLD_REQUEST_ID or ln == 0: - continue - seq_len = ready_len + ln - data = data_cache.get(req) - if data is None: - data = paged_prefill_compress_data( - rm.req_to_token_indexs, mem.full_to_swa_indexs, req, ready_len, seq_len, ring=8 - ) - data_cache[req] = data - x_r = x[st : st + ln] - comp = compressor_paged_prefill( - x_r, - wkv, - wgate, - norm, - ape, - self.head_dim_, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - state_buf, - data, - ready_len, - seq_len, - ) - slots = self._write_compressed_kv(infer_state, req, ready_len // 4, comp) - idx_comp = compressor_paged_prefill( - x_r, - iwkv, - iwgate, - inorm, - iape, - self.index_head_dim, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - idx_state_buf, - data, - ready_len, - seq_len, - ) - if idx_comp.shape[0]: - infer_state.mem_manager.pack_indexer_k_to_cache(self.layer_num_, slots, idx_comp) - return - - def _run_c128_compressor_prefill( - self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight - ): - rm = infer_state.req_manager - wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) - b_req = infer_state.b_req_idx.tolist() - starts = infer_state.b_q_start_loc.tolist() - lens = infer_state.b_q_seq_len.tolist() - ready_lens = infer_state.b_ready_cache_len.tolist() - for req, st, ln, ready_len in zip(b_req, starts, lens, ready_lens): - if req == rm.HOLD_REQUEST_ID: - continue - x_r = x[st : st + ln] - state_pool = rm.get_compress_state_pool_for_req(self.layer_num_, req) - if ready_len == 0: - comp = compressor_prefill_state( - x_r, - wkv, - wgate, - norm, - ape, - self.compress_ratio, - self.head_dim_, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - state_pool, - ) - self._write_compressed_kv(infer_state, req, 0, comp) - else: - for j in range(ln): - start_pos = ready_len + j - entry = compressor_decode_step_single( - x_r[j], - wkv, - wgate, - norm, - ape, - self.compress_ratio, - self.head_dim_, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - state_pool, - start_pos, - ) - if entry is not None: - entry_start = (start_pos + 1) // self.compress_ratio - 1 - self._write_compressed_kv(infer_state, req, entry_start, entry.unsqueeze(0)) - return - - def _run_compressor_decode( - self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight - ): - """Batched decode compressor (cuda-graph safe): state update for every request, cache write - masked to the pool HOLD slot unless this token completes a window. Compressed-cache slots - were pre-allocated by prepare_decode_compress_slots in the prep phase. - - c4: paged state — group slots derived from full_to_swa (translation #3) via pure tensor - ops (graph-safe), shared across all c4 layers per step. c128: req-keyed state.""" - if not self.compress_ratio: - return - rm = infer_state.req_manager - mem = infer_state.mem_manager - req = infer_state.b_req_idx - ratio = self.compress_ratio - wkv, wgate, norm, ape, _ = self._compressor_weights(layer_weight, for_indexer=False) - - if ratio == 4: - mapping, hold = mem.full_to_c4_indexs, mem.c4_pool.HOLD_TOKEN_MEMINDEX - slot_meta = getattr(infer_state, "_dsv4_c4_decode_slots", None) - if slot_meta is None: - slot_meta = paged_decode_state_slots( - rm.req_to_token_indexs, - mem.full_to_swa_indexs, - req, - infer_state.b_seq_len, - page_size=128, - ring=8, - ratio=4, - hold_req_id=rm.HOLD_REQUEST_ID, - num_swa_pages=mem.swa_num_pages, - ) - infer_state._dsv4_c4_decode_slots = slot_meta - write_slot, overlap_slot = slot_meta - entry, should = compressor_paged_decode_batch( - x, - wkv, - wgate, - norm, - ape, - self.head_dim_, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - mem.get_c4_state_buffer(self.layer_num_), - write_slot, - overlap_slot, - infer_state.b_seq_len, - ) - else: - mapping, hold = mem.full_to_c128_indexs, mem.c128_pool.HOLD_TOKEN_MEMINDEX - entry, should = compressor_decode_step_batch( - x, - wkv, - wgate, - norm, - ape, - ratio, - self.head_dim_, - self.qk_rope_head_dim, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - rm.get_compress_state_pool(self.layer_num_), - req, - infer_state.b_seq_len.long() - 1, - ) - - should = should & (req != rm.HOLD_REQUEST_ID) - # 本步 token 即组末 token(should 为真时),其 full 槽 = mem_index,映射在 prep 已 scatter。 - slots = mapping[infer_state.mem_index.long()].long() - slots = torch.where(should, slots, torch.full_like(slots, hold)) - mem.pack_compressed_kv_to_cache(self.layer_num_, slots, entry) - - if ratio == 4: - iwkv, iwgate, inorm, iape, _ = self._compressor_weights(layer_weight, for_indexer=True) - idx_entry, idx_should = compressor_paged_decode_batch( - x, - iwkv, - iwgate, - inorm, - iape, - self.index_head_dim, - self.cos_compress_table, - self.sin_compress_table, - self.eps_, - mem.get_c4_indexer_state_buffer(self.layer_num_), - write_slot, - overlap_slot, - infer_state.b_seq_len, - ) - idx_should = idx_should & (req != rm.HOLD_REQUEST_ID) - idx_slots = torch.where(idx_should, slots, torch.full_like(slots, hold)) - mem.pack_indexer_k_to_cache(self.layer_num_, idx_slots, idx_entry) - return - # ------------------------------------------------------------------ attention (prefill) def context_attention_forward( self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight @@ -503,7 +239,8 @@ def att_func(new_infer_state: DeepseekV4InferStateInfo): def _context_attention_kernel( self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight ): - self._run_compressor_prefill(x, infer_state, layer_weight) + self.compressor.prepare_states(x, infer_state, layer_weight) + self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) att_control = AttControl( nsa_prefill=True, @@ -546,7 +283,8 @@ def token_attention_forward( def _token_attention_kernel( self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight ): - self._run_compressor_decode(x, infer_state, layer_weight) + self.compressor.prepare_states(x, infer_state, layer_weight) + self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) att_control = AttControl( nsa_decode=True, @@ -647,3 +385,63 @@ def _select_experts_vllm( hash_indices_table, ) return weights, indices.long() + + +class CompressorInfer: + def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.tp_world_size_ = tp_world_size + self.compress_ratio = network_config["compress_ratios"][layer_idx] + self.head_dim = network_config["head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.eps = network_config["rms_norm_eps"] + self._metadata = None + + def prepare_states( + self, + x: torch.Tensor, + infer_state: DeepseekV4InferStateInfo, + layer_weight: DeepseekV4TransformerLayerWeight, + ): + self._metadata = prepare_compress_states( + infer_state=infer_state, + layer_idx=self.layer_idx_, + compress_ratio=self.compress_ratio, + ) + if self._metadata is not None: + self._metadata.kv_score = layer_weight.compressor_wkv_gate_.mm(x).float() + prepare_partial_states( + kv_score=self._metadata.kv_score, + metadata=self._metadata, + ape=layer_weight.compressor_ape_.weight, + compress_ratio=self.compress_ratio, + ) + return self._metadata + + def fused_compress( + self, + infer_state: DeepseekV4InferStateInfo, + layer_weight: DeepseekV4TransformerLayerWeight, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + ): + if self.compress_ratio == 0: + return None + metadata = self._metadata + if metadata is None: + raise RuntimeError("DeepSeek-V4 compressor.prepare_states must run before fused_compress") + return fused_compress_op( + kv_score=metadata.kv_score, + metadata=metadata, + norm_weight=layer_weight.compressor_norm_.weight, + ape=layer_weight.compressor_ape_.weight, + eps=self.eps, + head_dim=self.head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + compress_ratio=self.compress_ratio, + cos_table=cos_table, + sin_table=sin_table, + ) diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py index 7b2f67e123..d42b20f6e5 100644 --- a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -121,19 +121,10 @@ def _init_compressor(self): coff = 2 if ratio == 4 else 1 # wkv/wgate are bf16 (no scale) and replicated (single KV head). - self.compressor_wkv_ = ROWMMWeight( + self.compressor_wkv_gate_ = ROWMMWeight( in_dim=self.hidden, - out_dims=[coff * head_dim], - weight_names=f"{prefix}.wkv.weight", - data_type=self.data_type_, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - self.compressor_wgate_ = ROWMMWeight( - in_dim=self.hidden, - out_dims=[coff * head_dim], - weight_names=f"{prefix}.wgate.weight", + out_dims=[coff * head_dim, coff * head_dim], + weight_names=[f"{prefix}.wkv.weight", f"{prefix}.wgate.weight"], data_type=self.data_type_, quant_method=None, tp_rank=0, From d76450f9c92af47857f33f3615cfdfbcc48ef55b Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Sun, 14 Jun 2026 13:27:05 +0000 Subject: [PATCH 18/30] add c128 to mem_manager --- .../deepseek4_mem_manager.py | 60 +++++++--- lightllm/common/req_manager.py | 42 +------ .../deepseek_v4/layer_infer/compressor.py | 104 +++++++++++------- 3 files changed, 116 insertions(+), 90 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index f87a11704d..8d172ec758 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -31,9 +31,9 @@ DSV4_SWA_PAGE_SIZE = 128 DSV4_C4_PAGE_SIZE = 64 DSV4_C128_PAGE_SIZE = 2 -# c4 compressor state ring(overlap 对: 每页 2 个分组槽 × ratio 4 行)。c128 state 在 128 边界 -# 自然归零(在线聚合),无缓存常驻需求,保持 req 键控,不进 swa 派生池。 +# compressor state ring: c4 overlap 对为每页 2 个分组槽 × ratio 4 行;c128 离线聚合为每页 1 组。 DSV4_C4_STATE_RING = 8 +DSV4_C128_STATE_RING = 128 # swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 同值)。 # v5 的 swa 压力阀(借页/驱逐)已覆盖 radix 树与准入波次的瞬时增长,结构性预算 # (max_req×window + batch_max_tokens 余量)另行叠加,0.1 仅作 full 池比例下限。 @@ -154,6 +154,7 @@ def __init__( max_request_num: Optional[int] = None, sliding_window: Optional[int] = None, swa_extra_token_num: int = 0, + swa_full_tokens_ratio: float = DSV4_SWA_FULL_TOKENS_RATIO, always_copy=False, mem_fraction=0.9, ): @@ -174,6 +175,7 @@ def __init__( # 活跃窗口(max_request_num * sliding_window)之外的余量: 在途 prefill chunk 的瞬时占用 # (出窗槽位要到下一次 prep 才回收) + radix cache 持有的窗口尾部。 self.swa_extra_token_num = int(swa_extra_token_num) + self.swa_full_tokens_ratio = float(swa_full_tokens_ratio) # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法) self.layer_to_c4_idx = {} @@ -200,7 +202,7 @@ def _planned_swa_size(self, full_size: int) -> int: if self.max_request_num is None or self.sliding_window is None: return _ceil_div(full_size, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE cap = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num - cap = max(cap, int(full_size * DSV4_SWA_FULL_TOKENS_RATIO)) + cap = max(cap, int(full_size * self.swa_full_tokens_ratio)) cap = max(1, min(full_size, cap)) return _ceil_div(cap, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE @@ -209,18 +211,32 @@ def _slab_bytes_per_slot(page_size: int, data_bytes: int, scale_bytes: int, alig bytes_per_page = _ceil_div(page_size * (data_bytes + scale_bytes), align_bytes) * align_bytes return bytes_per_page / page_size - def _c4_state_bytes_per_swa_slot(self) -> float: - """c4 compressor state(attention + indexer,swa 页派生寻址)摊到每个 swa 槽的字节数。""" - if self.n_c4 == 0: - return 0.0 - per_page = DSV4_C4_STATE_RING * (4 * self.head_dim + 4 * self.indexer_head_dim) * 4 # fp32 - return per_page * self.n_c4 / DSV4_SWA_PAGE_SIZE + @staticmethod + def _paged_state_rows(num_swa_pages: int, ring: int, ratio: int) -> int: + rows = num_swa_pages * ring + ring + 1 + return _ceil_div(rows, ratio) * ratio + + @staticmethod + def _init_state_sentinel(buffer: torch.Tensor) -> None: + half = buffer.shape[-1] // 2 + buffer[:, -1, :half].zero_() + buffer[:, -1, half:].fill_(float("-inf")) + return + + def _paged_state_bytes_per_swa_slot(self) -> float: + """c4/c128 compressor state(swa 页派生寻址)摊到每个 swa 槽的字节数。""" + per_page = 0.0 + if self.n_c4 > 0: + per_page += DSV4_C4_STATE_RING * (4 * self.head_dim + 4 * self.indexer_head_dim) * 4 * self.n_c4 + if self.n_c128 > 0: + per_page += DSV4_C128_STATE_RING * (2 * self.head_dim) * 4 * self.n_c128 + return per_page / DSV4_SWA_PAGE_SIZE def _swa_slot_bytes(self) -> float: per_layer = self._slab_bytes_per_slot( DSV4_SWA_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES ) - return per_layer * self.layer_num + self._c4_state_bytes_per_swa_slot() + return per_layer * self.layer_num + self._paged_state_bytes_per_swa_slot() def _compressed_cell_size(self) -> float: """每个 full token 摊到压缩池上的精确字节数(按 page-slab 对齐后)。""" @@ -259,10 +275,10 @@ def profile_size(self, mem_fraction): self.size = max(1, int(available_bytes / full_cell)) else: size_budget = max(1, int((available_bytes - swa_slot_bytes * swa_budget) / compressed_cell)) - if size_budget * DSV4_SWA_FULL_TOKENS_RATIO > swa_budget: + if size_budget * self.swa_full_tokens_ratio > swa_budget: # 比例下限生效(_planned_swa_size 会取 ratio*full),按该机制反解 full。 self.size = max( - 1, int(available_bytes / (swa_slot_bytes * DSV4_SWA_FULL_TOKENS_RATIO + compressed_cell)) + 1, int(available_bytes / (swa_slot_bytes * self.swa_full_tokens_ratio + compressed_cell)) ) else: self.size = size_budget @@ -321,6 +337,9 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.c4_allocator: Optional[KvCacheAllocator] = None self.c128_pool: Optional[PackedPagePool] = None self.c128_allocator: Optional[KvCacheAllocator] = None + self.c4_state_buffer: Optional[torch.Tensor] = None + self.c4_indexer_state_buffer: Optional[torch.Tensor] = None + self.c128_state_buffer: Optional[torch.Tensor] = None # 压缩槽映射: 键 = 组末 token(位置 (g+1)%ratio==0)的 full 槽位,值 = 压缩池槽位。 # 与 full_to_swa_indexs 同构: radix 持有 full 槽 => 映射行存活,free 级联回收。 self.full_to_c4_indexs: Optional[torch.Tensor] = None @@ -350,8 +369,7 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 生灭 -> radix 命中零拷贝续算。行数 = 页数*ring + ring(HOLD 页) + 1(哨兵), # 取整到 ratio;末行哨兵 kv=0/score=-inf(KVAndScore.clear 语义),其余行由内核在 # 组起点覆写,无需按页清零。last_dim = 2*coff*head_dim(overlap coff=2)。 - state_rows = self.swa_num_pages * DSV4_C4_STATE_RING + DSV4_C4_STATE_RING + 1 - state_rows = _ceil_div(state_rows, 4) * 4 + state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C4_STATE_RING, 4) self.c4_state_buffer = torch.zeros( (self.n_c4, state_rows, 4 * self.head_dim), dtype=torch.float32, device="cuda" ) @@ -359,8 +377,7 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): (self.n_c4, state_rows, 4 * self.indexer_head_dim), dtype=torch.float32, device="cuda" ) for buf in (self.c4_state_buffer, self.c4_indexer_state_buffer): - half = buf.shape[-1] // 2 - buf[:, -1, half:].fill_(float("-inf")) + self._init_state_sentinel(buf) if self.n_c128 > 0: self.c128_pool = PackedPagePool( size=self.c128_size, @@ -375,6 +392,13 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): ) self.full_to_c128_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") self.full_to_c128_indexs[size] = self.c128_pool.HOLD_TOKEN_MEMINDEX + # c128 compressor 在途状态: 与 c4 同样由 full->swa 推导行号,但 ring=128 且无 overlap。 + # last_dim = 2*head_dim;末行是 swa 缺失/出窗时读取的哨兵。 + state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C128_STATE_RING, 128) + self.c128_state_buffer = torch.zeros( + (self.n_c128, state_rows, 2 * self.head_dim), dtype=torch.float32, device="cuda" + ) + self._init_state_sentinel(self.c128_state_buffer) logger.info( f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size}({self.swa_num_pages}p) " @@ -410,6 +434,10 @@ def get_c4_indexer_state_buffer(self, layer_index: int) -> torch.Tensor: assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged indexer state" return self.c4_indexer_state_buffer[self.layer_to_c4_idx[layer_index]] + def get_c128_state_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 128, "只有 c128(HCA) 层有 paged compressor state" + return self.c128_state_buffer[self.layer_to_c128_idx[layer_index]] + # ------------------------------------------------------------------ swa slot lifecycle def set_swa_pressure_valve(self, valve) -> None: """valve(need_pages): 在页 allocator 不足时尝试腾页(radix 对 ref==0 节点回收 swa)。""" diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 24b39b71ce..a0faccb00c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -29,8 +29,8 @@ class DeepseekV4PromptCachePayload: """prompt cache 载荷: 只剩 swa 按页有效性 bitmap。 槽位与 compressor 状态都不进载荷: full_to_swa/full_to_c4/full_to_c128 以 full token 槽位 - 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4 compressor 状态以 swa 页派生 - 寻址(随 swa 页生灭,命中零拷贝续算);c128 状态在 128 边界自然归零,无需恢复。 + 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4/c128 compressor 状态以 swa + 页派生寻址(随 swa 页生灭,命中零拷贝续算)。c128 partial state 不跨 radix 的 128 边界保存。 * ``swa_page_valid``: cpu bool [cache_len // page],插入时按当下 full_to_swa 映射写定 (页内 128 个映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 128 边界, @@ -368,9 +368,8 @@ class DeepseekV4ReqManager(ReqManager): 本类只负责 prep 阶段的分配与 scatter(``prepare_prefill_compress_slots`` / ``prepare_decode_compress_slots``)——必须先于 attention metadata 构建/图捕获; 条目内容由 layer-infer 的 compressor 前向写入。 - * ``req_to_c128_state_pool`` —— c128 compressor 的在途状态(per req、per c128 层)。 - c128 在线聚合在 128 边界自然归零(命中边界必 128 对齐),无缓存常驻需求,保持 req 键控。 - c4 状态(跨边界 overlap)在 mem manager 的 swa 页派生池,随页生灭,命中零拷贝续算。 + * compressor 在途状态不在本类: c4/c128 都在 mem manager 的 swa 页派生池, + 随页生灭,命中零拷贝续算。 * SWA 槽位分配/出窗回收(``prepare_prefill_swa`` / ``prepare_decode_swa``): 每步 prep 阶段 为新 token 调 mem_manager.alloc_swa,并按 per-req 水位线(``_swa_evict_marks``)惰性回收 已出窗位置的 swa 槽。水位线首次置为该请求首个 chunk 的 ready_cache_len(radix 共享前缀 @@ -419,16 +418,6 @@ def __init__( self.layer_to_c128_idx[lid] = c128 c128 += 1 - # c128 compressor 在途状态(fp32): 在线聚合在 128 边界自然归零(命中边界必 128 对齐), - # 无缓存常驻需求,保持 req 键控。c4 状态(有跨边界 overlap)在 mem manager 的 - # swa 页派生池(c4_state_buffer / c4_indexer_state_buffer)。 - self.req_to_c128_state_pool = LayerCache( - size=max_request_num + 1, - dtype=torch.float32, - shape=(1, 128, 2 * head_dim), - layer_num=self.n_c128, - device="cuda", - ) return def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager): @@ -524,20 +513,11 @@ def prepare_decode_swa( self.mem_manager.alloc_swa_decode(b_req_idx, b_seq_len, mem_indexes, self.req_to_token_indexs) return - def _reset_state_pool_req(self, cache: LayerCache, req_idx: int): - if cache.layer_num == 0: - return - cache.buffer[:, req_idx, ...].fill_(0) - return - def init_compress_state(self, req_idx: int): - """新请求开始时重置其 compressor 在途状态(对应 mamba 的 init_linear_att_state)。 + """新请求开始时重置 runtime 水位线(对应 mamba 的 init_linear_att_state 调用点)。 - 只有 c128 状态是 req 键控的(c4 状态随 swa 页生灭,内核组起点覆写,无需重置; - 压缩槽位以 full 槽位为键,随请求 full 槽的释放级联回收)。""" + c4/c128 compressor state 都随 swa 页寻址,由内核按组覆写;请求复用时不做 per-req 清零。""" self.clear_runtime_state(req_idx) - if self.n_c128 > 0: - self._reset_state_pool_req(self.req_to_c128_state_pool, req_idx) return # ------------------------------------------------------------------ compress slot prep (per step) @@ -628,14 +608,6 @@ def clear_runtime_state(self, req_idx: int): self._swa_evict_marks[req_idx] = -1 return - def get_compress_state_pool_for_req(self, layer_index: int, req_idx: int): - assert self.compress_rates[layer_index] == 128, "c4 state 在 mem manager 的 swa 页派生池" - return self.req_to_c128_state_pool.buffer[self.layer_to_c128_idx[layer_index], req_idx] - - def get_compress_state_pool(self, layer_index: int): - assert self.compress_rates[layer_index] == 128, "c4 state 在 mem manager 的 swa 页派生池" - return self.req_to_c128_state_pool.buffer[self.layer_to_c128_idx[layer_index]] - def get_prompt_cache_value_ops(self): return DeepseekV4PromptCacheValueOps(self) @@ -712,6 +684,4 @@ def free_req(self, free_req_index: int): def free_all(self): super().free_all() self._swa_evict_marks = [-1 for _ in range(self.max_request_num + 1)] - if self.n_c128 > 0: - self.req_to_c128_state_pool.buffer.fill_(0) return diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py index 129a93ae0d..4857740992 100644 --- a/lightllm/models/deepseek_v4/layer_infer/compressor.py +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -6,6 +6,12 @@ import triton.language as tl from triton.language.extra import libdevice +from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( + DSV4_C4_STATE_RING, + DSV4_C128_STATE_RING, + DSV4_SWA_PAGE_SIZE, +) + _SGLANG_COMPRESS_ERR = None _SGLANG_COMPRESS_MOD = None @@ -80,7 +86,7 @@ def _save_partial_states_kernel( IS_C4: tl.constexpr, IS_PREFILL: tl.constexpr, SWA_PAGE_SIZE: tl.constexpr, - C4_STATE_RING: tl.constexpr, + STATE_RING: tl.constexpr, BLOCK: tl.constexpr, ): token_idx = tl.program_id(0) @@ -89,19 +95,18 @@ def _save_partial_states_kernel( seq_len = tl.load(b_seq_len + batch_idx) if IS_C4: - same_page_next = (position % SWA_PAGE_SIZE) + C4_STATE_RING < SWA_PAGE_SIZE - if same_page_next and position + C4_STATE_RING < seq_len: - return - full_slot = tl.load(mem_index + token_idx).to(tl.int64) - swa_slot = tl.load(full_to_swa + full_slot).to(tl.int64) - if swa_slot < 0: + same_page_next = (position % SWA_PAGE_SIZE) + STATE_RING < SWA_PAGE_SIZE + if same_page_next and position + STATE_RING < seq_len: return - state_row = (swa_slot // SWA_PAGE_SIZE) * C4_STATE_RING + (swa_slot % C4_STATE_RING) else: if position + COMPRESS_RATIO < seq_len: return - req_idx = tl.load(b_req_idx + batch_idx).to(tl.int64) - state_row = req_idx * COMPRESS_RATIO + (position % COMPRESS_RATIO) + + full_slot = tl.load(mem_index + token_idx).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot).to(tl.int64) + if swa_slot < 0: + return + state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING) offs = tl.arange(0, BLOCK) mask = offs < STATE_WIDTH @@ -144,7 +149,7 @@ def _fused_compress_norm_rope_insert_kernel( IS_C4: tl.constexpr, IS_PREFILL: tl.constexpr, SWA_PAGE_SIZE: tl.constexpr, - C4_STATE_RING: tl.constexpr, + STATE_RING: tl.constexpr, ROPE_HEAD_DIM: tl.constexpr, FP8_MAX: tl.constexpr, SCALE_MIN: tl.constexpr, @@ -188,12 +193,18 @@ def _fused_compress_norm_rope_insert_kernel( other=0, ).to(tl.int64) swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64) - state_row = (swa_slot // SWA_PAGE_SIZE) * C4_STATE_RING + (swa_slot % C4_STATE_RING) + state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING) state_valid = valid_pos & (~use_current) & (swa_slot >= 0) head_offset = tl.where(token_offsets >= COMPRESS_RATIO, HEAD_DIM, 0) else: - state_row = req_idx * COMPRESS_RATIO + (gather_pos % COMPRESS_RATIO) - state_valid = valid_pos & (~use_current) + full_slot = tl.load( + req_to_token + req_idx * req_to_token_stride0 + gather_pos, + mask=valid_pos & (~use_current), + other=0, + ).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64) + state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING) + state_valid = valid_pos & (~use_current) & (swa_slot >= 0) head_offset = token_offsets * 0 offs = tl.arange(0, BLOCK) @@ -288,7 +299,7 @@ def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int) out_pool = mem_manager.c4_pool elif compress_ratio == 128: out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)] - state_buffer = infer_state.req_manager.get_compress_state_pool(layer_idx) + state_buffer = mem_manager.get_c128_state_buffer(layer_idx) out_pool = mem_manager.c128_pool else: raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") @@ -367,6 +378,7 @@ def fused_compress( state_width = kv_score.shape[-1] // 2 state_last_dim = metadata.state_buffer.shape[-1] is_c4 = compress_ratio == 4 + state_ring = DSV4_C4_STATE_RING if is_c4 else DSV4_C128_STATE_RING block_state = triton.next_power_of_2(state_width) block_head = triton.next_power_of_2(head_dim) @@ -399,8 +411,8 @@ def fused_compress( WINDOW_SIZE=compress_ratio * (2 if is_c4 else 1), IS_C4=is_c4, IS_PREFILL=metadata.is_prefill, - SWA_PAGE_SIZE=128, - C4_STATE_RING=8, + SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE, + STATE_RING=state_ring, ROPE_HEAD_DIM=qk_rope_head_dim, FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, SCALE_MIN=1e-4, @@ -429,8 +441,8 @@ def fused_compress( COMPRESS_RATIO=compress_ratio, IS_C4=is_c4, IS_PREFILL=metadata.is_prefill, - SWA_PAGE_SIZE=128, - C4_STATE_RING=8, + SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE, + STATE_RING=state_ring, BLOCK=block_state, num_warps=4, ) @@ -623,7 +635,7 @@ def compressor_decode_step_batch( return out.to(x_new.dtype), should_compress -# ---------------------------------------------------------------------------- paged state (c4) +# ---------------------------------------------------------------------------- paged state # 与 sglang srt compressor 的 paged 路径同构(compress_old 内核 + 分组槽 indices + overlap # extra_data): state 槽位由 swa 槽位算术派生(翻译③ state_loc = page*ring + swa_loc%ring, # 分组槽 = state_loc//ratio),state 随 swa 页生灭,radix 命中零拷贝续算。 @@ -667,35 +679,49 @@ def paged_decode_state_slots( ratio: int, hold_req_id: int, num_swa_pages: int, + overlap: bool = True, ): - """decode 步的 state 分组槽(写槽 = 当前组 clip_down(seq-1) 的槽,overlap 伙伴 = 前一组)。 + """decode 步的 state 分组槽(写槽 = 当前组 clip_down(seq-1) 的槽,可选 overlap 前一组)。 纯张量算术(prep 已写本步 req_to_token),图安全。padding(HOLD)行重定向到 HOLD 页的 state 槽,隔离其垃圾累加。""" seq = b_seq_len.long() write_positions = torch.div(seq - 1, ratio, rounding_mode="floor") * ratio write_slot = _paged_state_group_slot(req_to_token, full_to_swa, b_req_idx, write_positions, page_size, ring, ratio) - overlap_slot = _paged_state_group_slot( - req_to_token, full_to_swa, b_req_idx, write_positions - ratio, page_size, ring, ratio - ) + overlap_slot = None + if overlap: + overlap_slot = _paged_state_group_slot( + req_to_token, full_to_swa, b_req_idx, write_positions - ratio, page_size, ring, ratio + ) hold_slot = num_swa_pages * ring // ratio # HOLD 页区域([pages*ring, pages*ring+ring))的首个分组槽 is_hold = b_req_idx.long() == hold_req_id write_slot = torch.where(is_hold, torch.full_like(write_slot, hold_slot), write_slot) - overlap_slot = torch.where(is_hold, torch.full_like(overlap_slot, hold_slot), overlap_slot) + if overlap_slot is not None: + overlap_slot = torch.where(is_hold, torch.full_like(overlap_slot, hold_slot), overlap_slot) return write_slot, overlap_slot -def paged_prefill_compress_data(req_to_token, full_to_swa, req_idx: int, ready_len: int, seq_len: int, ring: int): +def paged_prefill_compress_data( + req_to_token, + full_to_swa, + req_idx: int, + ready_len: int, + seq_len: int, + ring: int, + ratio: int = 4, + page_size: int = DSV4_SWA_PAGE_SIZE, + overlap: bool = True, +): """单请求 prefill chunk 的 (indices, extra_data, plan): 与 sglang 同走 - triton_create_paged_compress_data(按请求产出,内核经 plan 逐 token 步进)。仅 c4(overlap)。 + triton_create_paged_compress_data(按请求产出,内核经 plan 逐 token 步进)。 三者都与层无关,同一 forward 内可跨全部 c4 层复用。""" mod, _ = _load_sglang_compressor() fn = _load_paged_compress_data_fn() device = req_to_token.device n_new = seq_len - ready_len write_loc, extra_data = fn( - compress_ratio=4, - is_overlap=True, - swa_page_size=128, + compress_ratio=ratio, + is_overlap=overlap, + swa_page_size=page_size, ring_size=ring, req_pool_indices=torch.tensor([req_idx], device=device, dtype=torch.int64), seq_lens=torch.tensor([seq_len], device=device, dtype=torch.int64), @@ -704,7 +730,7 @@ def paged_prefill_compress_data(req_to_token, full_to_swa, req_idx: int, ready_l full_to_swa_index_mapping=full_to_swa, ) plan = mod.CompressorPrefillPlan.generate( - 4, + ratio, n_new, torch.tensor([seq_len], dtype=torch.int64), torch.tensor([n_new], dtype=torch.int64), @@ -727,15 +753,16 @@ def compressor_paged_prefill( compress_data, ready_len, seq_len, + ratio: int = 4, ): - """单请求 prefill/extend chunk(c4 paged): x [n_new, dim] 为位置 [ready, seq) 的 hidden, + """单请求 prefill/extend chunk(paged): x [n_new, dim] 为位置 [ready, seq) 的 hidden, state 写到 swa 派生的分组槽(compress_data 来自 paged_prefill_compress_data,跨层复用)。 - 返回本 chunk 完结组的压缩条目 [seq//4 - ready//4, head_dim](rope 已施加)。""" + 返回本 chunk 完结组的压缩条目 [seq//ratio - ready//ratio, head_dim](rope 已施加)。""" mod, _ = _load_sglang_compressor() - ratio = 4 kv_score = _project_kv_score(x, wkv_w, wgate_w) pool = state_buffer.view(-1, ratio, state_buffer.shape[-1]) write_loc, extra_data, plan = compress_data + kwargs = {"extra_data": extra_data} if extra_data is not None else {} out = mod.compress_forward( pool, kv_score, @@ -744,7 +771,7 @@ def compressor_paged_prefill( plan, head_dim=head_dim, compress_ratio=ratio, - extra_data=extra_data, + **kwargs, ) ncomp = seq_len // ratio - ready_len // ratio if ncomp == 0: @@ -768,15 +795,16 @@ def compressor_paged_decode_batch( write_slot, overlap_slot, b_seq_len, + ratio: int = 4, ): - """批量 decode 一步(c4 paged): state 槽位为 swa 派生分组槽(paged_decode_state_slots, + """批量 decode 一步(paged): state 槽位为 swa 派生分组槽(paged_decode_state_slots, 可跨层复用)。返回 (entries [bs, head_dim], should_compress [bs])。""" mod, _ = _load_sglang_compressor() - ratio = 4 kv_score = _project_kv_score(x_new, wkv_w, wgate_w) pool = state_buffer.view(-1, ratio, state_buffer.shape[-1]) seq_lens = b_seq_len.to(torch.int32).contiguous() plan = mod.CompressorDecodePlan(ratio, seq_lens) + kwargs = {"extra_data": overlap_slot.view(-1, 1)} if overlap_slot is not None else {} out = mod.compress_forward( pool, kv_score, @@ -785,7 +813,7 @@ def compressor_paged_decode_batch( plan, head_dim=head_dim, compress_ratio=ratio, - extra_data=overlap_slot.view(-1, 1), + **kwargs, ) should_compress = (seq_lens % ratio) == 0 mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) From 07d230865f598d7d5079c757e7d65457e1d389b2 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 15 Jun 2026 06:37:11 +0000 Subject: [PATCH 19/30] refact --- .../attention/nsa/fp8_flashmla_sparse.py | 279 ++---------- lightllm/models/deepseek_v4/infer_struct.py | 23 + .../deepseek_v4/layer_infer/compressor.py | 429 ++---------------- .../layer_infer/transformer_layer_infer.py | 247 ++++++++-- .../triton_kernel/build_swa_index_dsv4.py | 75 +++ 5 files changed, 371 insertions(+), 682 deletions(-) create mode 100644 lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 14b1b3307d..03595c1bce 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -1,5 +1,4 @@ import dataclasses -import inspect import torch from typing import TYPE_CHECKING, Tuple @@ -10,7 +9,6 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo -FLASHMLA_INDEX_ALIGN = 64 # this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up # to the nearest supported count (zero heads are discarded from the output slice). FLASHMLA_SUPPORTED_HEADS = (64, 128) @@ -39,15 +37,6 @@ def _missing_attention_op(feature: str) -> None: ) -def _pad_last_dim(x: torch.Tensor, multiple: int = FLASHMLA_INDEX_ALIGN, value: int = -1) -> torch.Tensor: - pad = (-x.shape[-1]) % multiple - if pad == 0: - return x.contiguous() - out = torch.full((*x.shape[:-1], x.shape[-1] + pad), value, dtype=x.dtype, device=x.device) - out[..., : x.shape[-1]] = x - return out.contiguous() - - def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor: from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN @@ -55,191 +44,6 @@ def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> tor return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN) -def _load_flash_mla_with_extra(): - try: - import flash_mla - except Exception as exc: - raise DeepseekV4MissingOperatorError( - "DeepSeek-V4 packed FlashMLA requires the flash_mla package with compiled CUDA extension. " - f"Import failed with: {type(exc).__name__}: {exc}" - ) from exc - - fn = getattr(flash_mla, "flash_mla_with_kvcache", None) - get_mla_metadata = getattr(flash_mla, "get_mla_metadata", None) - missing_symbols = [] - if fn is None: - missing_symbols.append("flash_mla_with_kvcache") - if get_mla_metadata is None: - missing_symbols.append("get_mla_metadata") - if missing_symbols: - raise DeepseekV4MissingOperatorError( - "DeepSeek-V4 requires flash_mla.flash_mla_with_kvcache extra-cache wrapper. " - f"Current module={getattr(flash_mla, '__file__', '')} " - f"is missing symbols {missing_symbols}." - ) - - sig = inspect.signature(fn) - required = { - "attn_sink", - "extra_k_cache", - "extra_indices_in_kvcache", - "topk_length", - "extra_topk_length", - } - missing = sorted(required.difference(sig.parameters)) - if missing: - raise DeepseekV4MissingOperatorError( - "DeepSeek-V4 requires flash_mla.flash_mla_with_kvcache with extra-cache arguments. " - f"Current module={getattr(flash_mla, '__file__', '')} is missing {missing}." - ) - return flash_mla - - -def _build_dsv4_repeated_prefill_reqs(infer_state) -> torch.Tensor: - return torch.repeat_interleave(infer_state.b_req_idx, infer_state.b_q_seq_len.long()) - - -def _build_dsv4_prefill_positions(infer_state) -> torch.Tensor: - total = infer_state.total_token_num - infer_state.prefix_total_token_num - token_offsets = torch.arange(total, dtype=torch.int32, device=infer_state.b_q_seq_len.device) - req_ids = torch.repeat_interleave( - torch.arange(infer_state.batch_size, dtype=torch.long, device=infer_state.b_q_seq_len.device), - infer_state.b_q_seq_len.long(), - ) - local_offsets = token_offsets - infer_state.b_q_start_loc[req_ids] - return infer_state.b_ready_cache_len[req_ids] + local_offsets - - -def _build_dsv4_swa_indices( - req_manager, - mem_manager, - req_idx: torch.Tensor, - positions: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - window = int(mem_manager.sliding_window) - offsets = positions[:, None] - torch.arange(window, dtype=positions.dtype, device=positions.device)[None, :] - valid_pos = offsets >= 0 - safe_offsets = offsets.clamp_min(0).long() - full_slots = req_manager.req_to_token_indexs[req_idx.long()[:, None], safe_offsets] - swa_slots = mem_manager.full_to_swa_indexs[full_slots.long()].to(torch.int32) - indices = torch.where(valid_pos, swa_slots, torch.full_like(swa_slots, -1)) - lengths = torch.clamp(positions + 1, min=1, max=window).to(torch.int32) - return _pad_last_dim(indices.to(torch.int32)).unsqueeze(1), lengths.contiguous() - - -def _gather_dsv4_compress_slots( - infer_state, - mapping: torch.Tensor, - req_idx: torch.Tensor, - valid: torch.Tensor, - offsets: torch.Tensor, - ratio: int, -) -> torch.Tensor: - """条目 g 的压缩槽 = full_to_c*[req_to_token[req, (g+1)*ratio-1]](组末 token 的 full 槽位)。 - 无效条目(超出因果长度/HOLD 行)用位置 0 安全 gather 后由调用方按 valid 掩掉。""" - end_pos = offsets[None, :] * ratio + (ratio - 1) - safe_pos = torch.where(valid, end_pos, torch.zeros_like(end_pos)) - full_slots = infer_state.req_manager.req_to_token_indexs[req_idx.long()[:, None], safe_pos] - return mapping[full_slots.long()].to(torch.int32) - - -def _build_dsv4_c128_indices( - infer_state, - req_idx: torch.Tensor, - positions: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - raw_lengths = (positions + 1) // 128 - lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) - max_len = max(1, int(infer_state.max_kv_seq_len) // 128) - offsets = torch.arange(max_len, dtype=torch.long, device=positions.device) - valid = offsets[None, :] < raw_lengths[:, None] - slots = _gather_dsv4_compress_slots( - infer_state, infer_state.mem_manager.full_to_c128_indexs, req_idx, valid, offsets, 128 - ) - indices = torch.where(valid, slots, torch.full_like(slots, -1)) - return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() - - -def _build_dsv4_c4_indices( - infer_state, - layer_index: int, - req_idx: torch.Tensor, - positions: torch.Tensor, - nsa_dict: dict, -) -> Tuple[torch.Tensor, torch.Tensor]: - """c4(CSA) extra indices: causal all-entries when the entry space fits index_topk, - otherwise Lightning-Indexer scored top-k. Pure tensor ops (decode runs inside cuda graphs).""" - import torch.distributed as dist - import torch.nn.functional as F - from lightllm.distributed.communication_op import all_reduce - - mem_manager = infer_state.mem_manager - raw_lengths = (positions + 1) // 4 - max_entries = max(1, int(infer_state.max_kv_seq_len) // 4) - index_topk = int(nsa_dict["index_topk"]) - offsets = torch.arange(max_entries, dtype=torch.long, device=positions.device) - valid = offsets[None, :] < raw_lengths[:, None] - slots = _gather_dsv4_compress_slots(infer_state, mem_manager.full_to_c4_indexs, req_idx, valid, offsets, 4) - - if max_entries <= index_topk: - lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) - indices = torch.where(valid, slots, torch.full_like(slots, -1)) - return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() - - idx_q = nsa_dict["idx_q"] # [T, H, index_head_dim], rope applied - idx_weight = nsa_dict["idx_weight"] # [T, H] fp32, weight scale applied - score_scale = float(nsa_dict["indexer_score_scale"]) - hold_slot = mem_manager.c4_indexer_pool.HOLD_TOKEN_MEMINDEX - safe_slots = torch.where(valid, slots.long(), torch.full_like(slots.long(), hold_slot)) - k = mem_manager.gather_indexer_k(layer_index, safe_slots.reshape(-1)).view(positions.shape[0], max_entries, -1) - - num_tokens, num_heads = idx_q.shape[0], idx_q.shape[1] - score_chunks = [] - chunk = max(1, min(num_tokens, (16 * 1024 * 1024) // max(1, num_heads * max_entries))) - for start in range(0, num_tokens, chunk): - end = min(num_tokens, start + chunk) - scores = torch.einsum("thd,tnd->thn", idx_q[start:end].float(), k[start:end].float()) - scores = F.relu(scores) * score_scale - score_chunks.append((scores * idx_weight[start:end].unsqueeze(-1)).sum(dim=1)) - index_scores = torch.cat(score_chunks, dim=0) - if int(nsa_dict.get("tp_world_size", 1)) > 1: - all_reduce(index_scores, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - index_scores = index_scores.masked_fill(~valid, float("-inf")) - top = index_scores.topk(index_topk, dim=-1).indices - top_valid = torch.gather(valid, 1, top) - top_slots = torch.gather(slots.long(), 1, top).to(torch.int32) - indices = torch.where(top_valid, top_slots, torch.full_like(top_slots, -1)) - lengths = torch.clamp(torch.minimum(raw_lengths, torch.full_like(raw_lengths, index_topk)), min=1) - return _pad_last_dim(indices).unsqueeze(1), lengths.to(torch.int32).contiguous() - - -def _build_dsv4_extra_metadata( - infer_state, - layer_index: int, - compress_ratio: int, - req_idx: torch.Tensor, - positions: torch.Tensor, - swa_indices: torch.Tensor, - swa_lengths: torch.Tensor, - nsa_dict: dict, -) -> "_Dsv4Metadata": - from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE - - if compress_ratio == 0: - return _Dsv4Metadata(swa_indices, swa_lengths) - if compress_ratio == 4: - extra_indices, extra_lengths = _build_dsv4_c4_indices(infer_state, layer_index, req_idx, positions, nsa_dict) - extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(layer_index) - extra_cache = _view_dsv4_flashmla_cache(extra_buffer, DSV4_C4_PAGE_SIZE) - return _Dsv4Metadata(swa_indices, swa_lengths, extra_cache, extra_indices, extra_lengths) - if compress_ratio == 128: - extra_indices, extra_lengths = _build_dsv4_c128_indices(infer_state, req_idx, positions) - extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(layer_index) - extra_cache = _view_dsv4_flashmla_cache(extra_buffer, DSV4_C128_PAGE_SIZE) - return _Dsv4Metadata(swa_indices, swa_lengths, extra_cache, extra_indices, extra_lengths) - raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") - - @dataclasses.dataclass class _Dsv4Metadata: swa_indices: torch.Tensor @@ -249,6 +53,28 @@ class _Dsv4Metadata: extra_lengths: torch.Tensor = None +def _metadata_from_dict(infer_state, nsa_dict: dict) -> "_Dsv4Metadata": + """Bundle the model-built FINAL index tensors (carried in nsa_dict by DeepseekV4IndexInfer) with + the layer-keyed fp8 extra-cache byte view. The cache view is data-independent (a fixed per-layer + buffer slice), so it is built here -- a genuine flash_mla ABI concern -- rather than on the model + side; only the index/length tensors cross the att_control boundary.""" + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE + + ratio = nsa_dict["compress_ratio"] + extra_cache = None + if ratio: + page = DSV4_C4_PAGE_SIZE if ratio == 4 else DSV4_C128_PAGE_SIZE + extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(nsa_dict["layer_index"]) + extra_cache = _view_dsv4_flashmla_cache(extra_buffer, page) + return _Dsv4Metadata( + swa_indices=nsa_dict["swa_indices"], + swa_lengths=nsa_dict["swa_lengths"], + extra_cache=extra_cache, + extra_indices=nsa_dict.get("extra_indices"), + extra_lengths=nsa_dict.get("extra_lengths"), + ) + + class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) @@ -257,7 +83,6 @@ def __init__(self, model): torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) for _ in range(2) ] - self._flash_mla = None def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) @@ -265,11 +90,6 @@ def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMl def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) - def flash_mla(self): - if self._flash_mla is None: - self._flash_mla = _load_flash_mla_with_extra() - return self._flash_mla - @dataclasses.dataclass class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): @@ -360,30 +180,9 @@ def _nsa_prefill_att( ) return mla_out - def _build_flashmla_kvcache_prefill_metadata(self, nsa_dict: dict) -> _Dsv4Metadata: - infer_state = self.infer_state - req_idx = _build_dsv4_repeated_prefill_reqs(infer_state) - positions = _build_dsv4_prefill_positions(infer_state) - swa_indices, swa_lengths = _build_dsv4_swa_indices( - infer_state.req_manager, - infer_state.mem_manager, - req_idx, - positions, - ) - return _build_dsv4_extra_metadata( - infer_state, - nsa_dict["layer_index"], - nsa_dict["compress_ratio"], - req_idx, - positions, - swa_indices, - swa_lengths, - nsa_dict, - ) - def _flashmla_kvcache_prefill_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor: attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous() - metadata = self._build_flashmla_kvcache_prefill_metadata(nsa_dict) + metadata = _metadata_from_dict(self.infer_state, nsa_dict) return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict) def _flashmla_kvcache_att( @@ -394,7 +193,7 @@ def _flashmla_kvcache_att( attn_sink: torch.Tensor, nsa_dict: dict, ) -> torch.Tensor: - flash_mla = self.backend.flash_mla() + import flash_mla from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE q_4d = q.unsqueeze(1).contiguous() @@ -458,7 +257,8 @@ def init_state(self): ragged_mem_index=self.ragged_mem_index, hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) - flash_mla = self.backend.flash_mla() + import flash_mla + # one sched_meta per layer type: the lazy config locks extra-cache geometry (page size, # presence) on first invocation, so swa-only/c4/c128 layers must not share one object. self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} @@ -468,7 +268,8 @@ def reset_sched_meta_for_capture(self): # cuda-graph capture hook: the warmup pass already locked/stored sched meta on this # (shared) state object; reset so the capture pass re-plans INSIDE the graph and every # replay re-plans from the live tensors instead of binding warmup leftovers. - flash_mla = self.backend.flash_mla() + import flash_mla + self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} return @@ -539,29 +340,9 @@ def _nsa_decode_att( ) return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] - def _build_flashmla_kvcache_decode_metadata(self, nsa_dict: dict) -> _Dsv4Metadata: - infer_state = self.infer_state - positions = infer_state.b_seq_len.to(torch.int32) - 1 - swa_indices, swa_lengths = _build_dsv4_swa_indices( - infer_state.req_manager, - infer_state.mem_manager, - infer_state.b_req_idx, - positions, - ) - return _build_dsv4_extra_metadata( - infer_state, - nsa_dict["layer_index"], - nsa_dict["compress_ratio"], - infer_state.b_req_idx, - positions, - swa_indices, - swa_lengths, - nsa_dict, - ) - def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor: attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous() - metadata = self._build_flashmla_kvcache_decode_metadata(nsa_dict) + metadata = _metadata_from_dict(self.infer_state, nsa_dict) return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict) def _flashmla_kvcache_att( @@ -572,7 +353,7 @@ def _flashmla_kvcache_att( attn_sink: torch.Tensor, nsa_dict: dict, ) -> torch.Tensor: - flash_mla = self.backend.flash_mla() + import flash_mla from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE q_4d = q.unsqueeze(1).contiguous() diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py index caf8ca6fa8..ca2ac83b03 100644 --- a/lightllm/models/deepseek_v4/infer_struct.py +++ b/lightllm/models/deepseek_v4/infer_struct.py @@ -18,6 +18,11 @@ def __init__(self): self.position_sin_sliding = None self.position_cos_compress = None self.position_sin_compress = None + # layer-independent sparse-index metadata, built once per forward in init_some_extra_state + # (None until then so copy_for_cuda_graph's tensor-attr loop skips them). + self.dsv4_sparse_req_idx = None + self.dsv4_swa_indices = None + self.dsv4_swa_lengths = None def init_some_extra_state(self, model): super().init_some_extra_state(model) # sets position_ids, b_q_seq_len, b_q_start_loc (prefill) @@ -26,6 +31,24 @@ def init_some_extra_state(self, model): self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos) self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos) self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos) + # Per-token request id (decode: one token per req; prefill: ragged -> repeat by q-len). + # Layer-independent; the swa kernel + build_metadata's c4/c128 readers all reuse it. + if self.is_prefill: + self.dsv4_sparse_req_idx = torch.repeat_interleave(self.b_req_idx, self.b_q_seq_len.long()) + else: + self.dsv4_sparse_req_idx = self.b_req_idx + # Sliding-window indices: layer-independent (full_to_swa is global, window is const), so build + # once here via one fused kernel instead of recomputing per layer. const [T, window] shape is + # cuda-graph-safe (no max_kv_seq_len dependence) and auto-staged by copy_for_cuda_graph. + from lightllm.models.deepseek_v4.triton_kernel.build_swa_index_dsv4 import build_swa_index + + self.dsv4_swa_indices, self.dsv4_swa_lengths = build_swa_index( + req_idx=self.dsv4_sparse_req_idx, + positions=self.position_ids, + req_to_token_indexs=self.req_manager.req_to_token_indexs, + full_to_swa_indexs=self.mem_manager.full_to_swa_indexs, + window=int(self.mem_manager.sliding_window), + ) # prefill-cudagraph 桶填充的 HOLD 尾请求的 q 行数。其注意力读 HOLD 槽位(内容被并发写 # 竞争,每轮不同),输出必须清零,否则 pad 行 hidden 不确定 -> MoE 路由抖动 -> 共享 expert # 批次组成变化 -> 真实行 GEMM 归约顺序变化(ulp 级),44 层放大后翻转低置信 token。 diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py index 4857740992..c66fd22de5 100644 --- a/lightllm/models/deepseek_v4/layer_infer/compressor.py +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -6,6 +6,7 @@ import triton.language as tl from triton.language.extra import libdevice +from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( DSV4_C4_STATE_RING, DSV4_C128_STATE_RING, @@ -13,12 +14,6 @@ ) -_SGLANG_COMPRESS_ERR = None -_SGLANG_COMPRESS_MOD = None -_SGLANG_LINEAR_BF16_FP32 = None -_FREQ_CIS_CACHE = {} - - @dataclass class CoreCompressorMetadata: layer_idx: int @@ -159,6 +154,7 @@ def _fused_compress_norm_rope_insert_kernel( PAGE_SIZE: tl.constexpr, BYTES_PER_PAGE: tl.constexpr, BLOCK: tl.constexpr, + OUTPUT_BF16: tl.constexpr, ): token_idx = tl.program_id(0) out_slot = tl.load(out_slots + token_idx).to(tl.int64) @@ -260,6 +256,13 @@ def _fused_compress_norm_rope_insert_kernel( new_odd = odd * cos_v + even * sin_v rotated = tl.interleave(new_even, new_odd) + if OUTPUT_BF16: + # indexer-K path: emit the post-rope full HEAD_DIM vector as dense bf16 (token-indexed), + # leaving the fp8 single-amax pack to destindex_copy_indexer_k_dsv4 (the c4_indexer_pool + # ABI differs from the latent slab: whole-vector fp8 + one fp32 scale, no bf16 rope tail). + tl.store(out_buffer + token_idx * HEAD_DIM + offs, rotated.to(tl.bfloat16), mask=dim_mask) + return + page = out_slot // PAGE_SIZE token_in_page = out_slot % PAGE_SIZE data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_HEAD_DIM * 2) @@ -288,21 +291,38 @@ def _fused_compress_norm_rope_insert_kernel( return -def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int): +def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int, is_in_indexer: bool = False): if compress_ratio == 0: return None - mem_manager = infer_state.mem_manager - if compress_ratio == 4: + mem_manager: DeepseekV4MemoryManager = infer_state.mem_manager + if is_in_indexer: + # c4 Lightning-Indexer key compression: same window/state machinery as the c4 latent + # compressor but with index_head_dim, a separate state pool, and a DENSE bf16 scratch + # out_buffer (the kernel's OUTPUT_BF16 path); the fp8 pack into c4_indexer_pool is done + # afterwards by pack_indexer_k_to_cache. + assert compress_ratio == 4, "只有 c4(CSA) 层有 indexer-K" out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] - state_buffer = mem_manager.get_c4_state_buffer(layer_idx) - out_pool = mem_manager.c4_pool - elif compress_ratio == 128: - out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)] - state_buffer = mem_manager.get_c128_state_buffer(layer_idx) - out_pool = mem_manager.c128_pool + state_buffer = mem_manager.get_c4_indexer_state_buffer(layer_idx) + out_buffer = torch.empty( + (infer_state.mem_index.numel(), mem_manager.indexer_head_dim), + dtype=torch.bfloat16, + device=infer_state.mem_index.device, + ) + out_page_size = 1 # unused under OUTPUT_BF16 (token-indexed dense scratch, not paged) else: - raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") + if compress_ratio == 4: + out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = mem_manager.get_c4_state_buffer(layer_idx) + out_pool = mem_manager.c4_pool + elif compress_ratio == 128: + out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = mem_manager.get_c128_state_buffer(layer_idx) + out_pool = mem_manager.c128_pool + else: + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") + out_buffer = mem_manager.get_compressed_kv_buffer(layer_idx) + out_page_size = out_pool.page_size token_to_batch_idx = infer_state.b_req_idx if infer_state.is_prefill: @@ -319,8 +339,8 @@ def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int) out_slots=out_slots, mem_index=infer_state.mem_index, state_buffer=state_buffer, - out_buffer=mem_manager.get_compressed_kv_buffer(layer_idx), - out_page_size=out_pool.page_size, + out_buffer=out_buffer, + out_page_size=out_page_size, position_ids=infer_state.position_ids, b_req_idx=infer_state.b_req_idx, b_seq_len=infer_state.b_seq_len, @@ -371,6 +391,7 @@ def fused_compress( compress_ratio: int, cos_table: torch.Tensor, sin_table: torch.Tensor, + output_bf16: bool = False, ): if metadata is None or kv_score.shape[0] == 0: return @@ -422,6 +443,7 @@ def fused_compress( PAGE_SIZE=metadata.out_page_size, BYTES_PER_PAGE=metadata.out_buffer.shape[-1], BLOCK=block_head, + OUTPUT_BF16=output_bf16, num_warps=4, ) @@ -447,374 +469,3 @@ def fused_compress( num_warps=4, ) return - - -def _load_sglang_compressor(): - global _SGLANG_COMPRESS_ERR, _SGLANG_COMPRESS_MOD, _SGLANG_LINEAR_BF16_FP32 - if _SGLANG_COMPRESS_MOD is not None: - return _SGLANG_COMPRESS_MOD, _SGLANG_LINEAR_BF16_FP32 - if _SGLANG_COMPRESS_ERR is not None: - raise _SGLANG_COMPRESS_ERR - try: - from sglang.jit_kernel.dsv4 import linear_bf16_fp32 - from sglang.jit_kernel.dsv4 import compress_old as compress_mod - except Exception as exc: - _SGLANG_COMPRESS_ERR = RuntimeError( - "DeepSeek-V4 fused compressor requires sglang.jit_kernel.dsv4 " - "(linear_bf16_fp32 + compress_old). Install/export the SGLang package " - "or vendor the DSv4 compressor JIT into LightLLM." - ) - raise _SGLANG_COMPRESS_ERR from exc - _SGLANG_COMPRESS_MOD = compress_mod - _SGLANG_LINEAR_BF16_FP32 = linear_bf16_fp32 - return compress_mod, linear_bf16_fp32 - - -def _load_paged_compress_data_fn(): - from sglang.jit_kernel.dsv4 import triton_create_paged_compress_data - - return triton_create_paged_compress_data - - -def _freq_cis(cos_table, sin_table): - key = ( - cos_table.data_ptr(), - sin_table.data_ptr(), - cos_table.device, - tuple(cos_table.shape), - tuple(sin_table.shape), - ) - cached = _FREQ_CIS_CACHE.get(key) - if cached is None: - cached = torch.complex(cos_table.float(), sin_table.float()) - _FREQ_CIS_CACHE[key] = cached - return cached - - -def _sglang_ape(ape, ratio, head_dim): - if ratio == 4: - return torch.cat([ape[:, :head_dim], ape[:, head_dim:]], dim=0).contiguous() - return ape.contiguous() - - -def _compressor_weight(wkv_w, wgate_w): - return torch.cat([wkv_w, wgate_w], dim=0).contiguous() - - -def _project_kv_score(x, wkv_w, wgate_w): - _, linear_bf16_fp32 = _load_sglang_compressor() - return linear_bf16_fp32(x, _compressor_weight(wkv_w, wgate_w)) - - -def _state_pool_view(state_pool): - if state_pool is None: - raise RuntimeError("DeepSeek-V4 fused compressor requires a persistent state_pool") - if state_pool.dim() == 4 and state_pool.shape[1] == 1: - return state_pool.squeeze(1) - return state_pool - - -def compressor_prefill_state( - x, - wkv_w, - wgate_w, - norm_w, - ape, - ratio, - head_dim, - cos_table, - sin_table, - eps, - state_pool, -): - """start_pos==0 prefill for ONE request: x [s, dim] -> compressed entries [s//ratio, head_dim] - (rope applied). state_pool is the request's persistent jit state slice [1, slots, coff*2*head_dim]; - it is rebuilt in place so the decode path can continue from the trailing partial window.""" - mod, _ = _load_sglang_compressor() - kv_score = _project_kv_score(x, wkv_w, wgate_w) - pool = _state_pool_view(state_pool) - pool.zero_() - seq_len = x.shape[0] - plan = mod.CompressorPrefillPlan.generate( - ratio, - seq_len, - torch.tensor([seq_len], dtype=torch.int64), - torch.tensor([seq_len], dtype=torch.int64), - x.device, - ) - indices = torch.zeros((1,), device=x.device, dtype=torch.int32) - out = mod.compress_forward( - pool, - kv_score, - _sglang_ape(ape.float(), ratio, head_dim), - indices, - plan, - head_dim=head_dim, - compress_ratio=ratio, - ) - ncomp = seq_len // ratio - if ncomp == 0: - return x.new_zeros(0, head_dim) - mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - ragged_ids = plan.compress_plan.view(torch.int32)[:ncomp, 0].long() - return out.index_select(0, ragged_ids).to(x.dtype) - - -def compressor_decode_step_single( - x_new, - wkv_w, - wgate_w, - norm_w, - ape, - ratio, - head_dim, - cos_table, - sin_table, - eps, - state_pool, - start_pos, -): - """One token for ONE request (chunked-prefill extend path). Returns the finished compressed - entry [head_dim] when (start_pos+1) % ratio == 0, else None. Mutates state_pool in place.""" - mod, _ = _load_sglang_compressor() - kv_score = _project_kv_score(x_new.view(1, -1), wkv_w, wgate_w) - pool = _state_pool_view(state_pool) - seq_len = start_pos + 1 - plan = mod.CompressorDecodePlan( - ratio, - torch.tensor([seq_len], device=x_new.device, dtype=torch.int32), - ) - indices = torch.zeros((1,), device=x_new.device, dtype=torch.int32) - out = mod.compress_forward( - pool, - kv_score, - _sglang_ape(ape.float(), ratio, head_dim), - indices, - plan, - head_dim=head_dim, - compress_ratio=ratio, - ) - if seq_len % ratio != 0: - return None - mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - return out[0].to(x_new.dtype) - - -def compressor_decode_step_batch( - x_new, - wkv_w, - wgate_w, - norm_w, - ape, - ratio, - head_dim, - rope_dim, - cos_table, - sin_table, - eps, - state_pool, - b_req_idx, - start_pos, -): - mod, _ = _load_sglang_compressor() - kv_score = _project_kv_score(x_new, wkv_w, wgate_w) - pool = _state_pool_view(state_pool) - seq_lens = (start_pos + 1).to(torch.int32).contiguous() - plan = mod.CompressorDecodePlan(ratio, seq_lens) - out = mod.compress_forward( - pool, - kv_score, - _sglang_ape(ape.float(), ratio, head_dim), - b_req_idx.to(torch.int32).contiguous(), - plan, - head_dim=head_dim, - compress_ratio=ratio, - ) - should_compress = (seq_lens % ratio) == 0 - mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - return out.to(x_new.dtype), should_compress - - -# ---------------------------------------------------------------------------- paged state -# 与 sglang srt compressor 的 paged 路径同构(compress_old 内核 + 分组槽 indices + overlap -# extra_data): state 槽位由 swa 槽位算术派生(翻译③ state_loc = page*ring + swa_loc%ring, -# 分组槽 = state_loc//ratio),state 随 swa 页生灭,radix 命中零拷贝续算。 - - -def paged_state_rows(num_swa_pages: int, ring: int, ratio: int) -> int: - """state 池行数 = 页数*ring + ring(HOLD 页) + 1(哨兵行),向上取整到 ratio 整除 - (分组视图 [-1, ratio, last_dim] 需要)。与 sglang CompressStatePool 的 _size 公式一致。""" - rows = num_swa_pages * ring + ring + 1 - return (rows + ratio - 1) // ratio * ratio - - -def init_paged_state_pool(buffer: torch.Tensor) -> None: - """末行为哨兵: kv 半边置 0、score 半边置 -inf(KVAndScore.clear 语义)。其余行无需初始化 - (内核在组起点覆写)。buffer: [rows, 2*coff*head_dim] fp32。""" - half = buffer.shape[-1] // 2 - buffer[-1, :half].zero_() - buffer[-1, half:].fill_(float("-inf")) - return - - -def _paged_state_group_slot(req_to_token, full_to_swa, b_req_idx, positions, page_size, ring, ratio): - """位置 -> state 分组槽(= sglang create_paged_compressor_data.get_raw_loc): - state_loc = (swa_loc//page)*ring + swa_loc%ring; 分组槽 = state_loc//ratio。 - 负位置按 sglang 语义 mask 到 0;已出窗(swa_loc<0)的位置落到 -1(哨兵行,score=-inf)。""" - positions = positions.masked_fill(positions < 0, 0) - full = req_to_token[b_req_idx.long(), positions] - swa_loc = full_to_swa[full.long()].long() - state_loc = torch.div(swa_loc, page_size, rounding_mode="floor") * ring + swa_loc % ring - state_loc = torch.where(swa_loc < 0, torch.full_like(state_loc, -1), state_loc) - return torch.div(state_loc, ratio, rounding_mode="floor").to(torch.int32) - - -def paged_decode_state_slots( - req_to_token, - full_to_swa, - b_req_idx, - b_seq_len, - page_size: int, - ring: int, - ratio: int, - hold_req_id: int, - num_swa_pages: int, - overlap: bool = True, -): - """decode 步的 state 分组槽(写槽 = 当前组 clip_down(seq-1) 的槽,可选 overlap 前一组)。 - 纯张量算术(prep 已写本步 req_to_token),图安全。padding(HOLD)行重定向到 HOLD 页的 - state 槽,隔离其垃圾累加。""" - seq = b_seq_len.long() - write_positions = torch.div(seq - 1, ratio, rounding_mode="floor") * ratio - write_slot = _paged_state_group_slot(req_to_token, full_to_swa, b_req_idx, write_positions, page_size, ring, ratio) - overlap_slot = None - if overlap: - overlap_slot = _paged_state_group_slot( - req_to_token, full_to_swa, b_req_idx, write_positions - ratio, page_size, ring, ratio - ) - hold_slot = num_swa_pages * ring // ratio # HOLD 页区域([pages*ring, pages*ring+ring))的首个分组槽 - is_hold = b_req_idx.long() == hold_req_id - write_slot = torch.where(is_hold, torch.full_like(write_slot, hold_slot), write_slot) - if overlap_slot is not None: - overlap_slot = torch.where(is_hold, torch.full_like(overlap_slot, hold_slot), overlap_slot) - return write_slot, overlap_slot - - -def paged_prefill_compress_data( - req_to_token, - full_to_swa, - req_idx: int, - ready_len: int, - seq_len: int, - ring: int, - ratio: int = 4, - page_size: int = DSV4_SWA_PAGE_SIZE, - overlap: bool = True, -): - """单请求 prefill chunk 的 (indices, extra_data, plan): 与 sglang 同走 - triton_create_paged_compress_data(按请求产出,内核经 plan 逐 token 步进)。 - 三者都与层无关,同一 forward 内可跨全部 c4 层复用。""" - mod, _ = _load_sglang_compressor() - fn = _load_paged_compress_data_fn() - device = req_to_token.device - n_new = seq_len - ready_len - write_loc, extra_data = fn( - compress_ratio=ratio, - is_overlap=overlap, - swa_page_size=page_size, - ring_size=ring, - req_pool_indices=torch.tensor([req_idx], device=device, dtype=torch.int64), - seq_lens=torch.tensor([seq_len], device=device, dtype=torch.int64), - extend_seq_lens=torch.tensor([n_new], device=device, dtype=torch.int64), - req_to_token=req_to_token, - full_to_swa_index_mapping=full_to_swa, - ) - plan = mod.CompressorPrefillPlan.generate( - ratio, - n_new, - torch.tensor([seq_len], dtype=torch.int64), - torch.tensor([n_new], dtype=torch.int64), - device, - ) - return write_loc, extra_data, plan - - -def compressor_paged_prefill( - x, - wkv_w, - wgate_w, - norm_w, - ape, - head_dim, - cos_table, - sin_table, - eps, - state_buffer, - compress_data, - ready_len, - seq_len, - ratio: int = 4, -): - """单请求 prefill/extend chunk(paged): x [n_new, dim] 为位置 [ready, seq) 的 hidden, - state 写到 swa 派生的分组槽(compress_data 来自 paged_prefill_compress_data,跨层复用)。 - 返回本 chunk 完结组的压缩条目 [seq//ratio - ready//ratio, head_dim](rope 已施加)。""" - mod, _ = _load_sglang_compressor() - kv_score = _project_kv_score(x, wkv_w, wgate_w) - pool = state_buffer.view(-1, ratio, state_buffer.shape[-1]) - write_loc, extra_data, plan = compress_data - kwargs = {"extra_data": extra_data} if extra_data is not None else {} - out = mod.compress_forward( - pool, - kv_score, - _sglang_ape(ape.float(), ratio, head_dim), - write_loc, - plan, - head_dim=head_dim, - compress_ratio=ratio, - **kwargs, - ) - ncomp = seq_len // ratio - ready_len // ratio - if ncomp == 0: - return x.new_zeros(0, head_dim) - mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - ragged_ids = plan.compress_plan.view(torch.int32)[:ncomp, 0].long() - return out.index_select(0, ragged_ids).to(x.dtype) - - -def compressor_paged_decode_batch( - x_new, - wkv_w, - wgate_w, - norm_w, - ape, - head_dim, - cos_table, - sin_table, - eps, - state_buffer, - write_slot, - overlap_slot, - b_seq_len, - ratio: int = 4, -): - """批量 decode 一步(paged): state 槽位为 swa 派生分组槽(paged_decode_state_slots, - 可跨层复用)。返回 (entries [bs, head_dim], should_compress [bs])。""" - mod, _ = _load_sglang_compressor() - kv_score = _project_kv_score(x_new, wkv_w, wgate_w) - pool = state_buffer.view(-1, ratio, state_buffer.shape[-1]) - seq_lens = b_seq_len.to(torch.int32).contiguous() - plan = mod.CompressorDecodePlan(ratio, seq_lens) - kwargs = {"extra_data": overlap_slot.view(-1, 1)} if overlap_slot is not None else {} - out = mod.compress_forward( - pool, - kv_score, - _sglang_ape(ape.float(), ratio, head_dim), - write_slot, - plan, - head_dim=head_dim, - compress_ratio=ratio, - **kwargs, - ) - should_compress = (seq_lens % ratio) == 0 - mod.compress_fused_norm_rope_inplace(out, norm_w.float(), eps, _freq_cis(cos_table, sin_table), plan) - return out.to(x_new.dtype), should_compress diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 1b5c5f4c3f..f99df4b00d 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -8,6 +8,7 @@ from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor +from lightllm.utils.vllm_utils import vllm_ops from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post from .compressor import fused_compress as fused_compress_op from .compressor import prepare_partial_states @@ -26,9 +27,6 @@ def __init__(self, layer_num, network_config): self.qk_rope_head_dim = network_config["qk_rope_head_dim"] self.qk_nope_head_dim = self.head_dim_ - self.qk_rope_head_dim self.v_head_dim = self.head_dim_ - self.index_n_heads = network_config["index_n_heads"] - self.index_head_dim = network_config["index_head_dim"] - self.index_topk = network_config["index_topk"] self.o_groups = network_config["o_groups"] self.hc_mult = network_config["hc_mult"] self.sinkhorn_iters = network_config["hc_sinkhorn_iters"] @@ -48,14 +46,14 @@ def __init__(self, layer_num, network_config): self.swiglu_limit = network_config["swiglu_limit"] self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) self.tp_q_head_num_ = self.num_heads // self.tp_world_size_ - self.tp_index_n_heads = self.index_n_heads // self.tp_world_size_ self.tp_groups = self.o_groups // self.tp_world_size_ self.enable_ep_moe = get_env_start_args().enable_ep_moe - self.indexer_score_scale = self.index_head_dim ** -0.5 - self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 self.compressor = CompressorInfer( layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_ ) + self.index_infer = DeepseekV4IndexInfer( + layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_ + ) # ------------------------------------------------------------------ forward (HC-threaded) def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWeight): @@ -180,19 +178,6 @@ def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: Deepsee o = layer_weight.wo_b_.mm(o) return self._tpsp_reduce(input=o, infer_state=infer_state) - # ------------------------------------------------------------------ compressor / indexer - def _indexer_q_weight( - self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight - ): - if self.compress_ratio != 4: - return None, None - cos_tok = infer_state.position_cos_compress - sin_tok = infer_state.position_sin_compress - idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_n_heads, self.index_head_dim) - rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok) - idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale - return idx_q, idx_weight - # ------------------------------------------------------------------ attention (prefill) def context_attention_forward( self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight @@ -241,7 +226,14 @@ def _context_attention_kernel( ): self.compressor.prepare_states(x, infer_state, layer_weight) self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) - idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + # Write this step's c4 Lightning-Indexer keys (no-op off c4) BEFORE build_metadata so the + # scorer's gather_indexer_k reads fresh+accumulated entries. + self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) + # Build the FINAL flash_mla index tensors here (model side), so att_control is a thin + # transport of ready-to-forward tensors -- not indexer raw material. Must stay after + # fused_compress (c4 reads the indexer-K pool it writes) and before prefill_att (keeps the + # c4 einsum/topk/all_reduce at the same cuda-graph capture position). + meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight) att_control = AttControl( nsa_prefill=True, nsa_prefill_dict={ @@ -250,14 +242,8 @@ def _context_attention_kernel( "compress_ratio": self.compress_ratio, "head_dim_v": self.v_head_dim, "softmax_scale": self.softmax_scale, - "q_lora": q_lora, - "hidden_states": x, "attn_sink": layer_weight.attn_sink_.weight, - "idx_q": idx_q, - "idx_weight": idx_weight, - "index_topk": self.index_topk, - "indexer_score_scale": self.indexer_score_scale, - "tp_world_size": self.tp_world_size_, + **meta, }, ) out = infer_state.prefill_att_state.prefill_att( @@ -285,7 +271,8 @@ def _token_attention_kernel( ): self.compressor.prepare_states(x, infer_state, layer_weight) self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) - idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) + meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight) att_control = AttControl( nsa_decode=True, nsa_decode_dict={ @@ -294,14 +281,8 @@ def _token_attention_kernel( "compress_ratio": self.compress_ratio, "head_dim_v": self.v_head_dim, "softmax_scale": self.softmax_scale, - "q_lora": q_lora, - "hidden_states": x, "attn_sink": layer_weight.attn_sink_.weight, - "idx_q": idx_q, - "idx_weight": idx_weight, - "index_topk": self.index_topk, - "indexer_score_scale": self.indexer_score_scale, - "tp_world_size": self.tp_world_size_, + **meta, }, ) return infer_state.decode_att_state.decode_att( @@ -354,8 +335,6 @@ def _select_experts( def _select_experts_vllm( self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight ): - from vllm import _custom_ops as ops - M = logits.shape[0] bias = None input_tokens = None @@ -373,7 +352,7 @@ def _select_experts_vllm( weights = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.float32, device=logits.device) indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=indices_dtype, device=logits.device) token_expert_indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.int32, device=logits.device) - ops.topk_hash_softplus_sqrt( + vllm_ops.topk_hash_softplus_sqrt( weights, indices, token_expert_indices, @@ -388,11 +367,18 @@ def _select_experts_vllm( class CompressorInfer: - def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): + """Window-softmax compressor. is_in_indexer=False compresses the c4/c128 latent KV into the + paged fp8 slab (attention extra_k); is_in_indexer=True reuses the SAME machinery (mirroring + sglang's Compressor(is_in_indexer=...)) with the indexer weights/dims/state pool to produce the + per-c4-entry Lightning-Indexer keys, emitted as dense bf16 (OUTPUT_BF16) then fp8-packed into + c4_indexer_pool by the caller. Indexer mode is c4-only.""" + + def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int, is_in_indexer: bool = False): super().__init__() self.layer_idx_ = layer_idx self.network_config_ = network_config self.tp_world_size_ = tp_world_size + self.is_in_indexer = is_in_indexer self.compress_ratio = network_config["compress_ratios"][layer_idx] self.head_dim = network_config["head_dim"] self.index_head_dim = network_config["index_head_dim"] @@ -410,13 +396,23 @@ def prepare_states( infer_state=infer_state, layer_idx=self.layer_idx_, compress_ratio=self.compress_ratio, + is_in_indexer=self.is_in_indexer, ) if self._metadata is not None: - self._metadata.kv_score = layer_weight.compressor_wkv_gate_.mm(x).float() + if self.is_in_indexer: + # indexer wkv/wgate are two separate replicated weights; cat -> [T, 2*coff*idx_hd] + # (same [kv | score] layout the fused compressor_wkv_gate_ produces for attention). + kv = layer_weight.idx_cmp_wkv_.mm(x) + gate = layer_weight.idx_cmp_wgate_.mm(x) + self._metadata.kv_score = torch.cat([kv, gate], dim=-1).float() + ape = layer_weight.idx_cmp_ape_.weight + else: + self._metadata.kv_score = layer_weight.compressor_wkv_gate_.mm(x).float() + ape = layer_weight.compressor_ape_.weight prepare_partial_states( kv_score=self._metadata.kv_score, metadata=self._metadata, - ape=layer_weight.compressor_ape_.weight, + ape=ape, compress_ratio=self.compress_ratio, ) return self._metadata @@ -433,15 +429,178 @@ def fused_compress( metadata = self._metadata if metadata is None: raise RuntimeError("DeepSeek-V4 compressor.prepare_states must run before fused_compress") + if self.is_in_indexer: + norm_weight = layer_weight.idx_cmp_norm_.weight + ape = layer_weight.idx_cmp_ape_.weight + head_dim = self.index_head_dim + else: + norm_weight = layer_weight.compressor_norm_.weight + ape = layer_weight.compressor_ape_.weight + head_dim = self.head_dim return fused_compress_op( kv_score=metadata.kv_score, metadata=metadata, - norm_weight=layer_weight.compressor_norm_.weight, - ape=layer_weight.compressor_ape_.weight, + norm_weight=norm_weight, + ape=ape, eps=self.eps, - head_dim=self.head_dim, + head_dim=head_dim, qk_rope_head_dim=self.qk_rope_head_dim, compress_ratio=self.compress_ratio, cos_table=cos_table, sin_table=sin_table, + output_bf16=self.is_in_indexer, + ) + + +FLASHMLA_INDEX_ALIGN = 64 + + +def _pad_last_dim(x: torch.Tensor, multiple: int = FLASHMLA_INDEX_ALIGN, value: int = -1) -> torch.Tensor: + pad = (-x.shape[-1]) % multiple + if pad == 0: + return x.contiguous() + out = torch.full((*x.shape[:-1], x.shape[-1] + pad), value, dtype=x.dtype, device=x.device) + out[..., : x.shape[-1]] = x + return out.contiguous() + + +class DeepseekV4IndexInfer: + """Model-side builder for the FlashMLA sparse-index metadata. Mirrors deepseek3_2's NsaInfer + *boundary* (the model owns ALL index construction; the attention backend only forwards final + tensors to flash_mla.flash_mla_with_kvcache) but NOT its implementation -- the two share ~no + concrete operators (ds3_2: fp8_mqa_logits over the full ragged kv; dsv4: bf16 einsum over the + compressed c4 entries), hence no inheritance. Owns swa/c128 slot bookkeeping AND the c4 + Lightning-Indexer scoring. Holds only static per-layer config; all per-request data flows in via + args. Invoke from _context/_token_attention_kernel (after compressor.fused_compress, before + *_att) so the c4 einsum/topk/all_reduce keep the same cuda-graph capture position they had when + this lived in the backend.""" + + def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): + self.layer_idx_ = layer_idx + self.compress_ratio = network_config["compress_ratios"][layer_idx] + self.index_topk = network_config["index_topk"] + self.index_head_dim = network_config["index_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_n_heads = network_config["index_n_heads"] + self.tp_world_size_ = tp_world_size + self.tp_index_n_heads = self.index_n_heads // tp_world_size + self.indexer_score_scale = self.index_head_dim ** -0.5 + self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 + # c4 layers own a second compressor (is_in_indexer) that writes the Lightning-Indexer key + # pool every step; the scorer in _c4_indices reads it back via gather_indexer_k. + self.indexer_compressor = ( + CompressorInfer(layer_idx, network_config, tp_world_size, is_in_indexer=True) + if self.compress_ratio == 4 + else None + ) + + def write_indexer_k(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight, cos_table, sin_table): + """c4-only: compress this step's tokens into per-c4-entry indexer keys and pack them into + c4_indexer_pool. MUST run before build_metadata so the scorer's gather_indexer_k reads the + finished entries; runs every step (incl. in the decode graph) so keys accumulate for later + long-context scoring. No-op on c128 / dense layers.""" + if self.compress_ratio != 4: + return + self.indexer_compressor.prepare_states(x, infer_state, layer_weight) + self.indexer_compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table) + scratch = self.indexer_compressor._metadata.out_buffer # [T, index_head_dim] bf16 (group-end rows valid) + mem_manager = infer_state.mem_manager + positions = infer_state.position_ids + out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] + # only group-end tokens finish a c4 entry; mask the rest to -1 so the packer skips them + # (mid-group tokens share the group's c4 slot -> avoids racing a finished slot). + completed = ((positions + 1) % 4 == 0) & (out_slots >= 0) + masked_slots = torch.where(completed, out_slots, torch.full_like(out_slots, -1)).to(torch.int32) + mem_manager.pack_indexer_k_to_cache(self.layer_idx_, masked_slots, scratch) + + def build_metadata(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + """Return the final flash_mla index tensors for this layer's compress variant. swa indices and + the per-token req_idx are layer-independent and precomputed once in init_some_extra_state + (read here); only the c4 scorer / c128 gather is per-layer. The backend pairs these with the + (data-independent, layer-keyed) fp8 cache-byte views it owns.""" + swa_indices = infer_state.dsv4_swa_indices.unsqueeze(1) + swa_lengths = infer_state.dsv4_swa_lengths + req_idx = infer_state.dsv4_sparse_req_idx + positions = infer_state.position_ids + extra_indices = extra_lengths = None + if self.compress_ratio == 4: + idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + extra_indices, extra_lengths = self._c4_indices(infer_state, idx_q, idx_weight, req_idx, positions) + elif self.compress_ratio == 128: + extra_indices, extra_lengths = self._c128_indices(infer_state, req_idx, positions) + return { + "swa_indices": swa_indices, + "swa_lengths": swa_lengths, + "extra_indices": extra_indices, + "extra_lengths": extra_lengths, + } + + @staticmethod + def _gather_compress_slots(infer_state, mapping, req_idx, valid, offsets, ratio): + """条目 g 的压缩槽 = full_to_c*[req_to_token[req, (g+1)*ratio-1]](组末 token 的 full 槽位)。 + 无效条目(超出因果长度/HOLD 行)用位置 0 安全 gather 后由调用方按 valid 掩掉。""" + end_pos = offsets[None, :] * ratio + (ratio - 1) + safe_pos = torch.where(valid, end_pos, torch.zeros_like(end_pos)) + full_slots = infer_state.req_manager.req_to_token_indexs[req_idx.long()[:, None], safe_pos] + return mapping[full_slots.long()].to(torch.int32) + + def _c128_indices(self, infer_state: DeepseekV4InferStateInfo, req_idx, positions): + raw_lengths = (positions + 1) // 128 + lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) + max_len = max(1, int(infer_state.max_kv_seq_len) // 128) + offsets = torch.arange(max_len, dtype=torch.long, device=positions.device) + valid = offsets[None, :] < raw_lengths[:, None] + slots = self._gather_compress_slots( + infer_state, infer_state.mem_manager.full_to_c128_indexs, req_idx, valid, offsets, 128 + ) + indices = torch.where(valid, slots, torch.full_like(slots, -1)) + return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() + + def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + cos_tok = infer_state.position_cos_compress + sin_tok = infer_state.position_sin_compress + idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_n_heads, self.index_head_dim) + rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok) + idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale + return idx_q, idx_weight + + def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q, idx_weight, req_idx, positions): + """c4(CSA) extra indices: causal all-entries when the entry space fits index_topk, otherwise + Lightning-Indexer scored top-k. Pure tensor ops (decode runs inside cuda graphs).""" + mem_manager = infer_state.mem_manager + raw_lengths = (positions + 1) // 4 + max_entries = max(1, int(infer_state.max_kv_seq_len) // 4) + index_topk = self.index_topk + offsets = torch.arange(max_entries, dtype=torch.long, device=positions.device) + valid = offsets[None, :] < raw_lengths[:, None] + slots = self._gather_compress_slots(infer_state, mem_manager.full_to_c4_indexs, req_idx, valid, offsets, 4) + + if max_entries <= index_topk: + lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) + indices = torch.where(valid, slots, torch.full_like(slots, -1)) + return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() + + score_scale = float(self.indexer_score_scale) + hold_slot = mem_manager.c4_indexer_pool.HOLD_TOKEN_MEMINDEX + safe_slots = torch.where(valid, slots.long(), torch.full_like(slots.long(), hold_slot)) + k = mem_manager.gather_indexer_k(self.layer_idx_, safe_slots.reshape(-1)).view( + positions.shape[0], max_entries, -1 ) + num_tokens, num_heads = idx_q.shape[0], idx_q.shape[1] + score_chunks = [] + chunk = max(1, min(num_tokens, (16 * 1024 * 1024) // max(1, num_heads * max_entries))) + for start in range(0, num_tokens, chunk): + end = min(num_tokens, start + chunk) + scores = torch.einsum("thd,tnd->thn", idx_q[start:end].float(), k[start:end].float()) + scores = F.relu(scores) * score_scale + score_chunks.append((scores * idx_weight[start:end].unsqueeze(-1)).sum(dim=1)) + index_scores = torch.cat(score_chunks, dim=0) + if self.tp_world_size_ > 1: + all_reduce(index_scores, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + index_scores = index_scores.masked_fill(~valid, float("-inf")) + top = index_scores.topk(index_topk, dim=-1).indices + top_valid = torch.gather(valid, 1, top) + top_slots = torch.gather(slots.long(), 1, top).to(torch.int32) + indices = torch.where(top_valid, top_slots, torch.full_like(top_slots, -1)) + lengths = torch.clamp(torch.minimum(raw_lengths, torch.full_like(raw_lengths, index_topk)), min=1) + return _pad_last_dim(indices).unsqueeze(1), lengths.to(torch.int32).contiguous() diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py new file mode 100644 index 0000000000..e5b7b80bb4 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py @@ -0,0 +1,75 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_swa_index_kernel( + req_idx_ptr, + pos_ptr, + req_to_token_ptr, + req_to_token_stride0, + full_to_swa_ptr, + swa_index_ptr, + swa_length_ptr, + WINDOW: tl.constexpr, + BLOCK_W: tl.constexpr, +): + token_idx = tl.program_id(0) + req = tl.load(req_idx_ptr + token_idx).to(tl.int64) + pos = tl.load(pos_ptr + token_idx).to(tl.int64) + + w = tl.arange(0, BLOCK_W) + w_mask = w < WINDOW + # most-recent-first window, identical to the eager _swa_indices (offset = position - arange). + offset = pos - w + valid = (offset >= 0) & w_mask + safe_offset = tl.where(valid, offset, 0) + full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_offset, mask=valid, other=0).to(tl.int64) + swa_slot = tl.load(full_to_swa_ptr + full_slot, mask=valid, other=-1) + out = tl.where(valid, swa_slot, -1).to(tl.int32) + tl.store(swa_index_ptr + token_idx * WINDOW + w, out, mask=w_mask) + + length = tl.minimum(tl.maximum(pos + 1, 1), WINDOW).to(tl.int32) + tl.store(swa_length_ptr + token_idx, length) + + +def build_swa_index( + req_idx: torch.Tensor, + positions: torch.Tensor, + req_to_token_indexs: torch.Tensor, + full_to_swa_indexs: torch.Tensor, + window: int, +): + """Per-token sliding-window FlashMLA index table, built ONCE per forward (layer-independent: + full_to_swa is a single global map and the window is a model constant, so every layer's swa + indices are identical). Replaces DeepseekV4IndexInfer._swa_indices: for token t at + (req_idx, position) gather the last `window` tokens' full slots via req_to_token, then map + full -> swa; out-of-range positions store -1. + + Returns (swa_index [T, window] int32, swa_length [T] int32). `window` is 128 (a multiple of the + FlashMLA 64 alignment) so no extra pad is needed; the reader adds the s_q axis via unsqueeze(1). + Const output shape (no max_kv_seq_len dependence) makes this cuda-graph-safe to stage from + init_some_extra_state via copy_for_cuda_graph. + """ + # window must stay 64-aligned: the output is the FlashMLA `indices` tensor directly (no separate + # _pad_last_dim), and the extra-cache fork requires the topk dim to be a multiple of 64. + assert window % 64 == 0, f"DeepSeek-V4 sliding_window must be a multiple of 64 for FlashMLA, got {window}" + T = positions.shape[0] + swa_index = torch.empty((T, window), dtype=torch.int32, device=positions.device) + swa_length = torch.empty((T,), dtype=torch.int32, device=positions.device) + if T == 0: + return swa_index, swa_length + _build_swa_index_kernel[(T,)]( + req_idx, + positions, + req_to_token_indexs, + req_to_token_indexs.stride(0), + full_to_swa_indexs, + swa_index, + swa_length, + WINDOW=window, + BLOCK_W=triton.next_power_of_2(window), + num_warps=4, + ) + return swa_index, swa_length From d4dcd8a8924e7218025763447db3abad72d9c520 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 15 Jun 2026 13:00:25 +0000 Subject: [PATCH 20/30] opt --- lightllm/common/basemodel/basemodel.py | 53 +---- .../deepseek4_mem_manager.py | 10 +- lightllm/common/req_manager.py | 40 ++-- .../layer_infer/transformer_layer_infer.py | 192 ++++++++++-------- .../layer_weights/transformer_layer_weight.py | 10 +- .../build_compress_index_dsv4.py | 78 +++++++ .../triton_kernel/gather_c4_indexer_k_dsv4.py | 106 ++++++++++ 7 files changed, 338 insertions(+), 151 deletions(-) create mode 100644 lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py create mode 100644 lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 986802e760..30248d6a21 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -297,6 +297,11 @@ def _init_custom(self): @torch.no_grad() def forward(self, model_input: ModelInput): + # decode 槽位 prep: 放在 to_cuda 前 (b_req_idx/b_seq_len/mem_indexes_cpu 还是原生 CPU 张量), + # 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。 + # mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。 + if not model_input.is_prefill and model_input.mem_indexes_cpu is not None: + self.req_manager.prepare_decode(model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes_cpu) model_input.to_cuda() assert model_input.mem_indexes.is_cuda @@ -579,14 +584,6 @@ def _decode( model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) - if hasattr(self.req_manager, "prepare_decode_swa"): - self.req_manager.prepare_decode_swa( - model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes - ) - if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots( - model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes - ) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -608,14 +605,6 @@ def _decode( model_input = self._create_padded_decode_model_input( model_input=model_input, new_batch_size=infer_batch_size ) - if hasattr(self.req_manager, "prepare_decode_swa"): - self.req_manager.prepare_decode_swa( - model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes - ) - if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots( - model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes - ) infer_state = self._create_inferstate(model_input) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -845,6 +834,10 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod @torch.no_grad() def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput): + # decode 槽位 prep: 在 to_cuda 前 (原生 CPU 张量)、且已在 forward 的 CUDA stream 上 (见 forward 注释)。 + for mi in (model_input0, model_input1): + if mi.mem_indexes_cpu is not None: + self.req_manager.prepare_decode(mi.b_req_idx, mi.b_seq_len, mi.mem_indexes_cpu) model_input0.to_cuda() model_input1.to_cuda() assert self.args.enable_tpsp_mix_mode @@ -876,20 +869,6 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode # 一致,需要按照较高 batch size 进行graph的寻找,同时,进行有效的恢复。 padded_model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) - if hasattr(self.req_manager, "prepare_decode_swa"): - self.req_manager.prepare_decode_swa( - padded_model_input0.b_req_idx, padded_model_input0.b_seq_len, padded_model_input0.mem_indexes - ) - self.req_manager.prepare_decode_swa( - padded_model_input1.b_req_idx, padded_model_input1.b_seq_len, padded_model_input1.mem_indexes - ) - if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots( - padded_model_input0.b_req_idx, padded_model_input0.b_seq_len, padded_model_input0.mem_indexes - ) - self.req_manager.prepare_decode_compress_slots( - padded_model_input1.b_req_idx, padded_model_input1.b_seq_len, padded_model_input1.mem_indexes - ) infer_state0 = self._create_inferstate(padded_model_input0, 0) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -931,20 +910,6 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode else: model_input0 = self._create_padded_decode_model_input(model_input0, infer_batch_size) model_input1 = self._create_padded_decode_model_input(model_input1, infer_batch_size) - if hasattr(self.req_manager, "prepare_decode_swa"): - self.req_manager.prepare_decode_swa( - model_input0.b_req_idx, model_input0.b_seq_len, model_input0.mem_indexes - ) - self.req_manager.prepare_decode_swa( - model_input1.b_req_idx, model_input1.b_seq_len, model_input1.mem_indexes - ) - if hasattr(self.req_manager, "prepare_decode_compress_slots"): - self.req_manager.prepare_decode_compress_slots( - model_input0.b_req_idx, model_input0.b_seq_len, model_input0.mem_indexes - ) - self.req_manager.prepare_decode_compress_slots( - model_input1.b_req_idx, model_input1.b_seq_len, model_input1.mem_indexes - ) infer_state0 = self._create_inferstate(model_input0, 0) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 8d172ec758..32561aa6f7 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -511,8 +511,8 @@ def alloc_swa_prefill( def alloc_swa_decode( self, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, mem_indexes: torch.Tensor, req_to_token_indexs: torch.Tensor, ) -> None: @@ -523,8 +523,8 @@ def alloc_swa_decode( (DSV4 启动参数已拒绝 MTP;支持需按步内顺序分段派生)。""" page = DSV4_SWA_PAGE_SIZE hold_req_id = self.max_request_num - req_list = b_req_idx.detach().cpu().tolist() - seq_list = b_seq_len.detach().cpu().tolist() + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() cont_rows, cont_prev_pos, new_rows = [], [], [] for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)): req_idx, seq_len = int(req_idx), int(seq_len) @@ -537,7 +537,7 @@ def alloc_swa_decode( cont_prev_pos.append(seq_len - 2) mem_indexes = mem_indexes.cuda().long().reshape(-1) if cont_rows: - req_rows = b_req_idx[cont_rows].long() + req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda") prev_full = req_to_token_indexs[req_rows, torch.tensor(cont_prev_pos, device="cuda")].long() prev_slots = self.full_to_swa_indexs[prev_full] # 续槽不变式哨兵: 上一位置必驻留(retain 覆盖)。prep 阶段本就有同步,代价可忽略。 diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index a0faccb00c..9ac1babe23 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -156,6 +156,11 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu): + """每个 decode step 在 to_cuda 之前调用的钩子 (数据为原生 CPU 张量, 且已在 forward 的 + CUDA stream 上)。基类 no-op; 需要 per-step KV 槽位 prep 的模型 (DeepSeek-V4) override。""" + return + class ReqSamplingParamsManager: """ @@ -479,26 +484,32 @@ def prepare_prefill_swa( self.mem_manager.alloc_swa_prefill(b_req_idx, b_ready_cache_len, b_seq_len, self.req_to_token_indexs) return + def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu): + """decode 每步槽位 prep: 先 swa 再 compress。由 BaseModel.forward / microbatch_overlap_decode + 在 to_cuda 之前调用 (CPU 数据 + forward 的 CUDA stream); 不再放在 _decode 里。""" + self.prepare_decode_swa(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu) + self.prepare_decode_compress_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu) + return + def prepare_decode_swa( self, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, mem_indexes: torch.Tensor, ) -> None: """decode prep: 回收出窗槽并为本步新 token 分配位置对齐的 swa 槽。当前 query 位置 seq_len-1 的窗口是 [seq_len-W, seq_len-1];回收边界额外保留一个 radix 页 - (_swa_retain_len),即位置 < seq_len-retain。先回收再分配。""" + (_swa_retain_len),即位置 < seq_len-retain。先回收再分配。 + seq_len/req_idx 从 CPU 镜像读(host 算术,无 D2H);水位线 _swa_evict_marks 仍是 host 状态。""" assert self.mem_manager is not None if self.sliding_window is not None: retain = self._swa_retain_len() evict_slots = [] - req_list = b_req_idx.detach().cpu().tolist() - seq_list = b_seq_len.detach().cpu().tolist() + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() for req_idx, seq_len in zip(req_list, seq_list): - req_idx = int(req_idx) if req_idx == self.HOLD_REQUEST_ID: continue - seq_len = int(seq_len) mark = self._swa_evict_marks[req_idx] if mark < 0: # 未经过 prefill prep 的保守路径: 不回收旧位置,仅推进水位线。 @@ -510,7 +521,7 @@ def prepare_decode_swa( self._swa_evict_marks[req_idx] = evict_end if evict_slots: self.mem_manager.evict_swa(torch.cat(evict_slots)) - self.mem_manager.alloc_swa_decode(b_req_idx, b_seq_len, mem_indexes, self.req_to_token_indexs) + self.mem_manager.alloc_swa_decode(b_req_idx_cpu, b_seq_len_cpu, mem_indexes, self.req_to_token_indexs) return def init_compress_state(self, req_idx: int): @@ -575,23 +586,24 @@ def prepare_prefill_compress_slots( def prepare_decode_compress_slots( self, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, mem_indexes: torch.Tensor, ) -> None: """decode prep: 本步 token 关闭一个组(seq_len % ratio == 0)时为其分配压缩槽并 scatter。 - 组末 full 槽即本步的 mem_index(此刻 req_to_token_indexs 尚未写入本步槽位)。""" + 组末 full 槽即本步的 mem_index(此刻 req_to_token_indexs 尚未写入本步槽位)。 + 从 CPU 镜像读 seq_len/req_idx(host 算术,无 D2H);非关组步 rows 为空 => 不调 _scatter,零同步。""" if self.n_c4 == 0 and self.n_c128 == 0: return - req_list = b_req_idx.detach().cpu().tolist() - seq_list = b_seq_len.detach().cpu().tolist() + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() for ratio, n_layers in ((4, self.n_c4), (128, self.n_c128)): if n_layers == 0: continue rows = [ i for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)) - if int(req_idx) != self.HOLD_REQUEST_ID and int(seq_len) > 0 and int(seq_len) % ratio == 0 + if req_idx != self.HOLD_REQUEST_ID and seq_len > 0 and seq_len % ratio == 0 ] if rows: self._scatter_compress_slots(ratio, mem_indexes.reshape(-1)[rows]) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index f99df4b00d..ef898c31bc 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -227,12 +227,12 @@ def _context_attention_kernel( self.compressor.prepare_states(x, infer_state, layer_weight) self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) # Write this step's c4 Lightning-Indexer keys (no-op off c4) BEFORE build_metadata so the - # scorer's gather_indexer_k reads fresh+accumulated entries. + # scorer (gather + deep_gemm.fp8_mqa_logits) reads fresh+accumulated entries from the indexer pool. self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) # Build the FINAL flash_mla index tensors here (model side), so att_control is a thin # transport of ready-to-forward tensors -- not indexer raw material. Must stay after # fused_compress (c4 reads the indexer-K pool it writes) and before prefill_att (keeps the - # c4 einsum/topk/all_reduce at the same cuda-graph capture position). + # c4 scorer/topk at the same cuda-graph capture position). meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight) att_control = AttControl( nsa_prefill=True, @@ -452,28 +452,18 @@ def fused_compress( ) -FLASHMLA_INDEX_ALIGN = 64 - - -def _pad_last_dim(x: torch.Tensor, multiple: int = FLASHMLA_INDEX_ALIGN, value: int = -1) -> torch.Tensor: - pad = (-x.shape[-1]) % multiple - if pad == 0: - return x.contiguous() - out = torch.full((*x.shape[:-1], x.shape[-1] + pad), value, dtype=x.dtype, device=x.device) - out[..., : x.shape[-1]] = x - return out.contiguous() - - class DeepseekV4IndexInfer: """Model-side builder for the FlashMLA sparse-index metadata. Mirrors deepseek3_2's NsaInfer - *boundary* (the model owns ALL index construction; the attention backend only forwards final - tensors to flash_mla.flash_mla_with_kvcache) but NOT its implementation -- the two share ~no - concrete operators (ds3_2: fp8_mqa_logits over the full ragged kv; dsv4: bf16 einsum over the - compressed c4 entries), hence no inheritance. Owns swa/c128 slot bookkeeping AND the c4 - Lightning-Indexer scoring. Holds only static per-layer config; all per-request data flows in via - args. Invoke from _context/_token_attention_kernel (after compressor.fused_compress, before - *_att) so the c4 einsum/topk/all_reduce keep the same cuda-graph capture position they had when - this lived in the backend.""" + boundary (the model owns ALL index construction; the attention backend only forwards final + tensors to flash_mla.flash_mla_with_kvcache) AND its c4 implementation: hadamard'd fp8 q/K, a + ragged gather of the compressed c4 keys, deep_gemm.fp8_mqa_logits, then topk -- adapted for the + replicated indexer (no gather-q/all_reduce), the c4-compressed entry space, and topk-512 (no + inheritance only because of those data-shape differences). swa metadata is precomputed in + init_some_extra_state; this class owns the c4/c128 entry gather (build_compress_index) AND the c4 + Lightning-Indexer scoring (gather + deep_gemm.fp8_mqa_logits + topk). Holds only static per-layer + config; all per-request data flows in via args. Invoke from _context/_token_attention_kernel + (after compressor.fused_compress, before *_att) so the c4 scorer/topk keep the same cuda-graph + capture position they had when this lived in the backend. The indexer is replicated (no TP collective).""" def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): self.layer_idx_ = layer_idx @@ -483,11 +473,10 @@ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): self.qk_rope_head_dim = network_config["qk_rope_head_dim"] self.index_n_heads = network_config["index_n_heads"] self.tp_world_size_ = tp_world_size - self.tp_index_n_heads = self.index_n_heads // tp_world_size self.indexer_score_scale = self.index_head_dim ** -0.5 self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 # c4 layers own a second compressor (is_in_indexer) that writes the Lightning-Indexer key - # pool every step; the scorer in _c4_indices reads it back via gather_indexer_k. + # pool every step; _c4_indices gathers it back + scores via deep_gemm.fp8_mqa_logits. self.indexer_compressor = ( CompressorInfer(layer_idx, network_config, tp_world_size, is_in_indexer=True) if self.compress_ratio == 4 @@ -496,14 +485,19 @@ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): def write_indexer_k(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight, cos_table, sin_table): """c4-only: compress this step's tokens into per-c4-entry indexer keys and pack them into - c4_indexer_pool. MUST run before build_metadata so the scorer's gather_indexer_k reads the - finished entries; runs every step (incl. in the decode graph) so keys accumulate for later - long-context scoring. No-op on c128 / dense layers.""" + c4_indexer_pool. MUST run before build_metadata so the scorer (gather + deep_gemm.fp8_mqa_logits) + reads the finished entries; runs every step (incl. in the decode graph) so keys accumulate for + later long-context scoring. No-op on c128 / dense layers.""" if self.compress_ratio != 4: return self.indexer_compressor.prepare_states(x, infer_state, layer_weight) self.indexer_compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table) scratch = self.indexer_compressor._metadata.out_buffer # [T, index_head_dim] bf16 (group-end rows valid) + # Rotate K (post norm+rope) by the SAME 1/sqrt(d) Hadamard the q kernel applies, so + # (Hq)·(Hk)=q·k (H orthogonal) and the fp8 quant of K stays accurate. + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + + scratch = hadamard_transform(scratch, scale=self.index_head_dim ** -0.5) mem_manager = infer_state.mem_manager positions = infer_state.position_ids out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] @@ -524,8 +518,8 @@ def build_metadata(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer positions = infer_state.position_ids extra_indices = extra_lengths = None if self.compress_ratio == 4: - idx_q, idx_weight = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) - extra_indices, extra_lengths = self._c4_indices(infer_state, idx_q, idx_weight, req_idx, positions) + idx_q_fp8, weights = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + extra_indices, extra_lengths = self._c4_indices(infer_state, idx_q_fp8, weights, positions) elif self.compress_ratio == 128: extra_indices, extra_lengths = self._c128_indices(infer_state, req_idx, positions) return { @@ -535,72 +529,98 @@ def build_metadata(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer "extra_lengths": extra_lengths, } - @staticmethod - def _gather_compress_slots(infer_state, mapping, req_idx, valid, offsets, ratio): - """条目 g 的压缩槽 = full_to_c*[req_to_token[req, (g+1)*ratio-1]](组末 token 的 full 槽位)。 - 无效条目(超出因果长度/HOLD 行)用位置 0 安全 gather 后由调用方按 valid 掩掉。""" - end_pos = offsets[None, :] * ratio + (ratio - 1) - safe_pos = torch.where(valid, end_pos, torch.zeros_like(end_pos)) - full_slots = infer_state.req_manager.req_to_token_indexs[req_idx.long()[:, None], safe_pos] - return mapping[full_slots.long()].to(torch.int32) - def _c128_indices(self, infer_state: DeepseekV4InferStateInfo, req_idx, positions): - raw_lengths = (positions + 1) // 128 - lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) - max_len = max(1, int(infer_state.max_kv_seq_len) // 128) - offsets = torch.arange(max_len, dtype=torch.long, device=positions.device) - valid = offsets[None, :] < raw_lengths[:, None] - slots = self._gather_compress_slots( - infer_state, infer_state.mem_manager.full_to_c128_indexs, req_idx, valid, offsets, 128 + from ..triton_kernel.build_compress_index_dsv4 import build_compress_index + + cap = ((max(1, int(infer_state.max_kv_seq_len) // 128) + 63) // 64) * 64 + indices, lengths = build_compress_index( + req_idx, + positions, + infer_state.req_manager.req_to_token_indexs, + infer_state.mem_manager.full_to_c128_indexs, + ratio=128, + cap=cap, ) - indices = torch.where(valid, slots, torch.full_like(slots, -1)) - return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() + return indices.unsqueeze(1), lengths def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + """fp8 indexer q (mirrors deepseek3_2 NsaInfer): wq_b -> rope(last rope dims) -> 1/sqrt(d) + Hadamard -> per-token fp8 quant. Returns (idx_q_fp8 [T,H,d], weights [T,H]); the per-token q + fp8 scale and the head_dim^-0.5 * n_heads^-0.5 score scale are folded into weights -- the + deep_gemm.fp8_mqa_logits contract (fp8 q carries no companion scale). Replicated -> full heads.""" + from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + cos_tok = infer_state.position_cos_compress sin_tok = infer_state.position_sin_compress - idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.tp_index_n_heads, self.index_head_dim) + idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.index_n_heads, self.index_head_dim) rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok) - idx_weight = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale - return idx_q, idx_weight - - def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q, idx_weight, req_idx, positions): - """c4(CSA) extra indices: causal all-entries when the entry space fits index_topk, otherwise - Lightning-Indexer scored top-k. Pure tensor ops (decode runs inside cuda graphs).""" + idx_q = hadamard_transform(idx_q, scale=self.index_head_dim ** -0.5) + idx_q_fp8, q_scale = act_quant(idx_q, self.index_head_dim, None) # fp8 [T,H,d], scale [T,H,1] + weights = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale # [T, H] + weights = weights.unsqueeze(-1) * q_scale # fold per-token q scale + return idx_q_fp8, weights.squeeze(-1).contiguous() + + def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, positions): + """c4 scorer via ds3.2-style gather + deep_gemm.fp8_mqa_logits. Gather each request's causal c4 + keys into a padded-per-request ragged fp8 buffer (k row r*c4_cap+e), score every query token + over its absolute [ks, ke) range, then masked topk-512 -> c4 slots. Fixed shapes (c4_cap pinned + per graph bucket) keep the decode cuda graph capturable.""" mem_manager = infer_state.mem_manager - raw_lengths = (positions + 1) // 4 - max_entries = max(1, int(infer_state.max_kv_seq_len) // 4) index_topk = self.index_topk - offsets = torch.arange(max_entries, dtype=torch.long, device=positions.device) - valid = offsets[None, :] < raw_lengths[:, None] - slots = self._gather_compress_slots(infer_state, mem_manager.full_to_c4_indexs, req_idx, valid, offsets, 4) + max_entries = max(1, int(infer_state.max_kv_seq_len) // 4) + c4_cap = ((max_entries + 63) // 64) * 64 + # entry space fits the budget -> every causal entry is selected; no scoring needed. The + # captured decode graph (graph_max_len -> max_entries > topk) always takes the scorer branch + # below, so this only shortcuts tiny eager contexts. if max_entries <= index_topk: - lengths = torch.clamp(raw_lengths, min=1).to(torch.int32) - indices = torch.where(valid, slots, torch.full_like(slots, -1)) - return _pad_last_dim(indices).unsqueeze(1), lengths.contiguous() - - score_scale = float(self.indexer_score_scale) - hold_slot = mem_manager.c4_indexer_pool.HOLD_TOKEN_MEMINDEX - safe_slots = torch.where(valid, slots.long(), torch.full_like(slots.long(), hold_slot)) - k = mem_manager.gather_indexer_k(self.layer_idx_, safe_slots.reshape(-1)).view( - positions.shape[0], max_entries, -1 + from ..triton_kernel.build_compress_index_dsv4 import build_compress_index + + slots, lengths = build_compress_index( + infer_state.dsv4_sparse_req_idx, + positions, + infer_state.req_manager.req_to_token_indexs, + mem_manager.full_to_c4_indexs, + ratio=4, + cap=c4_cap, + ) + return slots.unsqueeze(1), lengths + + import deep_gemm + from ..triton_kernel.gather_c4_indexer_k_dsv4 import gather_c4_indexer_k_ragged + + b_req_idx = infer_state.b_req_idx + batch = b_req_idx.shape[0] + device = positions.device + c4_len = torch.div(infer_state.b_seq_len, 4, rounding_mode="floor").to(torch.int32) # entries/req + k_fp8, k_scale, ragged_slots = gather_c4_indexer_k_ragged( + mem_manager, + self.layer_idx_, + b_req_idx, + c4_len, + c4_cap, + infer_state.req_manager.req_to_token_indexs, ) - num_tokens, num_heads = idx_q.shape[0], idx_q.shape[1] - score_chunks = [] - chunk = max(1, min(num_tokens, (16 * 1024 * 1024) // max(1, num_heads * max_entries))) - for start in range(0, num_tokens, chunk): - end = min(num_tokens, start + chunk) - scores = torch.einsum("thd,tnd->thn", idx_q[start:end].float(), k[start:end].float()) - scores = F.relu(scores) * score_scale - score_chunks.append((scores * idx_weight[start:end].unsqueeze(-1)).sum(dim=1)) - index_scores = torch.cat(score_chunks, dim=0) - if self.tp_world_size_ > 1: - all_reduce(index_scores, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - index_scores = index_scores.masked_fill(~valid, float("-inf")) - top = index_scores.topk(index_topk, dim=-1).indices - top_valid = torch.gather(valid, 1, top) - top_slots = torch.gather(slots.long(), 1, top).to(torch.int32) - indices = torch.where(top_valid, top_slots, torch.full_like(top_slots, -1)) - lengths = torch.clamp(torch.minimum(raw_lengths, torch.full_like(raw_lengths, index_topk)), min=1) - return _pad_last_dim(indices).unsqueeze(1), lengths.to(torch.int32).contiguous() + # batch position of each query token -> absolute [ks, ke) into the padded buffer. + if infer_state.is_prefill: + token_batch_pos = torch.repeat_interleave( + torch.arange(batch, device=device, dtype=torch.int32), infer_state.b_q_seq_len + ) + else: + token_batch_pos = torch.arange(batch, device=device, dtype=torch.int32) + valid_len = ((positions + 1) // 4).to(torch.int32) # causal candidate count per query + ks = token_batch_pos * c4_cap + ke = ks + valid_len + logits = deep_gemm.fp8_mqa_logits( + idx_q_fp8, (k_fp8, k_scale), weights, ks, ke, clean_logits=False, max_seqlen_k=c4_cap + ) # [T, c4_cap] f32, left-aligned: logits[t, j] = q_t . k[ks[t]+j] + col = torch.arange(c4_cap, device=device) + logits = logits.masked_fill(col.unsqueeze(0) >= valid_len.unsqueeze(1), float("-inf")) + top = logits.topk(index_topk, dim=-1).indices.to(torch.int32) # relative positions in [0, valid_len) + abs_idx = top + ks.unsqueeze(1) # absolute compact row + top_slots = ragged_slots[abs_idx.long()] # compact row -> c4 pool slot + invalid = top >= valid_len.unsqueeze(1) # topk over -inf padding when valid_len < index_topk + top_slots = torch.where(invalid, torch.full_like(top_slots, -1), top_slots) + topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1) + return top_slots.unsqueeze(1), topk_lengths.contiguous() diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py index d42b20f6e5..581b2b1d96 100644 --- a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -47,7 +47,6 @@ def _parse_config(self): self.is_hash = self.layer_num_ < self.num_hash_layers assert self.n_heads % self.tp_world_size_ == 0 assert self.o_groups % self.tp_world_size_ == 0 - assert self.index_n_heads % self.tp_world_size_ == 0 self.prefix = f"layers.{self.layer_num_}" def _init_weight(self): @@ -139,13 +138,18 @@ def _init_compressor(self): def _init_indexer(self): p = f"{self.prefix}.attn.indexer" - # wq_b is FP8 in the checkpoint -> de-quantized to bf16 at load; column-parallel over index heads. + # The Lightning-Indexer is REPLICATED across TP ranks (like sglang/vllm), not head-sharded: + # q_lora and the attn input are already full on every rank, so each rank scores all + # index_n_heads locally and the c4 top-k is identical everywhere -- no gather/all_reduce. + # wq_b is FP8 in the checkpoint -> de-quantized to bf16 at load. self.idx_wq_b_ = ROWMMWeight( in_dim=self.q_lora_rank, out_dims=[self.index_n_heads * self.index_head_dim], weight_names=f"{p}.wq_b.weight", data_type=self.data_type_, quant_method=self.get_quant_method("idx_wq_b"), + tp_rank=0, + tp_world_size=1, ) self.idx_weights_proj_ = ROWMMWeight( in_dim=self.hidden, @@ -153,6 +157,8 @@ def _init_indexer(self): weight_names=f"{p}.weights_proj.weight", data_type=self.data_type_, quant_method=None, + tp_rank=0, + tp_world_size=1, ) coff = 2 # indexer compressor always uses ratio 4 (overlap) self.idx_cmp_wkv_ = ROWMMWeight( diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py new file mode 100644 index 0000000000..b09192498d --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py @@ -0,0 +1,78 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_compress_index_kernel( + req_idx_ptr, + pos_ptr, + req_to_token_ptr, + req_to_token_stride0, + full_to_c_ptr, + index_ptr, + length_ptr, + cap, + RATIO: tl.constexpr, + BLOCK_E: tl.constexpr, +): + t = tl.program_id(0) + eb = tl.program_id(1) + req = tl.load(req_idx_ptr + t).to(tl.int64) + pos = tl.load(pos_ptr + t).to(tl.int64) + raw_len = (pos + 1) // RATIO + + e = eb * BLOCK_E + tl.arange(0, BLOCK_E) + e_mask = e < cap + valid = (e < raw_len) & e_mask + # group-end token of compressed entry e: position e*RATIO + (RATIO-1). + end_pos = e * RATIO + (RATIO - 1) + safe_pos = tl.where(valid, end_pos, 0) + full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_pos, mask=valid, other=0).to(tl.int64) + c_slot = tl.load(full_to_c_ptr + full_slot, mask=valid, other=-1).to(tl.int32) + tl.store(index_ptr + t * cap + e, c_slot, mask=e_mask) + + if eb == 0: + tl.store(length_ptr + t, tl.maximum(raw_len, 1).to(tl.int32)) + + +def build_compress_index( + req_idx: torch.Tensor, + positions: torch.Tensor, + req_to_token_indexs: torch.Tensor, + full_to_c_indexs: torch.Tensor, + ratio: int, + cap: int, +): + """Fused two-level group-end gather for the c4/c128 compressed-entry index tables. + + For token t (at request `req_idx[t]`, absolute `positions[t]`) and compressed entry e: + slot[t, e] = full_to_c[ req_to_token[req, e*ratio + (ratio-1)] ] (the group-end token's full slot) + with slot = -1 where e >= (pos+1)//ratio (beyond the causal compressed length) or where the + full->c map is unset. Returns (index [T, cap] int32, length [T] int32 = clamp((pos+1)//ratio, 1)). + + Replaces the eager _gather_compress_slots/_c128/c4-causal torch chain. `cap` must be a multiple of + 64 (FlashMLA topk alignment); the tiled grid (T, ceil(cap/BLOCK_E)) scales to 1M-context caps. + cuda-graph-safe: cap is fixed per graph bucket, shapes static. + """ + T = positions.shape[0] + index = torch.empty((T, cap), dtype=torch.int32, device=positions.device) + length = torch.empty((T,), dtype=torch.int32, device=positions.device) + if T == 0: + return index, length + BLOCK_E = 256 + grid = (T, triton.cdiv(cap, BLOCK_E)) + _build_compress_index_kernel[grid]( + req_idx, + positions, + req_to_token_indexs, + req_to_token_indexs.stride(0), + full_to_c_indexs, + index, + length, + cap, + RATIO=ratio, + BLOCK_E=BLOCK_E, + num_warps=4, + ) + return index, length diff --git a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py new file mode 100644 index 0000000000..6adac4fdd0 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py @@ -0,0 +1,106 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _gather_c4_indexer_k_kernel( + req_idx_ptr, # [batch] int — req_manager slot per batch position + c4_len_ptr, # [batch] int — number of causal c4 entries per request (= seq_len // ratio) + req_to_token_ptr, + req_to_token_stride0, + full_to_c4_ptr, + SlabFp8_ptr, # c4 indexer pool, viewed as fp8 (flat) + SlabF32_ptr, # same pool, viewed as f32 (flat) + Kout_fp8_ptr, # [batch*c4_cap, HEAD_DIM] fp8 + Kout_scale_ptr, # [batch*c4_cap] f32 + Slots_out_ptr, # [batch*c4_cap] int32 (compact->c4-slot map; -1 for padding) + c4_cap, + RATIO: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, + SCALE_OFFSET: tl.constexpr, # page_size * head_dim (byte offset of the scale tail) +): + r = tl.program_id(0) + e = tl.program_id(1) + out_pos = r * c4_cap + e + c4_len = tl.load(c4_len_ptr + r) + if e >= c4_len: + # padding entry: mark slot invalid; K is never read (ke bounds the scorer range). + tl.store(Slots_out_ptr + out_pos, -1) + return + + # group-end token of compressed entry e lives at position e*RATIO + (RATIO-1). + req = tl.load(req_idx_ptr + r).to(tl.int64) + end_tok = e * RATIO + (RATIO - 1) + full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + end_tok).to(tl.int64) + c4_slot = tl.load(full_to_c4_ptr + full_slot).to(tl.int64) + valid = c4_slot >= 0 + + # inline PackedPagePool byte addressing (matches destindex_copy_indexer_k_dsv4 / gather_indexer_k): + # fp8 K at page*bytes_per_page + tok*head_dim; fp32 scale at (page*bytes_per_page + scale_off)//4 + tok. + page = c4_slot // PAGE_SIZE + tok = c4_slot % PAGE_SIZE + data_base = page * BYTES_PER_PAGE + tok * HEAD_DIM + scale_base = (page * BYTES_PER_PAGE + SCALE_OFFSET) // 4 + tok + + offs_d = tl.arange(0, HEAD_DIM) + k_fp8 = tl.load(SlabFp8_ptr + data_base + offs_d, mask=valid, other=0.0) + k_scale = tl.load(SlabF32_ptr + scale_base, mask=valid, other=0.0) + tl.store(Kout_fp8_ptr + out_pos * HEAD_DIM + offs_d, k_fp8) + tl.store(Kout_scale_ptr + out_pos, k_scale) + tl.store(Slots_out_ptr + out_pos, tl.where(valid, c4_slot, -1).to(tl.int32)) + + +@torch.no_grad() +def gather_c4_indexer_k_ragged( + mem_manager, + layer_index: int, + b_req_idx: torch.Tensor, + c4_len: torch.Tensor, + c4_cap: int, + req_to_token_indexs: torch.Tensor, +): + """Gather each request's causal c4 indexer keys into a padded-per-request ragged buffer for the + deep_gemm fp8_mqa_logits scorer (mirrors deepseek3_2's extract_indexer_ks, but reads our + PackedPagePool by c4 slot instead of a token-indexed [N,1,132] buffer). + + For batch position r and compressed entry e in [0, c4_len[r]): + c4_slot = full_to_c4[req_to_token[b_req_idx[r], e*ratio + (ratio-1)]] + The raw fp8 key + f32 scale at that slot land at row r*c4_cap + e of the output (so query token t + of request r reads keys [r*c4_cap, r*c4_cap + (pos+1)//ratio) -- absolute ks/ke offsets the caller + builds). Returns (k_fp8 [batch*c4_cap, HEAD_DIM] fp8, k_scale [batch*c4_cap] f32, slots + [batch*c4_cap] int32 = compact-row -> c4 pool slot, -1 for padding). Fixed shapes -> cuda-graph + safe (c4_cap is pinned per graph bucket); the padding region is never read by the scorer. + """ + pool = mem_manager.c4_indexer_pool + head_dim = mem_manager.indexer_head_dim + buf = pool.get_layer_buffer(mem_manager.layer_to_c4_idx[layer_index]).view(-1) + slab_fp8 = buf.view(torch.float8_e4m3fn) + slab_f32 = buf.view(torch.float32) + batch = b_req_idx.shape[0] + n = batch * c4_cap + k_fp8 = torch.empty((n, head_dim), dtype=torch.float8_e4m3fn, device=buf.device) + k_scale = torch.empty((n,), dtype=torch.float32, device=buf.device) + slots = torch.empty((n,), dtype=torch.int32, device=buf.device) + _gather_c4_indexer_k_kernel[(batch, c4_cap)]( + b_req_idx, + c4_len, + req_to_token_indexs, + req_to_token_indexs.stride(0), + mem_manager.full_to_c4_indexs, + slab_fp8, + slab_f32, + k_fp8, + k_scale, + slots, + c4_cap, + RATIO=4, + HEAD_DIM=head_dim, + PAGE_SIZE=pool.page_size, + BYTES_PER_PAGE=pool.bytes_per_page, + SCALE_OFFSET=pool.scale_offset_in_page, + num_warps=1, + ) + return k_fp8, k_scale, slots From 62c16d5675787b308395724ca1fea5689665f622 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 15 Jun 2026 13:27:49 +0000 Subject: [PATCH 21/30] opt --- lightllm/common/basemodel/basemodel.py | 8 +----- .../layer_infer/transformer_layer_infer.py | 10 +++---- .../layer_weights/transformer_layer_weight.py | 26 +++++++++---------- lightllm/models/deepseek_v4/model.py | 5 ---- .../triton_kernel/gather_c4_indexer_k_dsv4.py | 11 +++++--- 5 files changed, 24 insertions(+), 36 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 30248d6a21..b2782d3ed2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -640,13 +640,7 @@ def prefill_func(input_tensors, infer_state): handle_token_num = infer_state.input_ids.shape[0] - can_run_prefill_graph = self.prefill_graph is not None and self.prefill_graph.can_run( - handle_token_num=handle_token_num - ) - if can_run_prefill_graph and hasattr(self, "_can_run_prefill_cudagraph"): - can_run_prefill_graph = self._can_run_prefill_cudagraph(infer_state, handle_token_num) - - if can_run_prefill_graph: + if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num): finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num( handle_token_num=handle_token_num ) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index ef898c31bc..6ebc0b2856 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F import torch.distributed as dist from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.common.basemodel.attention.base_att import AttControl @@ -306,14 +305,13 @@ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV if not self.enable_ep_moe: x = self._tpsp_allgather(input=x, infer_state=infer_state) - gw = layer_weight.gate_weight_.mm_param.weight - logits = F.linear(x.float(), gw.float()).contiguous() + logits = layer_weight.gate_weight_.mm(x.float()).contiguous() weights, indices = self._select_experts(logits, infer_state, layer_weight) # shared expert 必须先于 routed 计算: fp8 路径 (FuseMoeTriton) 的 fused_experts # 是 inplace 的,_routed_experts 返回后 x 已被覆盖为 routed 输出。 - g = layer_weight.shared_gate_.mm(x).float().clamp(max=self.swiglu_limit) - u = layer_weight.shared_up_.mm(x).float().clamp(min=-self.swiglu_limit, max=self.swiglu_limit) - shared = layer_weight.shared_down_.mm((F.silu(g) * u).to(x.dtype)) + # 复用 Llama 的 _ffn_tp: fused gate_up matmul + silu_and_mul triton kernel,无 swiglu clamp, + # 对齐参考 DeepseekV4MLP(=LlamaMLP)。swiglu_limit clamp 只属于 routed 专家 (见 _routed_experts)。 + shared = self._ffn_tp(input=x, infer_state=infer_state, layer_weight=layer_weight) routed = self._routed_experts(x, weights, indices, layer_weight) if self.enable_ep_moe: if self.tp_world_size_ > 1: diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py index 581b2b1d96..5896027b38 100644 --- a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -189,12 +189,14 @@ def _init_indexer(self): # ------------------------------------------------------------------ moe def _init_moe(self): p = f"{self.prefix}.ffn" - # router gate (replicated) + # router gate (replicated). Stored as fp32: the topk_hash_softplus_sqrt router wants fp32 logits, + # so keep the gate matmul in fp32 — but store the (constant) weight as fp32 once here instead of + # re-casting it to fp32 on every forward in _ffn. self.gate_weight_ = ROWMMWeight( in_dim=self.hidden, out_dims=[self.n_routed_experts], weight_names=f"{p}.gate.weight", - data_type=self.data_type_, + data_type=torch.float32, quant_method=None, tp_rank=0, tp_world_size=1, @@ -209,23 +211,19 @@ def _init_moe(self): self.gate_bias_ = ParameterWeight( weight_name=f"{p}.gate.bias", data_type=torch.float32, weight_shape=(self.n_routed_experts,) ) - # shared expert (dense, bf16 after de-quant): w1=gate, w3=up (row), w2=down (col) + # shared expert (dense, bf16 after de-quant): w1=gate, w3=up fused (row), w2=down (col). + # Named gate_up_proj/down_proj so the inherited Llama `_ffn_tp` (fused gate_up matmul + + # silu_and_mul triton kernel, no swiglu clamp) drives it directly. Order [w1, w3] = [gate, up] + # matches silu_and_mul_fwd's blocked layout (first half gate, second half up). sp = f"{p}.shared_experts" - self.shared_gate_ = ROWMMWeight( + self.gate_up_proj = ROWMMWeight( in_dim=self.hidden, - out_dims=[self.moe_inter], - weight_names=f"{sp}.w1.weight", + out_dims=[self.moe_inter, self.moe_inter], + weight_names=[f"{sp}.w1.weight", f"{sp}.w3.weight"], data_type=self.data_type_, quant_method=self.get_quant_method("shared_gate"), ) - self.shared_up_ = ROWMMWeight( - in_dim=self.hidden, - out_dims=[self.moe_inter], - weight_names=f"{sp}.w3.weight", - data_type=self.data_type_, - quant_method=self.get_quant_method("shared_up"), - ) - self.shared_down_ = COLMMWeight( + self.down_proj = COLMMWeight( in_dim=self.moe_inter, out_dims=[self.hidden], weight_names=f"{sp}.w2.weight", diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 63430e548b..f9d341e395 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -116,11 +116,6 @@ def _init_cudagraph(self): self.graph_max_len_in_batch = DSV4_DECODE_CUDAGRAPH_MAX_LEN return super()._init_cudagraph() - def _can_run_prefill_cudagraph(self, infer_state: DeepseekV4InferStateInfo, handle_token_num): - if infer_state.prefix_total_token_num == 0: - return True - return False - def _init_att_backend(self): args = get_env_start_args() if args.llm_kv_type == "None": diff --git a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py index 6adac4fdd0..b6e6ba751a 100644 --- a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py +++ b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py @@ -22,9 +22,12 @@ def _gather_c4_indexer_k_kernel( BYTES_PER_PAGE: tl.constexpr, SCALE_OFFSET: tl.constexpr, # page_size * head_dim (byte offset of the scale tail) ): - r = tl.program_id(0) - e = tl.program_id(1) - out_pos = r * c4_cap + e + # entry index on grid-X (limit ~2^31), batch on grid-Y (<= running_max_req_size): c4_cap reaches + # 65536 at 256K context, which would blow the 65535 grid-Y cap if entries were the grid-Y axis. + e = tl.program_id(0) + r = tl.program_id(1) + # int64: out_pos*HEAD_DIM can exceed int32 at high batch + long context (read side is already int64). + out_pos = r.to(tl.int64) * c4_cap + e c4_len = tl.load(c4_len_ptr + r) if e >= c4_len: # padding entry: mark slot invalid; K is never read (ke bounds the scorer range). @@ -84,7 +87,7 @@ def gather_c4_indexer_k_ragged( k_fp8 = torch.empty((n, head_dim), dtype=torch.float8_e4m3fn, device=buf.device) k_scale = torch.empty((n,), dtype=torch.float32, device=buf.device) slots = torch.empty((n,), dtype=torch.int32, device=buf.device) - _gather_c4_indexer_k_kernel[(batch, c4_cap)]( + _gather_c4_indexer_k_kernel[(c4_cap, batch)]( b_req_idx, c4_len, req_to_token_indexs, From 69824d0059bd0dc9c355dbcac9b79ba715d0d932 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 15 Jun 2026 13:34:37 +0000 Subject: [PATCH 22/30] delete launch.sh --- launch.sh | 53 ----------------------------------------------------- 1 file changed, 53 deletions(-) delete mode 100644 launch.sh diff --git a/launch.sh b/launch.sh deleted file mode 100644 index b9c10d3f0a..0000000000 --- a/launch.sh +++ /dev/null @@ -1,53 +0,0 @@ -# DeepSeek-V4-Flash serving (run inside the lightllm container, repo mounted at /data/wanzihao/lightllm-ds4). -# Verified 2026-06-11: smoke + gsm8k pass with this configuration (prompt cache ENABLED, decode -# cudagraph ENABLED; gsm8k 100q/128: cold 0.960/112s, warm 0.970/23.5s with 100% cache hits — -# vs eager cold 0.970/141s, warm 0.960/50s; batch-1 decode 20.4ms/token vs 142ms eager). -# -# Required env/flags and why: -# LOADWORKER=16 - parallel weight loading (~5x faster startup). -# PYTHONPATH sglang - _get_qkv / compressor reuse sglang.jit_kernel.dsv4 (fused_q_norm_rope, compress_old). -# --batch_max_tokens 8192 - FlashMLA get_decoding_sched_meta rejects >8192 rows per call (probed: 8192 OK, 12288 fails). -# kv pool sizing: auto-profiled from mem_fraction. The fp4 marlin MoE weights materialize their -# CUDA marlin-layout buffers at construction (MXFP4MoEQuantizationMethod._create_weight), so the -# profile sees the true weight footprint on any GPU/config. --max_total_token_num overrides. -# decode cudagraph ENABLED - the v5 decode path is graph-safe: slot alloc/scatter in prep (outside -# graph), forward is pure gathers, HOLD padding rows redirect to HOLD slots. CORRECTNESS NOTE: -# FlashMLASchedMeta is lazily planned at first kernel call and written back onto the (shared) -# decode att state; the capture warmup pass would bake a dummy-content plan into the graph -# (gsm8k dropped to 0.74 with coherent-but-runaway generations). reset_sched_meta_for_capture() -# in cuda_graph._capture_decode re-plans INSIDE the captured region so every replay re-plans. -# DSV4 caps graph max_len_in_batch at 8192; longer decode batches fall back to eager. -# --enable_prefill_cudagraph + --prefill_cudagraph_max_handle_token 2048 - graph-sandwich prefill: -# graphs capture only the per-token dense ops; attention/compressor/indexer run eagerly between -# graph segments (att_func), so host-side planning and .tolist() prep never enter capture. Only -# cold prefills (prefix_total_token_num == 0, model gate) of <= 2048 new tokens replay; cache-hit -# and large batched prefills stay eager. Buckets are padded with a HOLD tail request whose -# attention output MUST be zeroed (infer_struct._dsv4_prefill_pad_q_len): pad rows read the -# racing HOLD slot, and nondeterministic pad hiddens perturb real rows via MoE expert batching -# (ulp-level, chaotically amplified ~1.9x/layer to O(1) by layer ~16 -> greedy token flips). -# Residual caveat: padded-vs-unpadded expert-batch composition still shifts reductions by ulps, -# same class as decode bucket padding; run-to-run determinism is anyway bounded by the fp4 -# marlin MoE kernel itself (probabilistic 1-ulp reduction-order noise measured eager-vs-eager). -# Acceptance is therefore statistical (gsm8k parity), not bitwise. -# --disable_flashinfer_allreduce - flashinfer cuda_ipc resolves libcudart to tilelang's stub (undefined cudaDeviceReset); symm-mem allreduce is used instead. -# -# One-time container setup already applied (survives until container rebuild): -# pip install ipython (sglang import dependency) -# site-packages/vllm: layers/mhc.py + kernels/mhc/ + _tilelang_ops.py overlaid from /data/wanzihao/vllm (mhc_pre_tilelang ops; original kept at layers/mhc.py.bak) -# -# original: python -m lightllm.server.api_server --model_dir /data/models/DeepSeek-V4-Flash --tp 4 --enable_prefill_cudagraph - -# repo root = this script's directory, so the same file works in the main tree and in worktrees -# (a hardcoded tree path here once made a worktree launch silently serve main-tree code). -REPO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -LOADWORKER=16 \ -PYTHONPATH="${REPO_DIR}":/data/wanzihao/sglang/python \ -python -m lightllm.server.api_server \ - --model_dir /data/models/DeepSeek-V4-Flash \ - --tp 4 \ - --batch_max_tokens 8192 \ - --disable_flashinfer_allreduce \ - --enable_prefill_cudagraph \ - --prefill_cudagraph_max_handle_token 2048 \ - --port 8000 From df70ecbf6a09bde7be6fccd9e0467794828a9c5b Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 15 Jun 2026 14:15:11 +0000 Subject: [PATCH 23/30] fix --- lightllm/models/deepseek_v4/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index f9d341e395..6b5d5f9e38 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -148,8 +148,11 @@ def _init_to_get_rotary(self): cfg = self.config rs = cfg.get("rope_scaling", {}) or {} dim = cfg["qk_rope_head_dim"] + # The rope tables MUST span every absolute position any request can produce (the served + # max_req_total_len / max_position_embeddings, up to 1M). Capping them shorter makes + # init_some_extra_state's index_select(cos/sin, position_ids) read OOB past the table at + # contexts beyond the cap (device-side assert / crash). ~268MB total at 1M, fp32x32 x4 views. max_seq = max(int(self.max_seq_length), int(cfg.get("max_position_embeddings", 8192))) - max_seq = min(max_seq, 1 << 18) # cap table size (256K) for correctness-first freq_exponents = torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim positions = torch.arange(max_seq, dtype=torch.float32, device="cuda") From 1ad981d06809631c1cd91fa8780440522e0894d1 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 16 Jun 2026 01:01:55 +0000 Subject: [PATCH 24/30] restore --- lightllm/__init__.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/lightllm/__init__.py b/lightllm/__init__.py index 8e515afb70..e9ba6f3041 100644 --- a/lightllm/__init__.py +++ b/lightllm/__init__.py @@ -1,31 +1,4 @@ from lightllm.utils.device_utils import is_musa - -def _patch_mp_resource_tracker_for_semaphore(): - from multiprocessing import resource_tracker - - if getattr(resource_tracker, "_lightllm_ignore_semaphore", False): - return - - orig_register = resource_tracker.register - orig_unregister = resource_tracker.unregister - - def register(name, rtype): - if rtype == "semaphore": - return - return orig_register(name, rtype) - - def unregister(name, rtype): - if rtype == "semaphore": - return - return orig_unregister(name, rtype) - - resource_tracker.register = register - resource_tracker.unregister = unregister - resource_tracker._lightllm_ignore_semaphore = True - - -_patch_mp_resource_tracker_for_semaphore() - if is_musa(): import torchada # noqa: F401 From 7b17bb554d2ac324a7796132c78b270675398b33 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 16 Jun 2026 01:47:41 +0000 Subject: [PATCH 25/30] support parser --- lightllm/models/deepseek_v4/model.py | 5 ++++ lightllm/server/api_cli.py | 1 + lightllm/server/build_prompt.py | 7 +++++ lightllm/server/function_call_parser.py | 37 +++++++++++++++++++++++-- lightllm/utils/config_utils.py | 8 ++++-- 5 files changed, 53 insertions(+), 5 deletions(-) diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py index 6b5d5f9e38..887e72f433 100644 --- a/lightllm/models/deepseek_v4/model.py +++ b/lightllm/models/deepseek_v4/model.py @@ -191,6 +191,11 @@ def _init_to_get_rotary(self): class DeepSeekV4Tokenizer: """Tokenizer wrapper for DeepSeek-V4's Python prompt encoding.""" + # DeepSeek-V4 has a per-request thinking mode (...) toggled via + # chat_template_kwargs={"thinking": true}. It has no Jinja chat_template string, + # so advertise thinking support explicitly for tokenizer_supports_force_thinking(). + supports_thinking = True + def __init__(self, tokenizer, model_dir): self.tokenizer = tokenizer self.model_dir = model_dir diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 70d5c72ac3..7a9ddb0968 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -169,6 +169,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "qwen", "deepseekv31", "deepseekv32", + "deepseekv4", "glm47", "kimi_k2", "qwen3_coder", diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 96b4d040bb..be6042fa56 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -53,6 +53,13 @@ def tokenizer_supports_force_thinking() -> bool: assert tokenizer is not None + # Tokenizers that encode prompts in Python (e.g. DeepSeek-V4) have no Jinja + # chat_template string to inspect, so advertise thinking support via an + # explicit attribute instead. + if getattr(tokenizer, "supports_thinking", False): + logger.info("tokenizer_supports_force_thinking : True (explicit attribute)") + return True + try: ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template logger.debug(f"chat_template: {tokenizer.chat_template}") diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index f204c154ed..9213f4c7d4 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -40,6 +40,7 @@ "[TOOL_CALLS]", "<|tool▁calls▁begin|>", "<|DSML|function_calls>", + "<|DSML|tool_calls>", ] @@ -1480,11 +1481,14 @@ class DeepSeekV32Detector(BaseFormatDetector): Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 """ - def __init__(self): + def __init__(self, block_name: str = "function_calls"): super().__init__() self.dsml_token = "|DSML|" - self.bot_token = f"<{self.dsml_token}function_calls>" - self.eot_token = f"" + # DeepSeek V3.2 wraps tool calls in a `function_calls` block; V4 uses + # `tool_calls`. Only the outer block name differs — the invoke/parameter + # grammar is identical — so subclasses just override block_name. + self.bot_token = f"<{self.dsml_token}{block_name}>" + self.eot_token = f"" self.invoke_start_prefix = f"<{self.dsml_token}invoke" self.invoke_end_token = f"" self.param_end_token = f"" @@ -1962,6 +1966,32 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() +class DeepSeekV4Detector(DeepSeekV32Detector): + """ + Detector for DeepSeek V4 model function call format using DSML. + + Identical grammar to V3.2 (``<|DSML|invoke name="...">`` blocks with + ``<|DSML|parameter name="k" string="true|false">v`` + tags), except the outer block is named ``tool_calls`` instead of + ``function_calls`` — matching the model's own encoding (encoding_dsv4.py: + ``tool_calls_block_name = "tool_calls"``) and system prompt. + + Format Structure: + ``` + <|DSML|tool_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">Hangzhou + + + ``` + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V4 + """ + + def __init__(self): + super().__init__(block_name="tool_calls") + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1975,6 +2005,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, "deepseekv32": DeepSeekV32Detector, + "deepseekv4": DeepSeekV4Detector, "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 21df2130e0..dcaf7315dd 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -444,6 +444,10 @@ def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: if model_type == "deepseek_v32": return "deepseekv32" + # DeepSeek V4 + if model_type == "deepseek_v4": + return "deepseekv4" + return None @@ -468,8 +472,8 @@ def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: ]: return "qwen3" - # DeepSeek V3 - if model_type in ["deepseek_v3", "deepseek_v31", "deepseek_v32"]: + # DeepSeek V3 / V4 (share the ... reasoning format, request-gated) + if model_type in ["deepseek_v3", "deepseek_v31", "deepseek_v32", "deepseek_v4"]: return "deepseek-v3" # DeepSeek R1 From 6837abd33ff5f3ecf105a465320fb0fb7ff48f9f Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 16 Jun 2026 04:05:42 +0000 Subject: [PATCH 26/30] fix --- .../basemodel/attention/nsa/fp8_flashmla_sparse.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 03595c1bce..dc18ecf4ba 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -25,18 +25,6 @@ def _pad_q_heads(q_4d: torch.Tensor, attn_sink: torch.Tensor): return q_pad, sink_pad, h_q -class DeepseekV4MissingOperatorError(RuntimeError): - pass - - -def _missing_attention_op(feature: str) -> None: - raise DeepseekV4MissingOperatorError( - f"DeepSeek-V4 {feature} has no production batch operator. The flashmla_kvcache path " - f"(packed swa/c4/c128 pools + paged compressor + indexer top-k) is the supported route; " - f"this legacy/non-flashmla entry point was never wired and is fenced on purpose." - ) - - def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor: from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN From 02a24ce306bfa31350dd2117f66715e67e8dbde9 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 18 Jun 2026 01:49:41 +0000 Subject: [PATCH 27/30] add c4 paged indexes --- .../deepseek4_mem_manager.py | 55 ++++-- lightllm/common/req_manager.py | 166 ++++++++++++++++-- .../layer_infer/transformer_layer_infer.py | 89 +++++++++- .../triton_kernel/gather_c4_indexer_k_dsv4.py | 91 ++++++++++ .../server/router/model_infer/infer_batch.py | 3 +- 5 files changed, 373 insertions(+), 31 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py index 32561aa6f7..cfb149dcec 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -31,6 +31,7 @@ DSV4_SWA_PAGE_SIZE = 128 DSV4_C4_PAGE_SIZE = 64 DSV4_C128_PAGE_SIZE = 2 +DSV4_PROMPT_CACHE_PAGE_SIZE = DSV4_C4_PAGE_SIZE * 4 # compressor state ring: c4 overlap 对为每页 2 个分组槽 × ratio 4 行;c128 离线聚合为每页 1 组。 DSV4_C4_STATE_RING = 8 DSV4_C128_STATE_RING = 128 @@ -194,11 +195,12 @@ def __init__( # ------------------------------------------------------------------ sizing def _swa_per_req_budget(self) -> int: # 活跃请求保留 window + 一个 radix 页(req_manager._swa_retain_len: 让最近完成的 - # 128 边界的结尾页恒驻留,prompt cache 插入门才能放行),即 v5 §2 的「活跃窗口跨页 ≤2」。 - return int(self.sliding_window) + DSV4_SWA_PAGE_SIZE + # prompt-cache 边界的结尾页恒驻留)。V4 的 prompt-cache 边界取 256 token, + # 避免 radix 共享前缀落在 c4 物理页(64 c4 entry = 256 token)中间。 + return int(self.sliding_window) + DSV4_PROMPT_CACHE_PAGE_SIZE def _planned_swa_size(self, full_size: int) -> int: - # swa 池按页分配(页 = 128 = sliding_window = radix 页),容量向上取整到整页。 + # swa 池按页分配(页 = 128 = sliding_window),容量向上取整到整页。 if self.max_request_num is None or self.sliding_window is None: return _ceil_div(full_size, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE cap = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num @@ -335,6 +337,8 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.c4_pool: Optional[PackedPagePool] = None self.c4_indexer_pool: Optional[PackedPagePool] = None self.c4_allocator: Optional[KvCacheAllocator] = None + self.c4_page_allocator: Optional[KvCacheAllocator] = None + self.c4_page_live_count: Optional[torch.Tensor] = None self.c128_pool: Optional[PackedPagePool] = None self.c128_allocator: Optional[KvCacheAllocator] = None self.c4_state_buffer: Optional[torch.Tensor] = None @@ -360,9 +364,12 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): data_bytes=self.indexer_head_dim, scale_bytes=DSV4_INDEXER_SCALE_BYTES, ) - self.c4_allocator = KvCacheAllocator( - self.c4_size, shared_name=f"{server}_dsv4_c4_can_use_token_num_{rank_in_node}" + self.c4_num_pages = self.c4_size // DSV4_C4_PAGE_SIZE + assert self.c4_num_pages > 0, "DeepSeek-V4 c4 pool must have at least one usable full page" + self.c4_page_allocator = KvCacheAllocator( + self.c4_num_pages, shared_name=f"{server}_dsv4_c4_can_use_page_num_{rank_in_node}" ) + self.c4_page_live_count = torch.zeros((self.c4_pool.num_pages,), dtype=torch.int32, device="cuda") self.full_to_c4_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") self.full_to_c4_indexs[size] = self.c4_pool.HOLD_TOKEN_MEMINDEX # c4 compressor 在途状态(attention + indexer): swa 页派生寻址(翻译③),随 swa 页 @@ -588,11 +595,36 @@ def _evict_compress(self, full_slots: torch.Tensor, mapping: torch.Tensor, alloc mapping[full_slots[valid]] = -1 return + def alloc_c4_pages(self, need_pages: int) -> torch.Tensor: + assert self.c4_page_allocator is not None, "DeepSeek-V4 c4 page allocator is not initialized" + return self.c4_page_allocator.alloc(need_pages) + + def count_c4_slots(self, c4_slots: torch.Tensor, delta: int) -> torch.Tensor: + """按 c4 slot 所在页更新存活计数,返回触达的页(去重)。""" + assert self.c4_page_live_count is not None, "DeepSeek-V4 c4 page live count is not initialized" + pages = torch.div(c4_slots.long(), DSV4_C4_PAGE_SIZE, rounding_mode="floor") + ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device) + self.c4_page_live_count.index_add_(0, pages, ones) + return torch.unique(pages) + def evict_c4(self, full_slots: torch.Tensor) -> None: """回收 full 槽位(组末 token)映射的 c4 槽。非组末/未映射(-1)的槽位跳过。""" - if self.c4_allocator is None or full_slots.numel() == 0: + if self.c4_page_allocator is None or full_slots.numel() == 0: return - self._evict_compress(full_slots, self.full_to_c4_indexs, self.c4_allocator) + full_slots = full_slots.cuda().long().reshape(-1) + full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX]) + if full_slots.numel() == 0: + return + slots = self.full_to_c4_indexs[full_slots] + valid = slots >= 0 + valid_slots = slots[valid] + if valid_slots.numel() == 0: + return + self.full_to_c4_indexs[full_slots[valid]] = -1 + touched = self.count_c4_slots(valid_slots, -1) + empty = touched[self.c4_page_live_count[touched] == 0] + if empty.numel() > 0: + self.c4_page_allocator.free(empty.to(torch.int32)) return def evict_c128(self, full_slots: torch.Tensor) -> None: @@ -623,8 +655,9 @@ def free_all(self): self.swa_page_live_count.zero_() self.full_to_swa_indexs.fill_(-1) self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX - if self.c4_allocator is not None: - self.c4_allocator.free_all() + if self.c4_page_allocator is not None: + self.c4_page_allocator.free_all() + self.c4_page_live_count.zero_() self.full_to_c4_indexs.fill_(-1) self.full_to_c4_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c4_pool.HOLD_TOKEN_MEMINDEX if self.c128_allocator is not None: @@ -634,13 +667,13 @@ def free_all(self): return def alloc_c4(self, need_size) -> torch.Tensor: - return self.c4_allocator.alloc(need_size) + raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation; call alloc_c4_pages instead") def alloc_c128(self, need_size) -> torch.Tensor: return self.c128_allocator.alloc(need_size) def free_c4(self, free_index) -> None: - self.c4_allocator.free(free_index) + raise AssertionError("DeepSeek-V4 c4 uses page live-count release; call evict_c4 instead") def free_c128(self, free_index) -> None: self.c128_allocator.free(free_index) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 9ac1babe23..5e7c4f96dd 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -17,6 +17,10 @@ from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import ( LinearAttCacheManager, ) +from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( + DSV4_C4_PAGE_SIZE, + DSV4_PROMPT_CACHE_PAGE_SIZE, +) if TYPE_CHECKING: from lightllm.server.router.model_infer.infer_batch import InferReq @@ -30,10 +34,11 @@ class DeepseekV4PromptCachePayload: 槽位与 compressor 状态都不进载荷: full_to_swa/full_to_c4/full_to_c128 以 full token 槽位 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4/c128 compressor 状态以 swa - 页派生寻址(随 swa 页生灭,命中零拷贝续算)。c128 partial state 不跨 radix 的 128 边界保存。 + 页派生寻址(随 swa 页生灭,命中零拷贝续算)。prompt cache 对齐到 256 token, + 避免共享前缀停在 c4 物理页中间。 * ``swa_page_valid``: cpu bool [cache_len // page],插入时按当下 full_to_swa 映射写定 - (页内 128 个映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 128 边界, + (页内 token 映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 page 边界, swa 压力阀回收节点页时清零。""" cache_len: int @@ -61,7 +66,7 @@ def invalidate_swa_pages(self, payload: DeepseekV4PromptCachePayload) -> None: return def valid_match_length(self, payload: Optional[DeepseekV4PromptCachePayload], natural_len: int) -> int: - """radix 匹配裁剪: 返回 <= natural_len 的最大 128 边界 L',使结尾页(bitmap[L'/128-1])有效。 + """radix 匹配裁剪: 返回 <= natural_len 的最大 prompt-cache 边界 L',使结尾页有效。 有效性可能非单调(owner 生前从左驱逐、后续阀从尾回收),按候选边界回查 bitmap; 中段 invalid 页不挡更靠后的有效命中(注意力只回看最后一个窗口)。""" @@ -441,10 +446,9 @@ def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager): def _swa_retain_len(self) -> int: """出窗回收的保留长度 = window + 一个 radix 页。 - 多留一页使「最近一个完成的 128 边界」的结尾页恒驻留: prompt cache 只能在 floor(cur/128) - 边界入树(radix page=128),若回收只留 window,则任何非对齐时刻该边界的结尾页都已被 - 部分回收,插入门会把所有插入裁到 0(prompt cache 形同虚设)。预算即 v5 §2 的每请求 - 「活跃窗口跨页 ≤2」。驻留证明要求 window >= page-1(DSV4 实际 window == page == 128)。""" + 多留一页使「最近一个完成的 prompt-cache 边界」的结尾页恒驻留: 若回收只留 window, + 则任何非对齐时刻该边界的结尾页都已被部分回收,插入门会把所有插入裁到 0。 + V4 prompt-cache 页取 256 token,正好覆盖一个 c4 物理页对应的 token 范围。""" return int(self.sliding_window) + self.get_prompt_cache_page_size() def prepare_prefill_swa( @@ -535,11 +539,133 @@ def init_compress_state(self, req_idx: int): def _compress_mapping_alloc(self, ratio: int): assert self.mem_manager is not None, "DeepSeek-V4 mem manager is not bound yet" if ratio == 4: - return self.mem_manager.full_to_c4_indexs, self.mem_manager.alloc_c4 + raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation") if ratio == 128: return self.mem_manager.full_to_c128_indexs, self.mem_manager.alloc_c128 raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}") + def _scatter_c4_prefill_slots_slow(self, req_idx: int, first: int, last: int) -> None: + """Idempotence fallback for overlapped/repeated c4 prep.""" + page = DSV4_C4_PAGE_SIZE + mapping = self.mem_manager.full_to_c4_indexs + for page_base in range((first // page) * page, last, page): + e0 = max(first, page_base) + e1 = min(last, page_base + page) + entries = torch.arange(e0, e1, dtype=torch.long, device="cuda") + full_slots = self.req_to_token_indexs[req_idx, entries * 4 + 3].long() + existing = mapping[full_slots] + missing = existing < 0 + if not bool(missing.any()): + continue + + mapped = torch.nonzero(existing >= 0, as_tuple=False) + if mapped.numel() > 0: + j = int(mapped[0].item()) + base = int(existing[j].item()) - ((e0 + j) % page) + elif e0 > page_base: + prev_full = self.req_to_token_indexs[req_idx, e0 * 4 - 1].long() + prev_slot = int(mapping[prev_full].item()) + assert prev_slot >= 0 and prev_slot % page == (e0 - 1) % page + base = prev_slot - ((e0 - 1) % page) + else: + base = int(self.mem_manager.alloc_c4_pages(1)[0].item()) * page + + slots = (base + entries % page).to(torch.int32) + if mapped.numel() > 0: + assert bool((existing[existing >= 0] == slots[existing >= 0]).all()) + mapping[full_slots[missing]] = slots[missing] + self.mem_manager.count_c4_slots(slots[missing], 1) + return + + def _scatter_c4_prefill_slots(self, req_idx: int, first: int, last: int) -> None: + """为 logical c4 entry [first, last) 分配 page-safe c4 槽。 + + 不变式: logical entry e 映射到 physical_page * 64 + e % 64,同一 logical page + 内 entry 共享 physical_page。这是 DeepGEMM paged MQA logits 直接消费 page table 的前提。 + """ + if last <= first: + return + page = DSV4_C4_PAGE_SIZE + mapping = self.mem_manager.full_to_c4_indexs + entries = torch.arange(first, last, dtype=torch.long, device="cuda") + full_slots = self.req_to_token_indexs[req_idx, entries * 4 + 3].long() + need = mapping[full_slots] < 0 + if not bool(need.any()): + return + if not bool(need.all()): + self._scatter_c4_prefill_slots_slow(req_idx, first, last) + return + + first_page = first // page + last_page = (last - 1) // page + n_pages = last_page - first_page + 1 + bases = torch.empty((n_pages,), dtype=torch.long, device="cuda") + + base_start = 0 + if first % page != 0: + prev_full = self.req_to_token_indexs[req_idx, first * 4 - 1].long() + prev_slot = int(mapping[prev_full].item()) + assert prev_slot >= 0 and prev_slot % page == (first - 1) % page + bases[0] = prev_slot - ((first - 1) % page) + base_start = 1 + + new_page_count = n_pages - base_start + if new_page_count > 0: + new_pages = self.mem_manager.alloc_c4_pages(new_page_count).cuda(non_blocking=True).long() + bases[base_start:] = new_pages * page + + page_local = torch.div(entries, page, rounding_mode="floor") - first_page + slots = (bases[page_local] + entries % page).to(torch.int32) + mapping[full_slots] = slots + self.mem_manager.count_c4_slots(slots, 1) + return + + def _scatter_c4_decode_slots( + self, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, + mem_indexes: torch.Tensor, + ) -> None: + page = DSV4_C4_PAGE_SIZE + mapping = self.mem_manager.full_to_c4_indexs + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() + mem_indexes = mem_indexes.cuda().long().reshape(-1) + + cont_rows, cont_prev_pos, cont_offsets = [], [], [] + new_rows = [] + for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)): + req_idx, seq_len = int(req_idx), int(seq_len) + if req_idx == self.HOLD_REQUEST_ID or seq_len <= 0 or seq_len % 4 != 0: + continue + entry = seq_len // 4 - 1 + offset = entry % page + if offset == 0: + new_rows.append(i) + else: + cont_rows.append(i) + cont_prev_pos.append(entry * 4 - 1) + cont_offsets.append(offset) + + if cont_rows: + req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda") + prev_pos = torch.tensor(cont_prev_pos, dtype=torch.long, device="cuda") + prev_full = self.req_to_token_indexs[req_rows, prev_pos].long() + prev_slots = mapping[prev_full] + offsets = torch.tensor(cont_offsets, dtype=torch.int32, device="cuda") + assert bool((prev_slots >= 0).all()) + assert bool(((prev_slots % page) == (offsets - 1)).all()) + slots = (prev_slots + 1).to(torch.int32) + mapping[mem_indexes[cont_rows]] = slots + self.mem_manager.count_c4_slots(slots, 1) + + if new_rows: + pages = self.mem_manager.alloc_c4_pages(len(new_rows)).cuda(non_blocking=True).long() + slots = (pages * page).to(torch.int32) + mapping[mem_indexes[new_rows]] = slots + self.mem_manager.count_c4_slots(slots, 1) + return + def _scatter_compress_slots(self, ratio: int, full_slots: torch.Tensor) -> None: """为组末 full 槽位分配压缩槽并写入映射。已映射(>=0)的行跳过——重复 prep 幂等。""" if full_slots.numel() == 0: @@ -568,9 +694,15 @@ def prepare_prefill_compress_slots( req_list = b_req_idx.detach().cpu().tolist() ready_list = b_ready_cache_len.detach().cpu().tolist() seq_list = b_seq_len.detach().cpu().tolist() - for ratio, n_layers in ((4, self.n_c4), (128, self.n_c128)): - if n_layers == 0: - continue + if self.n_c4 > 0: + for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + self._scatter_c4_prefill_slots(req_idx, int(ready_len) // 4, int(seq_len) // 4) + + if self.n_c128 > 0: + ratio = 128 end_slots = [] for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list): req_idx = int(req_idx) @@ -597,9 +729,11 @@ def prepare_decode_compress_slots( return req_list = b_req_idx_cpu.tolist() seq_list = b_seq_len_cpu.tolist() - for ratio, n_layers in ((4, self.n_c4), (128, self.n_c128)): - if n_layers == 0: - continue + if self.n_c4 > 0: + self._scatter_c4_decode_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes) + + if self.n_c128 > 0: + ratio = 128 rows = [ i for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)) @@ -624,7 +758,7 @@ def get_prompt_cache_value_ops(self): return DeepseekV4PromptCacheValueOps(self) def get_prompt_cache_page_size(self): - return 128 + return DSV4_PROMPT_CACHE_PAGE_SIZE def compute_swa_page_valid(self, full_slots: torch.Tensor) -> torch.Tensor: """按当下 full_to_swa 映射给出按页有效性: full_slots [L](L 为 page 整数倍) -> @@ -654,7 +788,7 @@ def slice_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload, star start = int(start) end = int(end) page = self.get_prompt_cache_page_size() - # radix page=128 保证分裂点页对齐,bitmap 可整页切分。 + # radix page 保证分裂点页对齐,bitmap 可整页切分。 return DeepseekV4PromptCachePayload( cache_len=end - start, swa_page_valid=payload.swa_page_valid[start // page : end // page].clone() diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 6ebc0b2856..2721c23ac4 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -1,3 +1,4 @@ +import os import torch import torch.distributed as dist from lightllm.common.basemodel import TransformerLayerInferTpl @@ -585,13 +586,26 @@ def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, ) return slots.unsqueeze(1), lengths - import deep_gemm - from ..triton_kernel.gather_c4_indexer_k_dsv4 import gather_c4_indexer_k_ragged - b_req_idx = infer_state.b_req_idx batch = b_req_idx.shape[0] device = positions.device c4_len = torch.div(infer_state.b_seq_len, 4, rounding_mode="floor").to(torch.int32) # entries/req + + if os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER", "0") == "1": + out = self._c4_indices_paged( + infer_state=infer_state, + idx_q_fp8=idx_q_fp8, + weights=weights, + positions=positions, + c4_len=c4_len, + c4_cap=c4_cap, + ) + if out is not None: + return out + + import deep_gemm + from ..triton_kernel.gather_c4_indexer_k_dsv4 import gather_c4_indexer_k_ragged + k_fp8, k_scale, ragged_slots = gather_c4_indexer_k_ragged( mem_manager, self.layer_idx_, @@ -622,3 +636,72 @@ def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, top_slots = torch.where(invalid, torch.full_like(top_slots, -1), top_slots) topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1) return top_slots.unsqueeze(1), topk_lengths.contiguous() + + def _c4_indices_paged(self, infer_state, idx_q_fp8, weights, positions, c4_len, c4_cap): + import deep_gemm + from sglang.jit_kernel.dsv4 import topk_transform_512 + from ..triton_kernel.gather_c4_indexer_k_dsv4 import build_c4_indexer_page_table + + mem_manager = infer_state.mem_manager + index_topk = self.index_topk + device = positions.device + b_req_idx = infer_state.b_req_idx + batch = b_req_idx.shape[0] + validate = os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER_VALIDATE", "0") == "1" + validate_now = validate and not torch.cuda.is_current_stream_capturing() + + page_table, valid_flag = build_c4_indexer_page_table( + mem_manager, + b_req_idx, + c4_len, + c4_cap, + infer_state.req_manager.req_to_token_indexs, + infer_state.req_manager.HOLD_REQUEST_ID, + validate=validate_now, + ) + if validate_now and int(valid_flag.item()) == 0: + if os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER_STRICT", "0") == "1": + raise RuntimeError("DeepSeek-V4 paged indexer requires page-aligned c4 slots") + return None + + if infer_state.is_prefill: + token_batch_pos = torch.repeat_interleave( + torch.arange(batch, device=device, dtype=torch.int32), infer_state.b_q_seq_len + ) + row_page_table = page_table[token_batch_pos.long()].contiguous() + else: + row_page_table = page_table + + valid_len = ((positions + 1) // 4).to(torch.int32) + ctx_lens = torch.clamp(valid_len, min=1).reshape(-1, 1).contiguous() + kv_cache = mem_manager.c4_indexer_pool.get_layer_buffer(mem_manager.layer_to_c4_idx[self.layer_idx_]).view( + mem_manager.c4_indexer_pool.num_pages, + mem_manager.c4_indexer_pool.page_size, + 1, + self.index_head_dim + 4, + ) + metadata = deep_gemm.get_paged_mqa_logits_metadata( + ctx_lens, + mem_manager.c4_indexer_pool.page_size, + deep_gemm.get_num_sms(), + ) + logits = deep_gemm.fp8_paged_mqa_logits( + idx_q_fp8.unsqueeze(1), + kv_cache, + weights, + ctx_lens, + row_page_table, + metadata, + c4_cap, + False, + ) + top_slots = torch.empty((idx_q_fp8.shape[0], index_topk), dtype=torch.int32, device=device) + topk_transform_512( + logits, + valid_len, + row_page_table, + top_slots, + mem_manager.c4_indexer_pool.page_size, + ) + topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1) + return top_slots.unsqueeze(1), topk_lengths.contiguous() diff --git a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py index b6e6ba751a..a7a0a4be85 100644 --- a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py +++ b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py @@ -107,3 +107,94 @@ def gather_c4_indexer_k_ragged( num_warps=1, ) return k_fp8, k_scale, slots + + +@triton.jit +def _build_c4_indexer_page_table_kernel( + req_idx_ptr, # [batch] int + c4_len_ptr, # [batch] int + req_to_token_ptr, + req_to_token_stride0, + full_to_c4_ptr, + page_table_ptr, # [batch, page_cap] int32 + valid_flag_ptr, # [1] int32, initialized to 1; set to 0 on layout mismatch + page_cap, + hold_req_id, + RATIO: tl.constexpr, + PAGE_SIZE: tl.constexpr, + VALIDATE: tl.constexpr, +): + p = tl.program_id(0) + r = tl.program_id(1) + req = tl.load(req_idx_ptr + r).to(tl.int64) + c4_len = tl.load(c4_len_ptr + r).to(tl.int64) + page_start = p * PAGE_SIZE + active = (req != hold_req_id) & (page_start < c4_len) + + full_pos0 = page_start * RATIO + (RATIO - 1) + full_slot0 = tl.load( + req_to_token_ptr + req * req_to_token_stride0 + full_pos0, + mask=active, + other=0, + ).to(tl.int64) + c4_slot0 = tl.load(full_to_c4_ptr + full_slot0, mask=active, other=0).to(tl.int64) + phys_page = c4_slot0 // PAGE_SIZE + tl.store(page_table_ptr + r * page_cap + p, tl.where(active, phys_page, 0).to(tl.int32)) + + if VALIDATE: + offs = tl.arange(0, PAGE_SIZE) + e = page_start + offs + valid = active & (e < c4_len) + full_pos = e * RATIO + (RATIO - 1) + full_slot = tl.load( + req_to_token_ptr + req * req_to_token_stride0 + full_pos, + mask=valid, + other=0, + ).to(tl.int64) + c4_slot = tl.load(full_to_c4_ptr + full_slot, mask=valid, other=-1).to(tl.int64) + expected = phys_page * PAGE_SIZE + offs + ok = tl.where(valid, (c4_slot == expected) & (c4_slot >= 0), True) + if tl.min(ok.to(tl.int32), axis=0) == 0: + tl.store(valid_flag_ptr, 0) + + +@torch.no_grad() +def build_c4_indexer_page_table( + mem_manager, + b_req_idx: torch.Tensor, + c4_len: torch.Tensor, + c4_cap: int, + req_to_token_indexs: torch.Tensor, + hold_req_id: int, + validate: bool = False, +): + """Build the logical-c4-page -> physical-c4-page table expected by DeepGEMM paged logits. + + This is safe only when each logical c4 page maps to a physical page with matching offsets: + c4_slot(entry p*64 + o) == page_table[p] * 64 + o + The optional validation flag checks that invariant and lets the caller fall back to the + gather path while we keep the current token-slot allocator. + """ + pool = mem_manager.c4_indexer_pool + page_size = pool.page_size + assert c4_cap % page_size == 0 + batch = b_req_idx.shape[0] + page_cap = c4_cap // page_size + page_table = torch.empty((batch, page_cap), dtype=torch.int32, device=b_req_idx.device) + valid_flag = torch.ones((1,), dtype=torch.int32, device=b_req_idx.device) + _build_c4_indexer_page_table_kernel[(page_cap, batch)]( + b_req_idx, + c4_len, + req_to_token_indexs, + req_to_token_indexs.stride(0), + mem_manager.full_to_c4_indexs, + page_table, + valid_flag, + page_cap, + int(hold_req_id), + RATIO=4, + PAGE_SIZE=page_size, + VALIDATE=validate, + num_warps=1, + ) + return page_table, valid_flag diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index a852461df0..b667c4be72 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -8,7 +8,7 @@ from sortedcontainers import SortedDict from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any, Union -from lightllm.common.req_manager import ReqManager, ReqManagerForMamba +from lightllm.common.req_manager import DeepseekV4ReqManager, ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -177,6 +177,7 @@ def _dsv4_full_att_free_req(self, free_token_index: List, req: "InferReq"): # 载荷只剩按页 bitmap(compressor 状态随 swa 页生灭/边界自然归零,不进载荷), # 任意 128 对齐前缀皆可插入——含生成段(floor(cur_kv_len) 边界,回收保留尾页保证其驻留)。 cache_len = self.radix_cache.align_len(req.cur_kv_len) + self.req_manager: DeepseekV4ReqManager if cache_len > old_prefix_len: payload = self.req_manager.build_prompt_cache_payload(req.req_idx, cache_len) value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() From 52a1528051a53f0f1291332357d7df5730a5df59 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 18 Jun 2026 03:07:49 +0000 Subject: [PATCH 28/30] fix chunk_size and page_size --- lightllm/server/router/model_infer/infer_batch.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index b667c4be72..2bf2314185 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -916,10 +916,16 @@ def _align_chuncked_end_for_prompt_cache(self, chunked_start: int, chunked_end: page_size = getattr(radix_cache, "page_size", 1) if radix_cache is not None else 1 if page_size <= 1 or self.sampling_param.disable_prompt_cache: return chunked_end - prompt_end = self.shm_req.input_len - next_page_end = ((int(chunked_start) // page_size) + 1) * page_size - if int(chunked_start) < next_page_end < int(chunked_end) and next_page_end <= prompt_end: - return next_page_end + prompt_end = int(self.shm_req.input_len) + chunked_start = int(chunked_start) + chunked_end = int(chunked_end) + if chunked_end >= prompt_end: + return chunked_end + + assert self.args.chunked_prefill_size % page_size == 0, ( + f"chunked_prefill_size={self.args.chunked_prefill_size} must be divisible by " + f"prompt-cache page_size={page_size}" + ) return chunked_end def get_chuncked_input_token_len_for_linear_att(self): From 0dbc90b6a233ffe4dfbea2c4a237b5a18f88a612 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 18 Jun 2026 03:08:41 +0000 Subject: [PATCH 29/30] add sglang third_party --- .../layer_infer/transformer_layer_infer.py | 4 +- lightllm/third_party/__init__.py | 1 + lightllm/third_party/sglang_jit/LICENSE | 201 ++++ lightllm/third_party/sglang_jit/README.md | 13 + lightllm/third_party/sglang_jit/__init__.py | 1 + .../sglang_jit/csrc/deepseek_v4/c128.cuh | 522 +++++++++++ .../csrc/deepseek_v4/c128_online.cuh | 726 +++++++++++++++ .../csrc/deepseek_v4/c128_online_v2.cuh | 875 ++++++++++++++++++ .../sglang_jit/csrc/deepseek_v4/c128_v2.cuh | 448 +++++++++ .../sglang_jit/csrc/deepseek_v4/c4.cuh | 549 +++++++++++ .../sglang_jit/csrc/deepseek_v4/c4_v2.cuh | 405 ++++++++ .../sglang_jit/csrc/deepseek_v4/c_plan.cuh | 839 +++++++++++++++++ .../sglang_jit/csrc/deepseek_v4/common.cuh | 208 +++++ .../csrc/deepseek_v4/fused_norm_rope.cuh | 254 +++++ .../csrc/deepseek_v4/fused_norm_rope_v2.cuh | 643 +++++++++++++ .../sglang_jit/csrc/deepseek_v4/hash_topk.cuh | 214 +++++ .../csrc/deepseek_v4/hisparse_transfer.cuh | 82 ++ .../csrc/deepseek_v4/main_norm_rope.cuh | 845 +++++++++++++++++ .../deepseek_v4/mega_moe_pre_dispatch.cuh | 219 +++++ .../csrc/deepseek_v4/paged_mqa_metadata.cuh | 119 +++ .../sglang_jit/csrc/deepseek_v4/rope.cuh | 169 ++++ .../silu_and_mul_masked_post_quant.cuh | 540 +++++++++++ .../sglang_jit/csrc/deepseek_v4/store.cuh | 205 ++++ .../sglang_jit/csrc/deepseek_v4/topk_v1.cuh | 340 +++++++ .../sglang_jit/csrc/deepseek_v4/topk_v2.cuh | 493 ++++++++++ .../third_party/sglang_jit/dsv4/__init__.py | 8 + .../sglang_jit/dsv4/elementwise.py | 215 +++++ lightllm/third_party/sglang_jit/dsv4/topk.py | 92 ++ lightllm/third_party/sglang_jit/dsv4/utils.py | 2 + .../sglang_jit/include/sgl_kernel/atomic.cuh | 35 + .../sglang_jit/include/sgl_kernel/cta.cuh | 40 + .../sgl_kernel/deepseek_v4/compress.cuh | 37 + .../sgl_kernel/deepseek_v4/compress_v2.cuh | 99 ++ .../sgl_kernel/deepseek_v4/fp8_utils.cuh | 112 +++ .../sgl_kernel/deepseek_v4/kvcacheio.cuh | 96 ++ .../sgl_kernel/deepseek_v4/topk/cluster.cuh | 257 +++++ .../sgl_kernel/deepseek_v4/topk/common.cuh | 176 ++++ .../sgl_kernel/deepseek_v4/topk/ptx.cuh | 54 ++ .../sgl_kernel/deepseek_v4/topk/register.cuh | 302 ++++++ .../sgl_kernel/deepseek_v4/topk/streaming.cuh | 213 +++++ .../include/sgl_kernel/distributed/common.cuh | 120 +++ .../distributed/custom_all_reduce.cuh | 354 +++++++ .../sglang_jit/include/sgl_kernel/ffi.h | 104 +++ .../include/sgl_kernel/impl/norm.cuh | 168 ++++ .../sglang_jit/include/sgl_kernel/math.cuh | 71 ++ .../sglang_jit/include/sgl_kernel/runtime.cuh | 86 ++ .../include/sgl_kernel/scalar_type.hpp | 334 +++++++ .../include/sgl_kernel/source_location.h | 40 + .../sglang_jit/include/sgl_kernel/tensor.h | 605 ++++++++++++ .../sglang_jit/include/sgl_kernel/tile.cuh | 62 ++ .../sglang_jit/include/sgl_kernel/type.cuh | 120 +++ .../sglang_jit/include/sgl_kernel/utils.cuh | 333 +++++++ .../sglang_jit/include/sgl_kernel/utils.h | 186 ++++ .../sglang_jit/include/sgl_kernel/vec.cuh | 118 +++ .../sglang_jit/include/sgl_kernel/warp.cuh | 56 ++ lightllm/third_party/sglang_jit/jit_utils.py | 432 +++++++++ .../third_party/sglang_jit/runtime_utils.py | 5 + 57 files changed, 13845 insertions(+), 2 deletions(-) create mode 100644 lightllm/third_party/__init__.py create mode 100755 lightllm/third_party/sglang_jit/LICENSE create mode 100644 lightllm/third_party/sglang_jit/README.md create mode 100644 lightllm/third_party/sglang_jit/__init__.py create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh create mode 100644 lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh create mode 100644 lightllm/third_party/sglang_jit/dsv4/__init__.py create mode 100644 lightllm/third_party/sglang_jit/dsv4/elementwise.py create mode 100644 lightllm/third_party/sglang_jit/dsv4/topk.py create mode 100644 lightllm/third_party/sglang_jit/dsv4/utils.py create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh create mode 100644 lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh create mode 100644 lightllm/third_party/sglang_jit/jit_utils.py create mode 100644 lightllm/third_party/sglang_jit/runtime_utils.py diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 2721c23ac4..617d0dcd85 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -144,7 +144,7 @@ def _get_qkv( infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight, ): - from sglang.jit_kernel.dsv4 import fused_q_norm_rope + from lightllm.third_party.sglang_jit.dsv4 import fused_q_norm_rope input = self._tpsp_allgather(input=input, infer_state=infer_state) T = input.shape[0] @@ -639,7 +639,7 @@ def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, def _c4_indices_paged(self, infer_state, idx_q_fp8, weights, positions, c4_len, c4_cap): import deep_gemm - from sglang.jit_kernel.dsv4 import topk_transform_512 + from lightllm.third_party.sglang_jit.dsv4 import topk_transform_512 from ..triton_kernel.gather_c4_indexer_k_dsv4 import build_c4_indexer_page_table mem_manager = infer_state.mem_manager diff --git a/lightllm/third_party/__init__.py b/lightllm/third_party/__init__.py new file mode 100644 index 0000000000..2adb50db25 --- /dev/null +++ b/lightllm/third_party/__init__.py @@ -0,0 +1 @@ +"""Third-party source subsets vendored for LightLLM runtime support.""" diff --git a/lightllm/third_party/sglang_jit/LICENSE b/lightllm/third_party/sglang_jit/LICENSE new file mode 100755 index 0000000000..9c422689c8 --- /dev/null +++ b/lightllm/third_party/sglang_jit/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-2024 SGLang Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/lightllm/third_party/sglang_jit/README.md b/lightllm/third_party/sglang_jit/README.md new file mode 100644 index 0000000000..4f68c9cfd8 --- /dev/null +++ b/lightllm/third_party/sglang_jit/README.md @@ -0,0 +1,13 @@ +# Vendored SGLang JIT Subset + +This directory contains the minimal SGLang JIT source subset needed by the +DeepSeek-V4 LightLLM implementation. + +Source: https://github.com/sgl-project/sglang +Commit: 8cea0473ea5299bc04885f8f6ba71269415a39b5 +License: Apache License 2.0, copied in `LICENSE`. + +Local changes: +- The Python imports were moved from `sglang.jit_kernel.*` to + `lightllm.third_party.sglang_jit.*`. +- The package exports only the DSv4 functions used by LightLLM. diff --git a/lightllm/third_party/sglang_jit/__init__.py b/lightllm/third_party/sglang_jit/__init__.py new file mode 100644 index 0000000000..164d545b4e --- /dev/null +++ b/lightllm/third_party/sglang_jit/__init__.py @@ -0,0 +1 @@ +"""Vendored SGLang JIT kernels used by DeepSeek-V4.""" diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh new file mode 100644 index 0000000000..3a89e8114c --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh @@ -0,0 +1,522 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + const int32_t index = indices_ptr[global_bid]; + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh new file mode 100644 index 0000000000..b497470606 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh @@ -0,0 +1,726 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace device::compress { + +/// \brief Plan entry for online compress 128 prefill. +/// Each entry describes a contiguous segment of tokens that lies inside a +/// single 128-chunk. Multiple segments can map to the same batch id when the +/// extend tokens span chunk boundaries. +/// +/// **Layout compatibility:** the field order/types match `PrefillPlan` so that +/// downstream kernels (e.g. `fused_norm_rope` in `CompressExtend` mode) can +/// consume the compress_plan tensor as-if it were a `PrefillPlan` tensor -- +/// they only read `ragged_id` and `position`, both of which carry identical +/// semantics here (the LAST token of the segment in q-ragged and global +/// coordinates respectively). +/// +/// Note that `window_len` here means "number of real tokens in this segment" +/// (1..128), which differs from `PrefillPlan::window_len`. Downstream kernels +/// that share the tensor MUST NOT read it under that name. +struct alignas(16) OnlinePrefillPlan { + /// \brief Ragged-q position of the LAST token in this segment. + /// Equal to `segment_start_ragged + window_len - 1`. + uint32_t ragged_id; + /// \brief Index into the `indices` / `load_indices` arrays. + uint32_t batch_id; + /// \brief Global position of the LAST token in this segment. + /// For compress plans, `position % 128 == 127` (chunk-closing); for write + /// plans, `position % 128 < 127`. + uint32_t position; + /// \brief Number of real tokens in this segment (1..128). + /// The first segment token sits at `position - window_len + 1` (global) and + /// at `ragged_id - window_len + 1` (ragged). + uint32_t window_len; +}; + +static_assert(alignof(OnlinePrefillPlan) == alignof(PrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == sizeof(PrefillPlan)); + +} // namespace device::compress + +namespace host::compress { + +using device::compress::OnlinePrefillPlan; +using OnlinePrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kOnlinePrefillPlanDim = 16; + +static_assert(alignof(OnlinePrefillPlan) == sizeof(OnlinePrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == kOnlinePrefillPlanDim * sizeof(OnlinePrefillPlanTensorDtype)); + +} // namespace host::compress + +namespace { + +using OnlinePlan = device::compress::OnlinePrefillPlan; +using IndiceT = int32_t; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlineDecodeParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlinePrefillParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ load_indices; + /// \brief Plan for segments that close a chunk (write to `kv_compressed_output`). + /// Shape: `[num_compress, 16]` (uint8). + const OnlinePlan* __restrict__ compress_plan; + /// \brief Plan for the trailing partial segment of each batch (write back to + /// `kv_score_buffer`). Shape: `[num_write, 16]` (uint8). + const OnlinePlan* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +// 4 elements per thread, kHeadDim / 4 threads per block +template +__global__ void flash_c128_online_decode(const __grid_constant__ Compress128OnlineDecodeParams params) { + using namespace device; + constexpr uint32_t kVecSize = 4; + constexpr uint32_t kBlockSize = kHeadDim / kVecSize; + using Vec = AlignedVector; + const auto gmem = tile::Memory::cta(kBlockSize); + const auto batch_id = blockIdx.x; + const auto index = params.indices[batch_id]; + const auto seq_len = params.seq_lens[batch_id]; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kHeadDim * 3); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2); + + /// NOTE: kv_score_buffer layout is [max, sum, kv] (slot 0 / 1 / 2). Reads, + /// writes, and the prefill kernel must all agree on this order. + const auto max_score_vec = gmem.load(kv_buf, 0); + const auto sum_score_vec = gmem.load(kv_buf, 1); + const auto old_kv_vec = gmem.load(kv_buf, 2); + + /// NOTE: kv_score_input layout is | kv | score | (head_dim each), matching + /// the offline c128 kernel and the online prefill kernel. + const auto new_kv_vec = gmem.load(kv_src, 0); + const auto new_score_raw_vec = gmem.load(kv_src, 1); + + /// NOTE: the new token sits at global position `seq_len - 1`, so its + /// position inside the 128-chunk is `(seq_len - 1) % 128`. The previous + /// `seq_len % 128` was off by one (`bias[127]` vs `bias[0]`, etc.). + const auto pos_in_chunk = (seq_len - 1) % 128; + const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk); + + Vec out_kv_vec; + Vec out_max_vec; + Vec out_sum_vec; + if (pos_in_chunk != 0) { + // Mid-chunk: combine prior partial state with the new token via online softmax. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto old_max = max_score_vec[i]; + const auto old_kv = old_kv_vec[i]; + const auto new_score = new_score_raw_vec[i] + bias_vec[i]; + const auto new_kv = new_kv_vec[i]; + const auto new_max = fmax(old_max, new_score); + const auto old_sum = sum_score_vec[i] * expf(old_max - new_max); + const auto new_exp = expf(new_score - new_max); + const auto new_sum = old_sum + new_exp; + out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum; + out_max_vec[i] = new_max; + out_sum_vec[i] = new_sum; + } + } else { + // First token of a new 128-chunk: initialize state with this token alone. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + out_kv_vec[i] = new_kv_vec[i]; + out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i]; + out_sum_vec[i] = 1.0f; // exp(score - max) with max == score + } + } + + if (pos_in_chunk == 127) { + // Chunk just closed: emit the compressed kv. No need to update the buffer + // -- the next chunk's first token will overwrite it. + const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim; + gmem.store(kv_out, out_kv_vec); + } else { + // Otherwise persist the running [max, sum, kv] state for the next step. + gmem.store(kv_buf, out_max_vec, 0); + gmem.store(kv_buf, out_sum_vec, 1); + gmem.store(kv_buf, out_kv_vec, 2); + } +} + +constexpr int32_t kTileElements = 2; // split (along head-dim) +/// \brief Each warp will handle this many elements (split along softmax-128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps; +using PrefillStorage = device::AlignedVector; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_prefill_forward( + const PrefillStorage (&kv)[kElementsPerWarp], + const PrefillStorage (&score)[kElementsPerWarp], + float* kv_out, + float* max_out, + float* sum_out, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = score[j][i]; + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = global_product; + if constexpr (kNeedData) { + max_out[local_elem_id] = global_val_max; + sum_out[local_elem_id] = global_exp_sum; + } + } + if constexpr (kNeedData) __syncthreads(); +} + +/// \brief Sentinel score for padded positions in a 128-segment. +/// Must be finite so that `score - max` never produces NaN even when an +/// entire warp has only padded positions. +constexpr float kPadScore = -FLT_MAX; + +/// \brief Online compress 128 prefill. Two passes share this body: +/// - `kWrite=false` (compress pass): handles segments that close a chunk. +/// May load prior partial state from the buffer, but never writes to it, +/// so concurrent blocks can read the same slot without racing. +/// - `kWrite=true` (write pass): handles the trailing partial segment of each +/// batch. Each batch contributes at most one such plan, so concurrent blocks +/// touch disjoint buffer slots. +/// +/// The two passes MUST run as separate kernel launches (in stream order) so +/// that all reads in pass 1 finish before any writes in pass 2 start. +template +__global__ __launch_bounds__(kPrefillBlockSize, 2) // + void flash_c128_online_prefill(const __grid_constant__ Compress128OnlinePrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + /// NOTE: the compiler folds the if-else at compile time. + const auto num_plans = kWrite ? params.num_write : params.num_compress; + const auto plan_ptr = kWrite ? params.write_plan : params.compress_plan; + const uint32_t global_id = blockIdx.x; + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + if (global_pid >= num_plans) return; + const auto [ragged_id, batch_id, position, window_len] = plan_ptr[global_pid]; + if (ragged_id == 0xFFFFFFFFu) [[unlikely]] + return; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + const int32_t split_offset = global_sid * kTileDim; // int32 is enough + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_compressed_output = static_cast(params.kv_compressed_output); + const auto score_bias_base = static_cast(params.score_bias); + + constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score | + const uint32_t chunk_offset = (position % 128u) + 1u - window_len; + const uint32_t window_end = chunk_offset + window_len; // exclusive, in [1, 128] + const int32_t segment_start = ragged_id - (position % 128u); // can be negative, but safe + const int32_t load_index = chunk_offset != 0 ? params.load_indices[batch_id] : -1; + const int32_t store_index = kWrite ? params.indices[batch_id] : -1; + + PDLWaitPrimary(); + + // 2 * 8 = 16 register per elem. in theory we should consume 48 register here + PrefillStorage kv[kElementsPerWarp]; + PrefillStorage score[kElementsPerWarp]; + PrefillStorage bias[kElementsPerWarp]; + const auto warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + if (j >= chunk_offset && j < window_end) { + const auto kv_src_ptr = kv_score_input + (segment_start + j) * kElementSize + split_offset; + const auto score_src_ptr = kv_src_ptr + kHeadDim; + const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset; + kv[i].load(kv_src_ptr, lane_id); + score[i].load(score_src_ptr, lane_id); + bias[i].load(bias_src_ptr, lane_id); + } + } + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + const bool is_valid = (j >= chunk_offset && j < window_end); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore; + /// NOTE: must zero out kv on padded slots -- `c128_prefill_forward` + /// computes `kv * exp_score` where `exp_score = expf(-FLT_MAX - max) ??? 0`, + /// and IEEE-754 makes `NaN * 0 = NaN` / `+-inf * 0 = NaN`. An + /// uninitialized register can hold a NaN/inf bit pattern, so without + /// this reset a single padded warp can poison the whole softmax. + kv[i][ii] = is_valid ? kv[i][ii] : 0.0f; + } + } + + __shared__ alignas(16) float seg_kv[kTileDim]; + __shared__ alignas(16) float seg_max[kTileDim]; + __shared__ alignas(16) float seg_sum[kTileDim]; + + c128_prefill_forward(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id); + + PDLTriggerSecondary(); + + if (warp_id == 0) { + PrefillStorage out_kv_vec, out_max_vec, out_sum_vec; + out_kv_vec.load(seg_kv, lane_id); + out_max_vec.load(seg_max, lane_id); + out_sum_vec.load(seg_sum, lane_id); + if (chunk_offset != 0) { + /// NOTE: load (max, sum, kv) of the in-progress chunk for this index. + /// `load_indices` may differ from `indices` when the prior partial state + /// lives on a different slot than the slot we ultimately write to. + const auto buf_load = kv_score_buffer + load_index * (kHeadDim * 3) + split_offset; + PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec; + buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id); + buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id); + buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + const float m1 = buf_max_vec[ii]; + const float s1 = buf_sum_vec[ii]; + const float k1 = buf_kv_vec[ii]; + const float m2 = out_max_vec[ii]; + const float s2 = out_sum_vec[ii]; + const float k2 = out_kv_vec[ii]; + const float new_max = fmaxf(m1, m2); + const float new_s1 = s1 * expf(m1 - new_max); + const float new_s2 = s2 * expf(m2 - new_max); + const float new_sum = new_s1 + new_s2; + const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum; + out_max_vec[ii] = new_max; + out_sum_vec[ii] = new_sum; + out_kv_vec[ii] = new_kv; + } + } + + if constexpr (kWrite) { + const auto buf_store = kv_score_buffer + store_index * (kHeadDim * 3) + split_offset; + reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec; + reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec; + reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec; + } else { + const auto out_ptr = kv_compressed_output + ragged_id * kHeadDim + split_offset; + reinterpret_cast(out_ptr)[lane_id] = out_kv_vec; + } + } +} + +template +struct FlashCompress128OnlineKernel { + static constexpr auto decode_kernel = flash_c128_online_decode; + template + static constexpr auto prefill_kernel = flash_c128_online_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}).with_dtype().with_device(device).verify(indices); + TensorMatcher({B}).with_dtype().with_device(device).verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128OnlineDecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + LaunchKernel(batch_size, kDecodeBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + using host::compress::kOnlinePrefillPlanDim; + using host::compress::OnlinePrefillPlanTensorDtype; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) ??? 2D + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, kOnlinePrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, kOnlinePrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + /// NOTE: `extra` is `load_indices`. When the previous partial state lives + /// on a slot different from the destination slot (e.g. paged buffers), the + /// caller must supply this; otherwise it defaults to `indices`. + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}).with_dtype().with_device(device_).verify(load_indices); + + const auto device = device_.unwrap(); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128OnlinePrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + + /// NOTE: pass 1 reads the buffer (for the first segment of each batch + /// that started mid-chunk) and writes only to `kv_compressed_output`. + /// Pass 2 then writes the trailing partial state of each batch back to + /// the buffer. Stream serialization between the two launches enforces + /// read-before-write on shared buffer slots. + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = num_w * kNumSplit) { + LaunchKernel(num_w_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace + +namespace host::compress { + +using OnlinePlanResult = tvm::ffi::Tuple; + +struct OnlinePrefillCompressParams { + OnlinePrefillPlan* __restrict__ compress_plan; + OnlinePrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; +}; + +/// \brief Build the compress + write plans for online compress 128 prefill. +/// +/// Each batch's `[prefix_len, prefix_len + extend_len)` range is split at +/// 128-aligned boundaries. Every resulting segment falls into one of: +/// - **compress**: closes a 128-chunk (`chunk_offset + window_len == 128`). +/// These plans only read the buffer (when starting mid-chunk) and write the +/// compressed kv to `kv_compressed_output`. +/// - **write**: trailing partial of the batch (`chunk_offset + window_len < 128`). +/// May read the buffer and always writes the new partial state back to it. +/// Each batch produces at most one such plan. +/// +/// The two plans MUST be dispatched as separate kernel launches in stream +/// order so that pass-1 reads of a buffer slot complete before any pass-2 +/// write of the same slot. +inline OnlinePlanResult plan_online_prefill_host(const OnlinePrefillCompressParams& params, const bool use_cuda_graph) { + const auto& [compress_plan, write_plan, seq_lens, extend_lens, batch_size, num_tokens] = params; + + uint32_t counter = 0; + uint32_t compress_count = 0; + uint32_t write_count = 0; + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = static_cast(seq_lens[i]); + const uint32_t extend_len = static_cast(extend_lens[i]); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t end_pos = prefix_len + extend_len; + /// NOTE: split the extend range into per-128-chunk segments. Each segment + /// stays inside one chunk, so the kernel can decide load/store from + /// `chunk_offset` and `window_len` alone. + uint32_t pos = prefix_len; + while (pos < end_pos) { + const uint32_t chunk_start = (pos / 128u) * 128u; + const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive + const uint32_t seg_len = seg_end - pos; + const uint32_t chunk_off = pos - chunk_start; + /// NOTE: store last-token coordinates so that downstream consumers + /// (e.g. `fused_norm_rope`) can read `ragged_id` and `position` with the + /// same semantics as `PrefillPlan`. The segment start is recoverable as + /// `ragged_id - window_len + 1` and `position - window_len + 1`. + const uint32_t last_pos = seg_end - 1; + const uint32_t last_ragged = counter + (last_pos - prefix_len); + const auto plan = OnlinePrefillPlan{ + .ragged_id = last_ragged, + .batch_id = i, + .position = last_pos, + .window_len = seg_len, + }; + if (chunk_off + seg_len == 128u) { + // full chunk, must be complete, maybe read the buffer, no write + RuntimeCheck(compress_count < num_tokens); + compress_plan[compress_count++] = plan; + } else { + // last chunk, must be incomplete, maybe read the buffer, must write + RuntimeCheck(write_count < num_tokens); + write_plan[write_count++] = plan; + } + pos = seg_end; + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return OnlinePlanResult{compress_count, write_count}; + /// NOTE: pad both plans with sentinel entries so cuda-graph runs always see + /// the same number of blocks. The kernel skips plans whose `ragged_id` is -1. + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = OnlinePrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_count, num_tokens)) { + compress_plan[i] = kInvalidPlan; + } + for (const auto i : irange(write_count, num_tokens)) { + write_plan[i] = kInvalidPlan; + } + return OnlinePlanResult{num_tokens, num_tokens}; +} + +inline OnlinePlanResult plan_online_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + /// NOTE: only host (CPU/cuda-host) planning is implemented for now. The + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kOnlinePrefillPlanDim}) // + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + const auto params = OnlinePrefillCompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + }; + return plan_online_prefill_host(params, use_cuda_graph); +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_online_prefill = host::compress::plan_online_prefill; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh new file mode 100644 index 0000000000..71e600dc39 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh @@ -0,0 +1,875 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using PlanD = device::compress::DecodePlan; +using PlanC = device::compress::CompressPlan; + +// --------------------------------------------------------------------------- +// Decode kernel: 1 token / batch. Each block handles one batch. +// 4 elements per thread -> kBlockSize = head_dim / 4. +// --------------------------------------------------------------------------- + +struct Compress128OnlineDecodeParams { + void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3] + const void* __restrict__ kv_score_input; // [batch_size, head_dim * 2] + void* __restrict__ kv_compressed_output; // [batch_size, head_dim] + const void* __restrict__ score_bias; // [128, head_dim] + const PlanD* __restrict__ plan_d; + uint32_t batch_size; +}; + +template +__global__ void flash_c128_online_decode_v2(const __grid_constant__ Compress128OnlineDecodeParams params) { + using namespace device; + constexpr uint32_t kVecSize = 4; + constexpr uint32_t kBlockSize = kHeadDim / kVecSize; + using Vec = AlignedVector; + const auto gmem = tile::Memory::cta(kBlockSize); + const auto batch_id = blockIdx.x; + if (batch_id >= params.batch_size) return; + + // Wait for the plan-finalize kernel to publish `plan.read_page_0 / write_loc` + // before reading the plan. The plan kernel runs on the same stream and does + // NOT issue a PDL trigger, so launching this kernel with PDL means our + // pre-wait global reads can race with the plan kernel's writes. + PDLWaitPrimary(); + + const auto plan = params.plan_d[batch_id]; + const auto pos_in_chunk = (plan.seq_len - 1) % 128; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_load_buf = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3); + const auto kv_store_buf = kv_score_buffer + plan.write_loc * (kHeadDim * 3); + const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2); + + // Buffer layout: [max | sum | kv] (slot 0 / 1 / 2 of the head_dim*3 row). + const auto new_kv_vec = gmem.load(kv_src, 0); + const auto new_score_raw_vec = gmem.load(kv_src, 1); + const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk); + + Vec out_kv_vec; + Vec out_max_vec; + Vec out_sum_vec; + if (pos_in_chunk != 0) { + // Mid-chunk: combine prior partial state with the new token. + const auto max_score_vec = gmem.load(kv_load_buf, 0); + const auto sum_score_vec = gmem.load(kv_load_buf, 1); + const auto old_kv_vec = gmem.load(kv_load_buf, 2); +#pragma unroll + for (uint32_t i = 0; i < kVecSize; ++i) { + const auto old_max = max_score_vec[i]; + const auto old_kv = old_kv_vec[i]; + const auto new_score = new_score_raw_vec[i] + bias_vec[i]; + const auto new_kv = new_kv_vec[i]; + const auto new_max = fmaxf(old_max, new_score); + const auto old_sum = sum_score_vec[i] * expf(old_max - new_max); + const auto new_exp = expf(new_score - new_max); + const auto new_sum = old_sum + new_exp; + out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum; + out_max_vec[i] = new_max; + out_sum_vec[i] = new_sum; + } + } else { + // First token of a new chunk: state == this token alone. +#pragma unroll + for (uint32_t i = 0; i < kVecSize; ++i) { + out_kv_vec[i] = new_kv_vec[i]; + out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i]; + out_sum_vec[i] = 1.0f; + } + } + + if (pos_in_chunk == 127) { + // Chunk just closed: emit compressed kv, no buffer update. + const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim; + gmem.store(kv_out, out_kv_vec); + } else { + gmem.store(kv_store_buf, out_max_vec, 0); + gmem.store(kv_store_buf, out_sum_vec, 1); + gmem.store(kv_store_buf, out_kv_vec, 2); + } +} + +// --------------------------------------------------------------------------- +// Prefill kernel: 1 segment / block. Two passes (compress + write) share the +// kernel template, parameterized by `kWrite`. +// 16 warps per block; each warp handles 8 of the 128 chunk positions. +// --------------------------------------------------------------------------- + +constexpr int32_t kTileElements = 2; // split along head-dim +constexpr int32_t kElementsPerWarp = 8; // split along the 128-chunk +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps; +using PrefillStorage = device::AlignedVector; + +struct Compress128OnlinePrefillParams { + void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3] + const void* __restrict__ kv_score_input; // [num_q_tokens, head_dim * 2] + void* __restrict__ kv_compressed_output; // [num_compress, head_dim] + const void* __restrict__ score_bias; // [128, head_dim] + const PlanC* __restrict__ plan_c; // close-chunk segments + const PlanC* __restrict__ plan_w; // trailing partial segments + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // +1 to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +/// \brief Sentinel score for padded positions in a 128-segment. +constexpr float kPadScore = -FLT_MAX; + +[[maybe_unused]] +SGL_DEVICE void c128_prefill_segment_softmax( + const PrefillStorage (&kv)[kElementsPerWarp], + const PrefillStorage (&score)[kElementsPerWarp], + float* seg_kv, + float* seg_max, + float* seg_sum, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + // Per-warp running state (max, sum, kv) for kTileElements head-dim slots. + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = score[j][i]; + } + float max_value = score_fp32[0]; +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + max_value = fmaxf(max_value, score_fp32[j]); + } + float sum_exp_value = 0.0f; + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + const auto exp_score = expf(score_fp32[j] - max_value); + sum_product += kv[j][i] * exp_score; + sum_exp_value += exp_score; + } + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // Aligned writes (no bank conflict thanks to `+1` padding). + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + // Cross-warp reduction. Same recipe as c128_online.cuh: each block-thread + // pair reduces a (tile_id, lane_id) slot using a kNumWarps-wide warp shuffle. + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize; + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id; + const uint32_t local_warp_id = j % kNumWarps; + const uint32_t local_elem_id = j / kNumWarps; + const uint32_t local_tile_id = local_elem_id % kTileElements; + const uint32_t local_lane_id = local_elem_id / kTileElements; + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + seg_kv[local_elem_id] = global_product; + seg_max[local_elem_id] = global_val_max; + seg_sum[local_elem_id] = global_exp_sum; + } + __syncthreads(); +} + +/// \brief Online compress 128 prefill v2. +/// +/// `kWrite=false` (compress pass): handles segments that close a 128-chunk. +/// Reads optional prior state from `read_page_0` (-1 = none), emits compressed +/// kv to `kv_compressed_output[plan_id]` (compact). +/// `kWrite=true` (write pass) : handles trailing partial segments. +/// Reads optional prior state from `read_page_0` (-1 = none), writes new +/// running state to `read_page_1`. +template +__global__ __launch_bounds__(kPrefillBlockSize, 2) // + void flash_c128_online_prefill_v2(const __grid_constant__ Compress128OnlinePrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0); + + // Compile-time fold to the right plan list. + const auto num_plans = kWrite ? params.num_write : params.num_compress; + const auto plan_ptr = kWrite ? params.plan_w : params.plan_c; + const uint32_t global_id = blockIdx.x; + const uint32_t global_pid = global_id / kNumSplit; + const uint32_t global_sid = global_id % kNumSplit; + if (global_pid >= num_plans) return; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + const int32_t split_offset = global_sid * kTileDim; + + // The previous kernel (plan-finalize stage 1) does NOT issue a PDL trigger, + // so PDLWaitPrimary effectively waits for stage 1 to complete. Read the plan + // AFTER the wait so the freshly-written `read_page_0` (= state-pool slot) is + // visible. Reading it before the wait is a real race -- with PDL enabled the + // kernel can begin executing before stage 1's stores propagate, and we'd see + // the stage-0 batch_id placeholder in `read_page_0` instead of the slot. + PDLWaitPrimary(); + + const auto plan = plan_ptr[global_pid]; + if (plan.is_invalid()) [[unlikely]] + return; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_compressed_output = static_cast(params.kv_compressed_output); + const auto score_bias_base = static_cast(params.score_bias); + + constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score | + + // The plan stores last-token coordinates; segment start is recoverable as + // ragged_id - window_len + 1. + const uint32_t window_len = plan.buffer_len; + const uint32_t position = plan.seq_len - 1; + const uint32_t pos_in_chunk_end = (position % 128u) + 1u; // exclusive, in [1, 128] + const uint32_t chunk_offset = pos_in_chunk_end - window_len; // in [0, 127] + const int32_t segment_start_ragged = static_cast(plan.ragged_id) - static_cast(position % 128u); + + // --- Stage 1: load kv / score / bias for this warp's 8 chunk positions. + PrefillStorage kv[kElementsPerWarp]; + PrefillStorage score[kElementsPerWarp]; + PrefillStorage bias[kElementsPerWarp]; + const uint32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + if (j >= chunk_offset && j < pos_in_chunk_end) { + const auto kv_src_ptr = kv_score_input + (segment_start_ragged + j) * kElementSize + split_offset; + const auto score_src_ptr = kv_src_ptr + kHeadDim; + const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset; + kv[i].load(kv_src_ptr, lane_id); + score[i].load(score_src_ptr, lane_id); + bias[i].load(bias_src_ptr, lane_id); + } + } + + // --- Stage 2: pad invalid positions. score = -FLT_MAX, kv = 0 (so that + // kv * exp(score-max) ??? 0 / 0 cleanly without producing NaN/inf). +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + const bool is_valid = (j >= chunk_offset && j < pos_in_chunk_end); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore; + kv[i][ii] = is_valid ? kv[i][ii] : 0.0f; + } + } + + // --- Stage 3: warp-tile online softmax over the 128-position chunk. + __shared__ alignas(16) float seg_kv[kTileDim]; + __shared__ alignas(16) float seg_max[kTileDim]; + __shared__ alignas(16) float seg_sum[kTileDim]; + c128_prefill_segment_softmax(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id); + + PDLTriggerSecondary(); + + // --- Stage 4: warp 0 folds with prior partial state (if any) and writes. + if (warp_id == 0) { + PrefillStorage out_kv_vec, out_max_vec, out_sum_vec; + out_kv_vec.load(seg_kv, lane_id); + out_max_vec.load(seg_max, lane_id); + out_sum_vec.load(seg_sum, lane_id); + + if (chunk_offset != 0 && plan.read_page_0 >= 0) { + // Combine with prior partial state for this slot. + const auto buf_load = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset; + PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec; + buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id); + buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id); + buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + const float m1 = buf_max_vec[ii]; + const float s1 = buf_sum_vec[ii]; + const float k1 = buf_kv_vec[ii]; + const float m2 = out_max_vec[ii]; + const float s2 = out_sum_vec[ii]; + const float k2 = out_kv_vec[ii]; + const float new_max = fmaxf(m1, m2); + const float new_s1 = s1 * expf(m1 - new_max); + const float new_s2 = s2 * expf(m2 - new_max); + const float new_sum = new_s1 + new_s2; + const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum; + out_max_vec[ii] = new_max; + out_sum_vec[ii] = new_sum; + out_kv_vec[ii] = new_kv; + } + } + + if constexpr (kWrite) { + // For trailing-partial segments the load and store slots collapse to the + // segment's own chunk slot (the request keeps a single in-progress + // chunk's running state at any time), so we reuse `read_page_0`. + const auto buf_store = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset; + reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec; + reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec; + reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec; + } else { + // Compact output: one row per compress plan, indexed by `global_pid`. + const auto out_ptr = kv_compressed_output + global_pid * kHeadDim + split_offset; + reinterpret_cast(out_ptr)[lane_id] = out_kv_vec; + } + } +} + +// --------------------------------------------------------------------------- +// Host wrapper: matches the c128_v2 / c4_v2 host API style (run_decode / +// run_prefill methods on a kernel-class template). We only expose `kHeadDim` +// + `kUsePDL`; the dtype is fixed to fp32 for the online state pool. +// --------------------------------------------------------------------------- + +template +struct FlashCompress128OnlineKernel { + static constexpr auto decode_kernel = flash_c128_online_decode_v2; + template + static constexpr auto prefill_kernel = flash_c128_online_prefill_v2; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_d_) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output (sparse by batch_id) + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_d = compress::verify_plan_d(plan_d_, B, device_); + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + const auto params = Compress128OnlineDecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_d = plan_d, + .batch_size = batch_size, + }; + LaunchKernel(batch_size, kDecodeBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_c_, + const tvm::ffi::TensorView plan_w_) { + using namespace host; + + auto N = SymbolicSize{"num_q_tokens"}; + auto C = SymbolicSize{"num_c_plans"}; + auto W = SymbolicSize{"num_w_plans"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input (ragged) + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({C, kHeadDim}) // kv compressed output (compact, by plan_c index) + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + // Both compress and write segments use PlanC layout. plan_c uses + // read_page_1=-1 (unused); plan_w uses read_page_1=store_slot. + const auto plan_c = compress::verify_plan_c(plan_c_, C, device_); + const auto plan_w = compress::verify_plan_c(plan_w_, W, device_); + const auto device = device_.unwrap(); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(C.unwrap()); + const auto num_w = static_cast(W.unwrap()); + RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w"); + const auto params = Compress128OnlinePrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_c = plan_c, + .plan_w = plan_w, + .num_compress = num_c, + .num_write = num_w, + }; + + // The two passes MUST be serialized in stream order: pass 1 reads slots + // that pass 2 may write to; running them in parallel would race. + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_kernel, params); + } + if (const auto num_w_blocks = num_w * kNumSplit) { + LaunchKernel(num_w_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_kernel, params); + } + } +}; + +} // namespace + +// =========================================================================== +// Plan builders. Mirrors the offline v2 pattern (`c_plan.cuh`): +// - Decode: a single GPU kernel reads seq_lens / req_to_token / +// req_pool_indices on device and emits the final PlanD tensor in one go. +// - Prefill: stage 0 (host, on CPU pinned memory) splits each batch's +// extend range into per-chunk segments and emits PlanC entries with the +// batch_id stashed in `read_page_0` as a placeholder. Stage 1 is a tiny +// GPU kernel that finalizes `read_page_0` to `req_to_token[rid][chunk_start]`, +// so the slot tensors never leave GPU memory. The online state pool keeps +// a single in-progress chunk per request, so each segment's load and +// store slot collapse to one value (the slot for the segment's own chunk), +// and `read_page_1` is unused. +// =========================================================================== + +namespace host::compress { + +using device::compress::CompressPlan; +using device::compress::DecodePlan; + +// --------------------------------------------------------------------------- +// Decode plan builder. +// --------------------------------------------------------------------------- + +struct OnlineDecodePlanParams { + DecodePlan* __restrict__ plan_d; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ req_pool_indices; + const int32_t* __restrict__ req_to_token; + const int64_t* __restrict__ full_to_swa; // (full_cache_size,) int64 + int64_t stride_r2t; + int32_t swa_page_size; + uint32_t batch_size; +}; + +__global__ void plan_c128_online_decode_kernel(const OnlineDecodePlanParams params) { + const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.batch_size) return; + const auto seq_len = static_cast(params.seq_lens[idx]); + const auto rid = params.req_pool_indices[idx]; + const int32_t chunk_start = static_cast((seq_len - 1u) / 128u * 128u); + const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start]; + const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]); + const int32_t slot = swa_loc / params.swa_page_size; + params.plan_d[idx] = DecodePlan{ + .seq_len = seq_len, + .write_loc = slot, + .read_page_0 = slot, + .read_page_1 = -1, + }; +} + +/// \brief Build the decode plan tensor. Caller (Python) pre-allocates +/// `plan_d_dev` as a `(batch_size, 16)` device uint8 tensor; this routine +/// only fills it. See `plan_online_prefill` for the rationale (avoid +/// `ffi::empty` + dlpack roundtrip / PyTorch caching-allocator stream +/// tracking issue that surfaces as IMA in unrelated downstream kernels). +inline void plan_online_decode( + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView req_pool_indices, + const tvm::ffi::TensorView req_to_token, + const tvm::ffi::TensorView full_to_swa, + const tvm::ffi::TensorView plan_d_dev_, + const int32_t swa_page_size) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + auto seq_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(seq_dtype) + .with_device(device_) + .verify(seq_lens); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({B, sizeof(DecodePlan)}) // + .with_dtype() + .with_device(device_) + .verify(plan_d_dev_); + RuntimeCheck(swa_page_size > 0); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto device = device_.unwrap(); + constexpr uint32_t kBlockSize = 256; + const uint32_t num_blocks = host::div_ceil(batch_size, kBlockSize); + const auto stride_r2t = req_to_token.stride(0); + const auto params = OnlineDecodePlanParams{ + .plan_d = static_cast(plan_d_dev_.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .req_pool_indices = static_cast(req_pool_indices.data_ptr()), + .req_to_token = static_cast(req_to_token.data_ptr()), + .full_to_swa = static_cast(full_to_swa.data_ptr()), + .stride_r2t = stride_r2t, + .swa_page_size = swa_page_size, + .batch_size = batch_size, + }; + LaunchKernel(num_blocks, kBlockSize, device)(plan_c128_online_decode_kernel, params); +} + +// --------------------------------------------------------------------------- +// Prefill plan builder: host stage 0 + GPU stage 1. +// --------------------------------------------------------------------------- + +struct OnlinePrefillStage0Params { + CompressPlan* __restrict__ plan_c; + CompressPlan* __restrict__ plan_w; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_q_tokens; +}; + +inline std::tuple _plan_prefill_partial(const OnlinePrefillStage0Params& p) { + uint32_t counter = 0; + uint32_t compress_count = 0; + uint32_t write_count = 0; + for (const auto i : irange(p.batch_size)) { + const uint32_t seq_len = static_cast(p.seq_lens[i]); + const uint32_t extend_len = static_cast(p.extend_lens[i]); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t end_pos = prefix_len + extend_len; + + uint32_t pos = prefix_len; + while (pos < end_pos) { + const uint32_t chunk_start = (pos / 128u) * 128u; + const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive + const uint32_t seg_len = seg_end - pos; + const uint32_t chunk_off = pos - chunk_start; + const uint32_t last_pos = seg_end - 1; + const uint32_t last_ragged = counter + (last_pos - prefix_len); + RuntimeCheck(last_ragged < (1u << 16), "PlanC.ragged_id is uint16; ragged ", last_ragged, " overflows"); + RuntimeCheck(seg_len <= 128u); + // Stash batch_id in `read_page_0` for stage 1 to translate. A + // chunk-aligned segment never loads, so we still need stage 1 to fill + // a slot in -- the kernel keys the load on `chunk_offset != 0`. + const auto plan = CompressPlan{ + .seq_len = last_pos + 1u, + .ragged_id = static_cast(last_ragged), + .buffer_len = static_cast(seg_len), + .read_page_0 = static_cast(i), // batch_id placeholder + .read_page_1 = -1, // unused, kept so MSB layout is stable + }; + if (chunk_off + seg_len == 128u) { + // close-chunk segment + RuntimeCheck(compress_count < p.num_q_tokens); + p.plan_c[compress_count++] = plan; + } else { + // trailing partial segment + RuntimeCheck(write_count < p.num_q_tokens); + p.plan_w[write_count++] = plan; + } + pos = seg_end; + } + counter += extend_len; + } + RuntimeCheck(counter == p.num_q_tokens, "input size ", counter, " != num_q_tokens ", p.num_q_tokens); + return std::tuple{compress_count, write_count}; +} + +struct OnlinePrefillStage1Params { + CompressPlan* __restrict__ plan_c; + CompressPlan* __restrict__ plan_w; + const int64_t* __restrict__ req_pool_indices; // (batch_size,) + const int32_t* __restrict__ req_to_token; // (num_reqs, max_tokens) + const int64_t* __restrict__ full_to_swa; // (full_cache_size,) + int64_t stride_r2t; + int32_t swa_page_size; + uint32_t num_c; + uint32_t num_w; +}; + +__global__ void plan_c128_online_prefill_kernel(const OnlinePrefillStage1Params params) { + const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t total = params.num_c + params.num_w; + if (idx >= total) return; + + const bool is_compress = idx < params.num_c; + CompressPlan* const plan_ptr = is_compress ? ¶ms.plan_c[idx] : ¶ms.plan_w[idx - params.num_c]; + auto plan = *plan_ptr; + const auto batch_id = plan.read_page_0; + const auto rid = params.req_pool_indices[batch_id]; + const int32_t position = static_cast(plan.seq_len - 1u); + const int32_t chunk_start = (position / 128) * 128; + const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start]; + const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]); + plan.read_page_0 = swa_loc / params.swa_page_size; + *plan_ptr = plan; +} + +using OnlinePrefillPlan = tvm::ffi::Tuple; + +inline OnlinePrefillPlan plan_online_prefill( + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView req_pool_indices, + const tvm::ffi::TensorView req_to_token, + const tvm::ffi::TensorView full_to_swa, + const tvm::ffi::TensorView plan_c_pin, + const tvm::ffi::TensorView plan_w_pin, + const tvm::ffi::TensorView plan_c_dev_, + const tvm::ffi::TensorView plan_w_dev_, + const int32_t swa_page_size) { + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto cpu = SymbolicDevice{}; + auto device_ = SymbolicDevice{}; + cpu.set_options(); + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(cpu) + .verify(seq_lens) + .verify(extend_lens); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({N, sizeof(CompressPlan)}) // + .with_dtype() + .with_device(cpu) + .verify(plan_c_pin) + .verify(plan_w_pin); + TensorMatcher({N, sizeof(CompressPlan)}) // + .with_dtype() + .with_device(device_) + .verify(plan_c_dev_) + .verify(plan_w_dev_); + + const auto stage0_params = OnlinePrefillStage0Params{ + .plan_c = static_cast(plan_c_pin.data_ptr()), + .plan_w = static_cast(plan_w_pin.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(B.unwrap()), + .num_q_tokens = static_cast(N.unwrap()), + }; + + // Debug instrumentation: SGLANG_DEBUG_C128_ONLINE_GUARD=1 wraps stage 0 + // with redzone + post-write magic-check on the pin buffers, plus a strict + // upper-bound check on `batch_size` and `num_q_tokens`. If stage 0 has a + // CPU OOB this trips a clear panic at the offending byte instead of a + // delayed CUDA IMA from corrupted heap memory. + static const bool kGuard = []() { + const char* v = std::getenv("SGLANG_DEBUG_C128_ONLINE_GUARD"); + return v != nullptr && v[0] == '1'; + }(); + if (kGuard) { + RuntimeCheck(stage0_params.batch_size <= 65536u, "batch_size out of bound: ", stage0_params.batch_size); + RuntimeCheck(stage0_params.num_q_tokens <= 65536u, "num_q_tokens out of bound: ", stage0_params.num_q_tokens); + // Stamp the pin buffers with 0xAB so we can detect any byte still 0xAB + // beyond what stage 0 should have written (= OOB never reached, that's fine) + // or any byte BEYOND num_q_tokens*16 written to (= true OOB into + // adjacent allocation). + auto* pc = static_cast(plan_c_pin.data_ptr()); + auto* pw = static_cast(plan_w_pin.data_ptr()); + const auto bytes = static_cast(N.unwrap()) * sizeof(CompressPlan); + std::memset(pc, 0xAB, bytes); + std::memset(pw, 0xAB, bytes); + } + + const auto [num_c, num_w] = _plan_prefill_partial(stage0_params); + + if (kGuard) { + // Verify stage 0 wrote ONLY to the [0, num_c*16) and [0, num_w*16) prefix. + auto* pc = static_cast(plan_c_pin.data_ptr()); + auto* pw = static_cast(plan_w_pin.data_ptr()); + const auto end_c = static_cast(num_c) * sizeof(CompressPlan); + const auto end_w = static_cast(num_w) * sizeof(CompressPlan); + const auto pin_bytes = static_cast(N.unwrap()) * sizeof(CompressPlan); + for (size_t k = end_c; k < pin_bytes; ++k) { + RuntimeCheck( + pc[k] == 0xAB, + "GUARD: plan_c_pin OOB write at byte ", + k, + " (num_c=", + num_c, + ", num_q_tokens=", + N.unwrap(), + ")"); + } + for (size_t k = end_w; k < pin_bytes; ++k) { + RuntimeCheck( + pw[k] == 0xAB, + "GUARD: plan_w_pin OOB write at byte ", + k, + " (num_w=", + num_w, + ", num_q_tokens=", + N.unwrap(), + ")"); + } + } + + const auto device = device_.unwrap(); + // Out-params pre-allocated by Python. Cast to typed pointers for use. + auto* const plan_c_dev_ptr = static_cast(plan_c_dev_.data_ptr()); + auto* const plan_w_dev_ptr = static_cast(plan_w_dev_.data_ptr()); + + if (const auto total = num_c + num_w) { + const auto stream = LaunchKernel::resolve_device(device); + // SGLANG_DEBUG_C128_ONLINE_SYNC_H2D=1 forces a synchronous H2D copy. + static const bool kSyncH2D = []() { + const char* v = std::getenv("SGLANG_DEBUG_C128_ONLINE_SYNC_H2D"); + return v != nullptr && v[0] == '1'; + }(); + // SGLANG_DEBUG_C128_ONLINE_NO_H2D=1 skips the H2D copy entirely (debug only). + static const bool kNoH2D = []() { + const char* v = std::getenv("SGLANG_DEBUG_C128_ONLINE_NO_H2D"); + return v != nullptr && v[0] == '1'; + }(); + const auto copy_to_device = [stream](void* dst, void* src, int64_t count) { + if (kNoH2D) return; + const auto bytes = count * sizeof(CompressPlan); + if (kSyncH2D) { + RuntimeDeviceCheck(::cudaMemcpy(dst, src, bytes, ::cudaMemcpyHostToDevice)); + } else { + RuntimeDeviceCheck(::cudaMemcpyAsync(dst, src, bytes, ::cudaMemcpyHostToDevice, stream)); + } + }; + if (num_c) copy_to_device(plan_c_dev_ptr, plan_c_pin.data_ptr(), num_c); + if (num_w) copy_to_device(plan_w_dev_ptr, plan_w_pin.data_ptr(), num_w); + + const auto stage1_params = OnlinePrefillStage1Params{ + .plan_c = plan_c_dev_ptr, + .plan_w = plan_w_dev_ptr, + .req_pool_indices = static_cast(req_pool_indices.data_ptr()), + .req_to_token = static_cast(req_to_token.data_ptr()), + .full_to_swa = static_cast(full_to_swa.data_ptr()), + .stride_r2t = req_to_token.stride(0), + .swa_page_size = swa_page_size, + .num_c = num_c, + .num_w = num_w, + }; + constexpr uint32_t kBlockSize = 128; + const auto num_blocks = host::div_ceil(total, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device)(plan_c128_online_prefill_kernel, stage1_params); + } + return OnlinePrefillPlan{num_c, num_w}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_128_online_decode = host::compress::plan_online_decode; +[[maybe_unused]] +constexpr auto& plan_compress_128_online_prefill = host::compress::plan_online_prefill; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh new file mode 100644 index 0000000000..31353e6a15 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh @@ -0,0 +1,448 @@ +/** + * \brief Here's some dimension info for the main buffer used in C128 prefill and decode. + * + * kv_buffer: [num_indices, 128, head_dim * 2] + * - last dimension layout: | kv | score | + * kv_input: [batch_size, head_dim * 2] + * kv_output: [batch_size, head_dim] + * score_bias (ape): [128, head_dim] + * plan_c/plan_w: [variable length] + * + * For prefill, batch_size = num_q_tokens + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using PlanD = device::compress::DecodePlan; +using PlanC = device::compress::CompressPlan; +using PlanW = device::compress::WritePlan; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; +constexpr uint32_t kWriteBlockSize = 128; // one warp per write + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) +#define WRITE_KERNEL __global__ __launch_bounds__(kWriteBlockSize, 16) + +struct Compress128DecodeParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanD* __restrict__ plan_d; + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanC* __restrict__ plan_c; + const PlanW* __restrict__ plan_w; + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +struct C128Trait { + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr int64_t kHeadDim = kHeadDim_; + static constexpr int64_t kScoreOffset = kHeadDim; + static constexpr int64_t kElementSize = kHeadDim * 2; + static constexpr int64_t kPageElementSize = 128 * kElementSize; // page size = 128 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0); +}; + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_buf, // [128n, 128n + 127] + const InFloat* kv_src, // ragged pointer at position = 128n + 127 + OutFloat* kv_out, + const InFloat* score_bias, + const int32_t buffer_len) { + using namespace device; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * Trait::kHeadDim); + } + + const auto kv_start = kv_src - 127 * Trait::kElementSize; // point to start + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + __builtin_assume(j < 128); + const auto src = j < buffer_len ? kv_buf : kv_start; + kv[i] = gmem_in.load(src + j * Trait::kElementSize); + score[i] = gmem_in.load(src + j * Trait::kElementSize + Trait::kScoreOffset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + + float score_fp32[kTileElements][kElementsPerWarp]; + + // convert to fp32 and apply bias first +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[i][j] = cast(score[j][i]) + cast(bias[j][i]); + } + } + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + const auto& score = score_fp32[i]; + float max_value = score[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + + PDLTriggerSecondary(); + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +SGL_DEVICE void c128_write_decode(InFloat* kv_buf, const InFloat* kv_src) { + using namespace device; + + using Storage = AlignedVector; + const auto gmem = tile::Memory::warp(); + + Storage data[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + data[i] = gmem.load(kv_src + Trait::kHeadDim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_buf + Trait::kHeadDim * i, data[i]); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + using Trait = C128Trait; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t global_bid = blockIdx.x / Trait::kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_bid >= params.batch_size) return; + + const auto plan = params.plan_d[global_bid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + + const auto kv_src = kv_input + global_bid * Trait::kElementSize; + const auto kv_out = kv_output + global_bid * Trait::kHeadDim; + const auto kv_buf = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + const auto kv_dst = kv_buffer + plan.write_loc * Trait::kElementSize; + + PDLWaitPrimary(); + // the write warp must match the load warp in the following `c128_forward` + if (warp_id == kNumWarps - 1) { + c128_write_decode(kv_dst, kv_src); + } + if (plan.write_loc % 128 == 127) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, 128); + } +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + using Trait = C128Trait; + + const uint32_t global_pid = blockIdx.x / Trait::kNumSplit; // plan id + const uint32_t global_sid = blockIdx.x % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_pid >= params.num_compress) return; + + const auto plan = params.plan_c[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + if (plan.is_invalid()) return; + + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + // Compact output: one row per compress plan, indexed by `global_pid`. + const auto kv_out = kv_output + global_pid * Trait::kHeadDim; + const auto kv_buf = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + PDLWaitPrimary(); + c128_forward(kv_buf, kv_src, kv_out, score_bias, plan.buffer_len); +} + +template +WRITE_KERNEL void write_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + using Trait = C128Trait; + using StorageIn = AlignedVector; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / Trait::kNumSplit; // plan id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + // split the contiguous `kHeadDim * 2` into `kNumSplit` tiles + // each warp handles 1 contiguous tile (in contrast, decode handle the strided head_dim) + const int64_t split_offset = global_sid * (Trait::kTileDim * 2); + if (global_pid >= params.num_write) return; + + const auto plan = params.plan_w[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + if (plan.is_invalid()) return; + + // each warp will handle a contiguous region + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + const auto kv_buf = kv_buffer + plan.write_loc * Trait::kElementSize; + const auto gmem = tile::Memory::warp(); + + PDLWaitPrimary(); + StorageIn data[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + data[i] = gmem.load(kv_src, i); + } + PDLTriggerSecondary(); +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_buf, data[i], i); + } +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + static constexpr auto prefill_c_kernel = flash_c128_prefill; + static constexpr auto prefill_w_kernel = write_c128_prefill; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + using Trait = C128Trait; + + static void run_decode( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_d_) { + using namespace host; + + auto N = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_d = compress::verify_plan_d(plan_d_, N, device_); + const auto batch_size = static_cast(N.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_d = plan_d, + .batch_size = batch_size, + }; + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_c_, + const tvm::ffi::TensorView plan_w_) { + using namespace host; + + auto N = SymbolicSize{"num_q_tokens"}; + auto C = SymbolicSize{"num_c_plans"}; + auto W = SymbolicSize{"num_w_plans"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input (ragged) + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({C, kHeadDim}) // kv compressed output (compact) + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_c = compress::verify_plan_c(plan_c_, C, device_); + const auto plan_w = compress::verify_plan_w(plan_w_, W, device_); + const auto device = device_.unwrap(); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(C.unwrap()); + const auto num_w = static_cast(W.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_c = plan_c, + .plan_w = plan_w, + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w"); + if (const auto num_c_blocks = num_c * kNumSplit) { + constexpr auto kBlockSize_C = kBlockSize; + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + constexpr auto kBlockSize_W = kWriteBlockSize; + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh new file mode 100644 index 0000000000..145ab1fb08 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh @@ -0,0 +1,549 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan4 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) + +enum class PageMode { + RingBuffer = 8, + Page4Align = 4, +}; + +struct alignas(16) C4IndexBundle { + int32_t load_first_page; + int32_t load_second_page; + int32_t write_first_page; + int32_t last_position; +}; + +struct Compress4DecodeParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \brief Shape: `[batch_size, 1]` */ + const int32_t* __restrict__ extra; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, 4]` */ + const C4IndexBundle* __restrict__ extra; + /** \brief The following part is plan info. */ + + const Plan4* __restrict__ compress_plan; + const Plan4* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +template +SGL_DEVICE void c4_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 4; + const auto gmem = tile::Memory::warp(); + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv overlap | [1] = kv | [2] = score overlap | [3] = score | + Storage kv_score[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t seq_len, + const int32_t window_len, + [[maybe_unused]] const InFloat* kv_score_overlap_buf = nullptr) { + using namespace device; + + const auto element_size = head_dim * 4; + const auto score_offset = head_dim * 2; + const auto overlap_stride = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const bool is_overlap = i < 4; + const InFloat* src; + if (i < window_len) { + /// NOTE: `seq_len` must be a multiple of 4 here + if constexpr (kPaged) { + const auto kv_score_ptr = is_overlap ? kv_score_overlap_buf : kv_score_buf; + const int32_t k = i % 4; + src = kv_score_ptr + k * element_size; + } else { + const int32_t k = (seq_len + i) % 8; + src = kv_score_buf + k * element_size; + } + } else { + /// NOTE: k in [-7, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = i - 7; + src = kv_score_src + k * element_size; + } + src += (is_overlap ? 0 : overlap_stride); + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + if (seq_len == 4) { + [[unlikely]]; + constexpr float kFloatNegInf = -1e9f; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[8]; + +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + gmem_out.store(kv_out, result); +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, extra, batch_size // decode info + ] = params; + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / kNumSplit; // batch id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + page_size - 1` + if constexpr (kMode == PageMode::Page4Align) { + const auto index_prev = extra[global_bid]; + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 3) % 4); + if (seq_len % 4 == 0) { + const auto kv_overlap = kv_buf + (index_prev - index) * (kElementSize * 4); + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, 8, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 7) % 8); + if (seq_len % 4 == 0) { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, /*window_size=*/8); + } + } + + PDLTriggerSecondary(); +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, extra, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / kNumSplit; // plan id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + const auto seq_len = position + 1; + const int32_t index = indices[global_bid]; + + PDLWaitPrimary(); + + if constexpr (kMode == PageMode::Page4Align) { + const auto write_second_page = index; + const auto [load_first_page, load_second_page, write_first_page, last_pos] = extra[global_bid]; + if constexpr (kWrite) { + int32_t index; + if (position < static_cast(last_pos)) { + index = write_first_page; + } else { + index = write_second_page; + } + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 4); + } else { + int32_t index_overlap, index_normal; + if (window_len <= 4) { + index_overlap = load_second_page; + index_normal = load_second_page; // not used + } else { + index_overlap = load_first_page; + index_normal = load_second_page; + } + const auto kv_buf = kv_score_buffer + index_normal * (kElementSize * 4) + split_offset; + const auto kv_overlap = kv_score_buffer + index_overlap * (kElementSize * 4) + split_offset; + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + if constexpr (kWrite) { + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 8); + } else { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len); + } + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress4Kernel { + template + static constexpr auto decode_kernel = flash_c4_decode; + template + static constexpr auto prefill_kernel = flash_c4_prefill; + template + static constexpr auto prefill_c_kernel = prefill_kernel; + template + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + + using Self = FlashCompress4Kernel; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional extra) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extra = static_cast(extra_ptr), + .batch_size = batch_size, + }; + const auto kernel = extra_ptr != nullptr ? decode_kernel // + : decode_kernel; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra, /*is_prefill=*/true); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .extra = static_cast(extra_ptr), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + const auto c_kernel = extra_ptr != nullptr ? prefill_c_kernel // + : prefill_c_kernel; + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + const auto w_kernel = extra_ptr != nullptr ? prefill_w_kernel // + : prefill_w_kernel; + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(w_kernel, params); + } + } + + // some auxiliary functions + private: + static const void* _get_extra_pointer( + host::SymbolicSize& B, // batch_size + host::SymbolicDevice& device, + const tvm::ffi::Optional& extra, + bool is_prefill = false) { + // only have value when using page-aligned mode + if (!extra.has_value()) return nullptr; + const auto& extra_tensor = extra.value(); + /// NOTE: the metadata layout is different for prefill and decode: + /// for prefill, last 4 are: + /// load overlap | load normal | write overlap | last written page + /// for decode, last 1 is the write (also load) overlap + host::TensorMatcher({B, is_prefill ? 4 : 1}) // extra tensor + .with_dtype() + .with_device(device) + .verify(extra_tensor); + const auto data_ptr = extra_tensor.data_ptr(); + host::RuntimeCheck(data_ptr != nullptr, "extra tensor data ptr is null"); + if (is_prefill) { + static_assert(alignof(C4IndexBundle) == 16); + host::RuntimeCheck(std::bit_cast(data_ptr) % 16 == 0, "extra tensor is not properly aligned"); + } + return data_ptr; + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh new file mode 100644 index 0000000000..efa9f05100 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh @@ -0,0 +1,405 @@ +/** + * \brief Here's some dimension info for the main buffer used in C4 prefill and decode. + * + * kv_buffer: [num_indices, 8, head_dim * 4] + * - last dimension layout: | kv overlap | kv | score overlap | score | + * kv_input: [batch_size, head_dim * 4] + * kv_output: [batch_size, head_dim] + * score_bias (ape): [8, head_dim] + * plan_c/plan_w: [variable length] + * + * For prefill, batch_size = num_q_tokens + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include + +namespace { + +using PlanD = device::compress::DecodePlan; +using PlanC = device::compress::CompressPlan; +using PlanW = device::compress::WritePlan; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) +#define WRITE_KERNEL __global__ __launch_bounds__(128, 16) + +struct Compress4DecodeParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanD* __restrict__ plan_d; + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanC* __restrict__ plan_c; + const PlanW* __restrict__ plan_w; + uint32_t num_compress; + uint32_t num_write; +}; + +template +struct C4Trait { + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 128 + static constexpr int64_t kHeadDim = kHeadDim_; + static constexpr int64_t kOverlapOffset = kHeadDim; + static constexpr int64_t kScoreOffset = kHeadDim * 2; + static constexpr int64_t kElementSize = kHeadDim * 4; + static constexpr int64_t kPageElementSize = 4 * kElementSize; // page size = 4 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0); +}; + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_buf_0, // overlap [4n - 4, 4n - 1] + const InFloat* kv_buf_1, // normal [4n + 0, 4n + 3] + const InFloat* kv_src, // ragged pointer at position = 4n + 3 + OutFloat* kv_out, + const InFloat* score_bias, + const bool should_overlap, + const int32_t buffer_len) { + using namespace device; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + /// NOTE: load one tile_dim (< head_dim) at at time + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * Trait::kHeadDim); + } + + if (should_overlap) { + const auto kv_start = kv_src - 7 * Trait::kElementSize; // point to start +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + const auto src = i < buffer_len ? kv_buf_0 : kv_start; + const auto base = src + i * Trait::kElementSize; + kv[i] = gmem_in.load(base); + score[i] = gmem_in.load(base + Trait::kScoreOffset); + } + } else { + [[unlikely]]; + constexpr float kFloatNegInf = -FLT_MAX; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + const auto kv_start = kv_src - 3 * Trait::kElementSize; // point to start +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + const auto src = i + 4 < buffer_len ? kv_buf_1 : kv_start; + const auto base = src + i * Trait::kElementSize + Trait::kOverlapOffset; + kv[i + 4] = gmem_in.load(base); + score[i + 4] = gmem_in.load(base + Trait::kScoreOffset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + + // consume 32 fp registers + float score_fp32[kTileElements][8]; + + // convert to fp32 and apply bias first +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + for (int32_t j = 0; j < 8; ++j) { + score_fp32[i][j] = cast(score[j][i]) + cast(bias[j][i]); + } + } + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + const auto& score = score_fp32[i]; + float max_value = score[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + // overlap the store with the next iteration's load + PDLTriggerSecondary(); + gmem_out.store(kv_out, result); +} + +template +SGL_DEVICE void c4_write_decode(InFloat* kv_buf, const InFloat* kv_src) { + using namespace device; + + using StorageIn = AlignedVector; + const auto gmem = tile::Memory::warp(); + + StorageIn data[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + data[i] = gmem.load(kv_src + Trait::kHeadDim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_buf + Trait::kHeadDim * i, data[i]); + } +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + using Trait = C4Trait; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / Trait::kNumSplit; // batch id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_bid >= params.batch_size) return; + + const auto plan = params.plan_d[global_bid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + + const auto kv_src = kv_input + global_bid * Trait::kElementSize; + const auto kv_out = kv_output + global_bid * Trait::kHeadDim; + const auto kv_buf_0 = kv_buffer + plan.read_page_0 * Trait::kPageElementSize; + const auto kv_buf_1 = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + const auto kv_dst = kv_buffer + plan.write_loc * Trait::kElementSize; + + PDLWaitPrimary(); + c4_write_decode(kv_dst, kv_src); + if (plan.seq_len % 4 == 0) { + const auto need_overlap = plan.seq_len > 4; + c4_forward(kv_buf_0, kv_buf_1, kv_src, kv_out, score_bias, need_overlap, 8); + } +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + using Trait = C4Trait; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / Trait::kNumSplit; // plan id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_pid >= params.num_compress) return; + + const auto plan = params.plan_c[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + if (plan.is_invalid()) return; + + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + // Compact output: one row per compress plan, indexed by `global_pid`. + const auto kv_out = kv_output + global_pid * Trait::kHeadDim; + const auto kv_buf_0 = kv_buffer + plan.read_page_0 * Trait::kPageElementSize; + const auto kv_buf_1 = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + const bool need_overlap = plan.seq_len > 4; + PDLWaitPrimary(); + c4_forward(kv_buf_0, kv_buf_1, kv_src, kv_out, score_bias, need_overlap, plan.buffer_len); +} + +template +WRITE_KERNEL void write_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + using Trait = C4Trait; + using StorageIn = AlignedVector; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / Trait::kNumSplit; // plan id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + // split the contiguous `kHeadDim * 4` into `kNumSplit` tiles + // each warp handles 1 contiguous tile (in contrast, decode handle the strided head_dim) + const int64_t split_offset = global_sid * (Trait::kTileDim * 4); + if (global_pid >= params.num_write) return; + + const auto plan = params.plan_w[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + if (plan.is_invalid()) return; + + // each warp will handle a contiguous region + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + const auto kv_buf = kv_buffer + plan.write_loc * Trait::kElementSize; + const auto gmem = tile::Memory::warp(); + + PDLWaitPrimary(); + StorageIn data[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + data[i] = gmem.load(kv_src, i); + } + PDLTriggerSecondary(); +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_buf, data[i], i); + } +} + +template +struct FlashCompress4Kernel { + static constexpr auto decode_kernel = flash_c4_decode; + static constexpr auto prefill_c_kernel = flash_c4_prefill; + static constexpr auto prefill_w_kernel = write_c4_prefill; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + using Trait = C4Trait; + + static void run_decode( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_d_) { + using namespace host; + + auto N = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 4, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_d = compress::verify_plan_d(plan_d_, N, device_); + const auto batch_size = static_cast(N.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_d = plan_d, + .batch_size = batch_size, + }; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_c_, + const tvm::ffi::TensorView plan_w_) { + using namespace host; + + auto N = SymbolicSize{"num_q_tokens"}; + auto C = SymbolicSize{"num_c_plans"}; + auto W = SymbolicSize{"num_w_plans"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 4, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input (ragged) + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({C, kHeadDim}) // kv compressed output (compact) + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + const auto plan_c = compress::verify_plan_c(plan_c_, C, device_); + const auto plan_w = compress::verify_plan_w(plan_w_, W, device_); + const auto device = device_.unwrap(); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(C.unwrap()); + const auto num_w = static_cast(W.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_c = plan_c, + .plan_w = plan_w, + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh new file mode 100644 index 0000000000..3e4aaaf5f0 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh @@ -0,0 +1,839 @@ +#include +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include + +namespace host::compress { + +constexpr auto kDLUInt8 = DLDataType{.code = kDLUInt, .bits = 8, .lanes = 1}; + +using PlanC = CompressPlan; +using PlanW = WritePlan; +using PlanD = DecodePlan; + +using RID_T = int64_t; +using R2T_T = int32_t; +using F2S_T = int64_t; +using IDX_T = int64_t; + +/// NOTE: for the internal use, we pack the ragged and batch id, since both not exceed 65536 +SGL_DEVICE __host__ PlanW pack_w(uint32_t ragged_id, uint32_t batch_id, int32_t seq_len) { + return {static_cast(ragged_id | batch_id << 16), seq_len}; +} + +/// NOTE: for the internal use, we pack the ragged and batch id, since both not exceed 65536 +SGL_DEVICE uint2 unpack_w(PlanW plan) { + return {static_cast(plan.ragged_id), static_cast(plan.ragged_id >> 16)}; +} + +struct Prefill0Params { + PlanC* plan_c; + PlanW* plan_w; + const IDX_T* seq_lens_ptr; // [batch_size] + const IDX_T* extend_lens_ptr; // [batch_size] + uint32_t batch_size; + uint32_t num_q_tokens; + int32_t compress_ratio; + int32_t swa_page_size; + int32_t mtp_pad; +}; + +struct Prefill1Params { + PlanC* plan_c; + PlanW* plan_w; + const RID_T* rid_ptr; // [batch_size] + const R2T_T* r2t_ptr; // [num_reqs, stride_r2t] + const F2S_T* f2s_ptr; // [num_swa_slots] + int64_t stride_r2t; + uint32_t num_c; + uint32_t num_w; + uint32_t num_c_padded; + uint32_t num_w_padded; + uint32_t num_work; + int32_t swa_page_size; + int32_t ring_size; + int32_t compress_ratio; +}; + +struct DecodeParams { + PlanD* plan_d; + const RID_T* rid_ptr; // [batch_size] + const R2T_T* r2t_ptr; // [num_reqs, stride_r2t] + const F2S_T* f2s_ptr; // [num_swa_slots] + const IDX_T* seq_ptr; // [batch_size] + int64_t stride_r2t; + uint32_t batch_size; + int32_t swa_page_size; + int32_t ring_size; + int32_t compress_ratio; +}; + +struct Prefill1ParamsLegacy { + PlanC* plan_c; + PlanW* plan_w; + const RID_T* rid_ptr; // [batch_size] + uint32_t num_c; + uint32_t num_w; + uint32_t num_c_padded; + uint32_t num_w_padded; + uint32_t num_work; + int32_t compress_ratio; +}; + +struct DecodeParamsLegacy { + PlanD* plan_d; + const RID_T* rid_ptr; // [batch_size] + const IDX_T* seq_ptr; // [batch_size] + uint32_t batch_size; + int32_t compress_ratio; +}; + +inline constexpr uint32_t kMaxPrefillBatchSize = 1024; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { +#ifndef USE_ROCM + uint32_t n = __shfl_up_sync(device::kFullMask, val, offset); +#else + uint32_t n = __shfl_up(val, offset, 32); +#endif + if (lane_id >= offset) val += n; + } + return val; +} + +/// Warp-wide max/min for integer types. `device::warp::reduce_max` routes through +/// `dtype_trait::max` which is only specialized for FP types. +SGL_DEVICE uint32_t warp_reduce_max_u32(uint32_t val) { +#pragma unroll + for (uint32_t mask = 16; mask > 0; mask >>= 1) { +#ifndef USE_ROCM + val = max(val, __shfl_xor_sync(device::kFullMask, val, mask, 32)); +#else + val = max(val, __shfl_xor(val, mask, 32)); +#endif + } + return val; +} + +SGL_DEVICE uint32_t warp_reduce_min_u32(uint32_t val) { +#pragma unroll + for (uint32_t mask = 16; mask > 0; mask >>= 1) { +#ifndef USE_ROCM + val = min(val, __shfl_xor_sync(device::kFullMask, val, mask, 32)); +#else + val = min(val, __shfl_xor(val, mask, 32)); +#endif + } + return val; +} + +__global__ __launch_bounds__(1024, 1) // + void plan_compress_prefill_kernel0(const Prefill0Params params) { + using namespace device; + const auto tx = threadIdx.x; + const auto block_size = kMaxPrefillBatchSize; + constexpr auto kNumWarps = kMaxPrefillBatchSize / kWarpThreads; + const auto cr = params.compress_ratio; + const auto sps = params.swa_page_size; + const bool is_overlap = (cr == 4); + const int32_t window_size = cr * (is_overlap ? 2 : 1); + + alignas(128) __shared__ uint32_t counter_c; + alignas(128) __shared__ uint32_t counter_w; + __shared__ int32_t s_seq_len[kMaxPrefillBatchSize]; + __shared__ int32_t s_prefix_len[kMaxPrefillBatchSize]; + __shared__ uint32_t warp_max[kNumWarps]; + __shared__ uint32_t warp_min[kNumWarps]; + __shared__ uint32_t s_max_extend; + __shared__ uint32_t s_min_extend; + + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // === Stage A: load per-batch fields, init shared scratch === + int32_t seq_len = 0, extend_len = 0, prefix_len = 0; + if (tx < params.batch_size) { + seq_len = static_cast(params.seq_lens_ptr[tx]); + extend_len = static_cast(params.extend_lens_ptr[tx]); + prefix_len = seq_len - extend_len; + s_seq_len[tx] = seq_len; + s_prefix_len[tx] = prefix_len; + } + if (tx == 0) { + counter_c = 0; + counter_w = 0; + } + if (tx < kNumWarps) { + warp_max[tx] = 0; + warp_min[tx] = 0xFFFFFFFFu; + } + + // === Stage B: min/max(extend_len) for MTP-uniform detection === + // For min, treat threads outside `batch_size` as +inf so they don't pull the min down. + const uint32_t e_for_max = static_cast(extend_len); + const uint32_t e_for_min = (tx < params.batch_size) ? e_for_max : 0xFFFFFFFFu; + warp_max[warp_id] = warp_reduce_max_u32(e_for_max); + warp_min[warp_id] = warp_reduce_min_u32(e_for_min); + __syncthreads(); + if (warp_id == 0) { + s_max_extend = warp_reduce_max_u32(warp_max[lane_id]); + s_min_extend = warp_reduce_min_u32(warp_min[lane_id]); + } + __syncthreads(); + + const auto num_q = params.num_q_tokens; + // MTP-uniform: every batch shares the same small extend_len `E`, so we can decompose + // a global token id `k` into (batch_id, j) = (k / E, k % E) and skip the per-batch loop. + const bool is_mtp_extend = (s_min_extend == s_max_extend) && (s_max_extend > 0) && (s_max_extend <= 32); + + // === Stage C: emit valid plans, slot allocation via shared-mem atomicAdd === + if (is_mtp_extend) { + // Path 1: token-driven. Each global token id maps to exactly one (batch_id, j). + const uint32_t E = s_max_extend; + for (uint32_t k = tx; k < num_q; k += block_size) { + const uint32_t batch_id = k / E; + const uint32_t j = k % E; + const int32_t pl = s_prefix_len[batch_id]; + const int32_t sl = s_seq_len[batch_id]; + const int32_t position = pl + static_cast(j); + const uint32_t ragged_id = k; + + if ((position + 1) % cr == 0) { + const int32_t buffer_len = window_size - min(static_cast(j) + 1, window_size); + const uint32_t out_idx = atomicAdd(&counter_c, 1u); + params.plan_c[out_idx] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + .read_page_0 = -1, + .read_page_1 = static_cast(batch_id), + }; + } + + const int32_t last_c_pos = (sl / cr) * cr; + const int32_t first_w_pos = min(last_c_pos - (is_overlap ? cr : 0), sl - params.mtp_pad); + bool do_write = position >= first_w_pos; + if (!do_write && is_overlap) do_write = (position % sps) >= (sps - cr); + if (do_write) { + const uint32_t out_idx = atomicAdd(&counter_w, 1u); + params.plan_w[out_idx] = pack_w(ragged_id, batch_id, position + 1); + } + } + } else { + // Path 2: general prefill (long extend_len). Iterate batches in an outer loop; + // the whole block sweeps each batch's tokens in parallel. + uint32_t base_e = 0; + for (uint32_t batch_id = 0; batch_id < params.batch_size; ++batch_id) { + const int32_t pl = s_prefix_len[batch_id]; + const int32_t sl = s_seq_len[batch_id]; + const int32_t el = sl - pl; + const int32_t last_c_pos = (sl / cr) * cr; + const int32_t first_w_pos = min(last_c_pos - (is_overlap ? cr : 0), sl - params.mtp_pad); + for (int32_t j = static_cast(tx); j < el; j += static_cast(block_size)) { + const int32_t position = pl + j; + const uint32_t ragged_id = base_e + static_cast(j); + + if ((position + 1) % cr == 0) { + const int32_t buffer_len = window_size - min(j + 1, window_size); + const uint32_t out_idx = atomicAdd(&counter_c, 1u); + params.plan_c[out_idx] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + .read_page_0 = -1, + .read_page_1 = static_cast(batch_id), + }; + } + + bool do_write = position >= first_w_pos; + if (!do_write && is_overlap) do_write = (position % sps) >= (sps - cr); + if (do_write) { + const uint32_t out_idx = atomicAdd(&counter_w, 1u); + params.plan_w[out_idx] = pack_w(ragged_id, static_cast(batch_id), position + 1); + } + } + base_e += static_cast(el); + } + } + __syncthreads(); + + // === Stage D: pad [counter_c, num_q) / [counter_w, num_q) with invalid === + const auto total_c = counter_c; + const auto total_w = counter_w; + for (uint32_t k = total_c + tx; k < num_q; k += block_size) { + params.plan_c[k] = PlanC::invalid(); + } + for (uint32_t k = total_w + tx; k < num_q; k += block_size) { + params.plan_w[k] = PlanW::invalid(); + } +} + +/// NOTE: stage 1 +__global__ void plan_compress_prefill_kernel_1(const Prefill1Params params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.num_work) return; + auto plan_c = idx < params.num_c ? params.plan_c[idx] : PlanC::invalid(); + auto plan_w = idx < params.num_w ? params.plan_w[idx] : PlanW::invalid(); + + const auto compute_loc = [&](int32_t swa_loc) { + const auto swa_page = swa_loc / params.swa_page_size; + const auto ring_offset = swa_loc % params.ring_size; + return swa_page * params.ring_size + ring_offset; + }; + + if (!plan_c.is_invalid()) { // 1. in bound. 2. not masked + if (plan_c.buffer_len > 0) { + const auto batch_id = plan_c.read_page_1; + const auto rid = params.rid_ptr[batch_id]; + const auto mapping = params.r2t_ptr + rid * params.stride_r2t; + // `seq_len` should be ratio-aligned here + const auto position_1 = static_cast(plan_c.seq_len - 1); + // only used for c4, harmless for c128 + const auto position_0 = max(position_1 - params.compress_ratio, 0); + const auto raw_loc_0 = mapping[position_0]; + const auto raw_loc_1 = mapping[position_1]; + const auto swa_loc_0 = params.f2s_ptr[raw_loc_0]; + const auto swa_loc_1 = params.f2s_ptr[raw_loc_1]; + plan_c.read_page_0 = compute_loc(swa_loc_0) / params.compress_ratio; + plan_c.read_page_1 = compute_loc(swa_loc_1) / params.compress_ratio; + params.plan_c[idx] = plan_c; + } + } else if (idx < params.num_c_padded) { + params.plan_c[idx] = PlanC::invalid(); + } + + if (!plan_w.is_invalid()) { // 1. in bound. 2. not masked + const auto [ragged_id, batch_id] = unpack_w(plan_w); + const auto rid = params.rid_ptr[batch_id]; + const auto mapping = params.r2t_ptr + rid * params.stride_r2t; + // `seq_len` (`write_loc`) may not be aligned here + const auto position = static_cast(plan_w.write_loc - 1); + const auto raw_loc = mapping[position]; + const auto swa_loc = params.f2s_ptr[raw_loc]; + plan_w.ragged_id = ragged_id; + plan_w.write_loc = compute_loc(swa_loc); + params.plan_w[idx] = plan_w; + } else if (idx < params.num_w_padded) { + params.plan_w[idx] = PlanW::invalid(); + } +} + +__global__ void plan_compress_decode_kernel(const DecodeParams params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.batch_size) return; + const auto rid = params.rid_ptr[idx]; + const auto mapping = params.r2t_ptr + rid * params.stride_r2t; + const auto compute_loc = [&](int32_t swa_loc) { + const auto swa_page = swa_loc / params.swa_page_size; + const auto ring_offset = swa_loc % params.ring_size; + return swa_page * params.ring_size + ring_offset; + }; + const auto seq_len = static_cast(params.seq_ptr[idx]); + const auto position_1 = static_cast(seq_len - 1); + const auto position_0 = max(position_1 - params.compress_ratio, 0); + const auto raw_loc_0 = mapping[position_0]; + const auto raw_loc_1 = mapping[position_1]; + const auto swa_loc_0 = params.f2s_ptr[raw_loc_0]; + const auto swa_loc_1 = params.f2s_ptr[raw_loc_1]; + const auto write_loc = compute_loc(swa_loc_1); + const auto read_page_0 = compute_loc(swa_loc_0) / params.compress_ratio; + const auto read_page_1 = write_loc / params.compress_ratio; + params.plan_d[idx] = { + .seq_len = static_cast(seq_len), + .write_loc = write_loc, + .read_page_0 = read_page_0, + .read_page_1 = read_page_1, + }; +} + +__global__ void plan_compress_prefill_legacy_kernel(const Prefill1ParamsLegacy params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.num_work) return; + auto plan_c = idx < params.num_c ? params.plan_c[idx] : PlanC::invalid(); + auto plan_w = idx < params.num_w ? params.plan_w[idx] : PlanW::invalid(); + + /// Per-request ring buffer slot translation: + /// - c4: page = rid * 2 + (position / 4) % 2; slot = page * 4 + position % 4 + /// - c128: page = rid; slot = rid * 128 + position % 128 + const auto legacy_compute_page = [&](int32_t rid, int32_t position) { + if (params.compress_ratio == 4) return rid * 2 + ((position / 4) & 1); + return rid; // c128 + }; + const auto legacy_compute_loc = [&](int32_t rid, int32_t position) { + const auto remainder = position % params.compress_ratio; + return legacy_compute_page(rid, position) * params.compress_ratio + remainder; + }; + + if (!plan_c.is_invalid()) { + const auto batch_id = plan_c.read_page_1; + const auto rid = static_cast(params.rid_ptr[batch_id]); + // `seq_len` is ratio-aligned for compress events + const auto position_1 = static_cast(plan_c.seq_len) - 1; + const auto position_0 = max(position_1 - params.compress_ratio, 0); + plan_c.read_page_0 = legacy_compute_page(rid, position_0); + plan_c.read_page_1 = legacy_compute_page(rid, position_1); + params.plan_c[idx] = plan_c; + } else if (idx < params.num_c_padded) { + params.plan_c[idx] = PlanC::invalid(); + } + + if (!plan_w.is_invalid()) { + const auto [ragged_id, batch_id] = unpack_w(plan_w); + const auto rid = static_cast(params.rid_ptr[batch_id]); + // `write_loc` carries (position + 1) at this stage; may not be ratio-aligned + const auto position = static_cast(plan_w.write_loc) - 1; + plan_w.ragged_id = ragged_id; + plan_w.write_loc = legacy_compute_loc(rid, position); + params.plan_w[idx] = plan_w; + } else if (idx < params.num_w_padded) { + params.plan_w[idx] = PlanW::invalid(); + } +} + +__global__ void plan_compress_decode_legacy_kernel(const DecodeParamsLegacy params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.batch_size) return; + /// Per-request ring buffer slot translation: + /// - c4: page = rid * 2 + (position / 4) % 2; slot = page * 4 + position % 4 + /// - c128: page = rid; slot = rid * 128 + position % 128 + const auto legacy_compute_page = [&](int32_t rid, int32_t position) { + if (params.compress_ratio == 4) return rid * 2 + ((position / 4) & 1); + return rid; // c128 + }; + const auto legacy_compute_loc = [&](int32_t rid, int32_t position) { + const auto remainder = position % params.compress_ratio; + return legacy_compute_page(rid, position) * params.compress_ratio + remainder; + }; + const auto rid = static_cast(params.rid_ptr[idx]); + const auto seq_len = static_cast(params.seq_ptr[idx]); + const auto position_1 = seq_len - 1; + const auto position_0 = max(position_1 - params.compress_ratio, 0); + const auto write_loc = legacy_compute_loc(rid, position_1); + const auto read_page_0 = legacy_compute_page(rid, position_0); + const auto read_page_1 = legacy_compute_page(rid, position_1); + params.plan_d[idx] = { + .seq_len = static_cast(seq_len), + .write_loc = write_loc, + .read_page_0 = read_page_0, + .read_page_1 = read_page_1, + }; +} + +using PrefillPlan = tvm::ffi::Tuple; + +/** + * \brief Build c4/c128 prefill plan tensors. CPU-resident. + * Inputs (all CPU-resident): + * @param req_pool_indices `[batch_size]` int64_t + * @param req_to_token `[num_reqs, max_tokens_per_req]` int64_t + * @param full_to_swa `[num_swa_slots]` int64_t + * @param seq_lens `[batch_size]` int64 + * @param extend_lens `[batch_size]` int64 + * @param compress_plan `[num_q_tokens, 16]` uint8 (output) + * @param write_plan `[num_q_tokens, 8]` uint8 (output) + * @param compress_ratio 4 for c4, 128 for c128 + * @param use_cuda_graph Whether the plans will be used with cuda graph (affects padding) + * @return (compress plan tensor, write plan tensor) + */ +inline PrefillPlan plan_compress_prefill( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView req_to_token, // GPU + const tvm::ffi::TensorView full_to_swa, // GPU + const tvm::ffi::TensorView seq_lens, // CPU/GPU + const tvm::ffi::TensorView extend_lens, // CPU/GPU + const tvm::ffi::TensorView pin_buffer, // CPU + const uint32_t num_q_tokens, + const int32_t compress_ratio, + const int32_t swa_page_size, + const int32_t ring_size, + const bool use_cuda_graph) { + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto cpu_or_gpu = SymbolicDevice{}; + auto device_ = SymbolicDevice{}; + cpu_or_gpu.set_options(); + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({B}) // + .with_dtype() + .with_device(cpu_or_gpu) + .verify(seq_lens) + .verify(extend_lens); + TensorMatcher({-1}) // + .with_dtype() + .with_device() + .verify(pin_buffer); + + const bool is_overlap = (compress_ratio == 4); + const int32_t window_size = compress_ratio * (is_overlap ? 2 : 1); + + const auto seq_ptr = static_cast(seq_lens.data_ptr()); + const auto ext_ptr = static_cast(extend_lens.data_ptr()); + const auto rid_ptr = static_cast(req_pool_indices.data_ptr()); + const auto r2t_ptr = static_cast(req_to_token.data_ptr()); + const auto f2s_ptr = static_cast(full_to_swa.data_ptr()); + + const auto batch_size = static_cast(B.unwrap()); + constexpr auto kMaxTokens = static_cast(std::numeric_limits::max()); + RuntimeCheck(compress_ratio == 4 || compress_ratio == 128); + RuntimeCheck(batch_size <= num_q_tokens && num_q_tokens <= kMaxTokens); + // `swa_page_size` >= `ring_size` >= `compress_ratio` + RuntimeCheck(swa_page_size % ring_size == 0 && ring_size % compress_ratio == 0); + + const auto device = device_.unwrap(); + const auto stream = LaunchKernel::resolve_device(device); + + constexpr int32_t kMaxMTPDraftTokens = 4; + const auto mtp_pad = std::min(ring_size - compress_ratio, kMaxMTPDraftTokens); + + if (cpu_or_gpu.unwrap().device_type == kDLGPU) { + // GPU input path: kernel0 builds the (CPU-loop-equivalent) plan metadata directly + // on device, padding to num_q_tokens with invalid; kernel_1 then finalizes the + // SWA-translated read/write locations. Used for MTP / cuda-graph capture where + // a host sync would be expensive. + RuntimeCheck(batch_size <= kMaxPrefillBatchSize, "GPU plan only support batch size up to ", kMaxPrefillBatchSize); + auto C = ffi::empty({num_q_tokens, sizeof(PlanC)}, kDLUInt8, device); + auto W = ffi::empty({num_q_tokens, sizeof(PlanW)}, kDLUInt8, device); + const auto params0 = Prefill0Params{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .seq_lens_ptr = seq_ptr, + .extend_lens_ptr = ext_ptr, + .batch_size = batch_size, + .num_q_tokens = num_q_tokens, + .compress_ratio = compress_ratio, + .swa_page_size = swa_page_size, + .mtp_pad = mtp_pad, + }; + LaunchKernel(1, kMaxPrefillBatchSize, device)(plan_compress_prefill_kernel0, params0); + // kernel_1 sees the already-padded buffers, so num_c == num_w == num_padded == num_q_tokens. + const auto params1 = Prefill1Params{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .rid_ptr = rid_ptr, + .r2t_ptr = r2t_ptr, + .f2s_ptr = f2s_ptr, + .stride_r2t = req_to_token.stride(0), + .num_c = num_q_tokens, + .num_w = num_q_tokens, + .num_c_padded = num_q_tokens, + .num_w_padded = num_q_tokens, + .num_work = num_q_tokens, + .swa_page_size = swa_page_size, + .ring_size = ring_size, + .compress_ratio = compress_ratio, + }; + const auto block_size_1 = 256; + const auto num_blocks_1 = div_ceil(params1.num_work, block_size_1); + LaunchKernel(num_blocks_1, block_size_1, device)(plan_compress_prefill_kernel_1, params1); + return PrefillPlan{std::move(C), std::move(W)}; + } + + // CPU input path: only here do we need the pinned scratch buffer. + const auto pin_buffer_bytes = static_cast(pin_buffer.numel()) * sizeof(uint8_t); + RuntimeCheck(pin_buffer_bytes >= num_q_tokens * (sizeof(PlanC) + sizeof(PlanW))); + const auto plan_c_ptr = reinterpret_cast(pin_buffer.data_ptr()); + const auto plan_w_ptr = reinterpret_cast(plan_c_ptr + num_q_tokens); + + uint32_t counter = 0; + uint32_t counter_c = 0; + uint32_t counter_w = 0; + + const auto should_compress = [=](int32_t position) { return (position + 1) % compress_ratio == 0; }; + for (const auto i : irange(batch_size)) { + const int32_t seq_len = seq_ptr[i]; + const int32_t extend_len = ext_ptr[i]; + const int32_t prefix_len = seq_len - extend_len; + const int32_t last_c_pos = seq_len / compress_ratio * compress_ratio; + const int32_t first_w_pos = last_c_pos - (is_overlap ? compress_ratio : 0); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const auto should_write = [=](int32_t position) { + if (position >= first_w_pos) return true; + return is_overlap && position % swa_page_size >= (swa_page_size - compress_ratio); + }; + for (const auto j : irange(extend_len)) { + const int32_t position = prefix_len + j; + const int32_t ragged_id = counter + j; + if (should_compress(position)) { + const auto buffer_len = window_size - std::min(j + 1, window_size); + plan_c_ptr[counter_c++] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + // to be filled by kernel + .read_page_0 = -1, + .read_page_1 = static_cast(i), + }; + } + if (should_write(position)) { + plan_w_ptr[counter_w++] = pack_w(ragged_id, i, position + 1); + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_q_tokens); + + const auto copy_to_device = [stream](void* cuda_ptr, auto* host_ptr, size_t count) { + const auto size_bytes = count * sizeof(*host_ptr); + RuntimeDeviceCheck(cudaMemcpyAsync(cuda_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice, stream)); + }; + const auto num_c_padded = use_cuda_graph ? num_q_tokens : counter_c; + const auto num_w_padded = use_cuda_graph ? num_q_tokens : counter_w; + auto C = ffi::empty({num_c_padded, sizeof(PlanC)}, kDLUInt8, device); + auto W = ffi::empty({num_w_padded, sizeof(PlanW)}, kDLUInt8, device); + copy_to_device(C.data_ptr(), plan_c_ptr, counter_c); + copy_to_device(W.data_ptr(), plan_w_ptr, counter_w); + const auto params = Prefill1Params{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .rid_ptr = rid_ptr, + .r2t_ptr = r2t_ptr, + .f2s_ptr = f2s_ptr, + .stride_r2t = req_to_token.size(1), + .num_c = counter_c, + .num_w = counter_w, + .num_c_padded = num_c_padded, + .num_w_padded = num_w_padded, + .num_work = std::max(num_c_padded, num_w_padded), + .swa_page_size = swa_page_size, + .ring_size = ring_size, + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(params.num_work, block_size); + LaunchKernel(num_blocks, block_size, device)(plan_compress_prefill_kernel_1, params); + return PrefillPlan{std::move(C), std::move(W)}; +} + +inline tvm::ffi::Tensor plan_compress_decode( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView req_to_token, // GPU + const tvm::ffi::TensorView full_to_swa, // GPU + const tvm::ffi::TensorView seq_lens, // CPU/GPU + const int32_t compress_ratio, + const int32_t swa_page_size, + const int32_t ring_size) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto device = device_.unwrap(); + auto D = ffi::empty({batch_size, sizeof(PlanD)}, kDLUInt8, device); + const auto params = DecodeParams{ + .plan_d = static_cast(D.data_ptr()), + .rid_ptr = static_cast(req_pool_indices.data_ptr()), + .r2t_ptr = static_cast(req_to_token.data_ptr()), + .f2s_ptr = static_cast(full_to_swa.data_ptr()), + .seq_ptr = static_cast(seq_lens.data_ptr()), + .stride_r2t = req_to_token.size(1), + .batch_size = batch_size, + .swa_page_size = swa_page_size, + .ring_size = ring_size, + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(batch_size, block_size); + LaunchKernel(num_blocks, block_size, device)(plan_compress_decode_kernel, params); + return D; +} + +/** + * \brief Build c4/c128 prefill plan tensors for the legacy non-paged ring + * buffer. Uses only `req_pool_indices` to derive ring slots: + * - c4 (overlap): each request occupies 2 contiguous pages (8 token slots) + * - c128: each request occupies 1 page (128 token slots) + * + * Inputs: + * @param req_pool_indices `[batch_size]` int64 (GPU) + * @param seq_lens `[batch_size]` int64 (CPU) + * @param extend_lens `[batch_size]` int64 (CPU) + * @param pin_buffer pinned scratch (CPU uint8) + * @return (compress plan tensor, write plan tensor) + */ +inline PrefillPlan plan_compress_prefill_legacy( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView seq_lens, // CPU + const tvm::ffi::TensorView extend_lens, // CPU + const tvm::ffi::TensorView pin_buffer, // CPU + const uint32_t num_q_tokens, + const int32_t compress_ratio, + const bool use_cuda_graph) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({B}) // + .with_dtype() + .with_device() + .verify(seq_lens) + .verify(extend_lens); + TensorMatcher({-1}) // + .with_dtype() + .with_device() + .verify(pin_buffer); + + const auto pin_buffer_bytes = static_cast(pin_buffer.numel()) * sizeof(uint8_t); + RuntimeCheck(pin_buffer_bytes >= num_q_tokens * (sizeof(PlanC) + sizeof(PlanW))); + const auto plan_c_ptr = reinterpret_cast(pin_buffer.data_ptr()); + const auto plan_w_ptr = reinterpret_cast(plan_c_ptr + num_q_tokens); + + const bool is_overlap = (compress_ratio == 4); + const auto seq_ptr = static_cast(seq_lens.data_ptr()); + const auto ext_ptr = static_cast(extend_lens.data_ptr()); + const auto rid_ptr = static_cast(req_pool_indices.data_ptr()); + + const auto window_size = compress_ratio * (is_overlap ? 2 : 1); + const auto batch_size = static_cast(B.unwrap()); + constexpr auto kMaxTokens = static_cast(std::numeric_limits::max()); + RuntimeCheck(compress_ratio == 4 || compress_ratio == 128); + RuntimeCheck(batch_size <= num_q_tokens && num_q_tokens <= kMaxTokens); + + uint32_t counter = 0; + uint32_t counter_c = 0; + uint32_t counter_w = 0; + const auto should_compress = [=](int32_t position) { return (position + 1) % compress_ratio == 0; }; + for (const auto i : irange(batch_size)) { + const int32_t seq_len = seq_ptr[i]; + const int32_t extend_len = ext_ptr[i]; + const int32_t prefix_len = seq_len - extend_len; + const int32_t last_c_pos = seq_len / compress_ratio * compress_ratio; + const int32_t first_w_pos = last_c_pos - (is_overlap ? compress_ratio : 0); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const auto should_write = [=](int32_t position) { return position >= first_w_pos; }; + for (const auto j : irange(extend_len)) { + const int32_t position = prefix_len + j; + const int32_t ragged_id = counter + j; + if (should_compress(position)) { + const auto buffer_len = window_size - std::min(j + 1, window_size); + plan_c_ptr[counter_c++] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + // to be filled by kernel + .read_page_0 = -1, + .read_page_1 = static_cast(i), + }; + } + if (should_write(position)) { + plan_w_ptr[counter_w++] = pack_w(ragged_id, i, position + 1); + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_q_tokens); + + const auto device = device_.unwrap(); + const auto stream = LaunchKernel::resolve_device(device); + const auto copy_to_device = [stream](void* cuda_ptr, auto* host_ptr, size_t count) { + const auto size_bytes = count * sizeof(*host_ptr); + RuntimeDeviceCheck(cudaMemcpyAsync(cuda_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice, stream)); + }; + const auto num_c_padded = use_cuda_graph ? num_q_tokens : counter_c; + const auto num_w_padded = use_cuda_graph ? num_q_tokens : counter_w; + auto C = ffi::empty({num_c_padded, sizeof(PlanC)}, kDLUInt8, device); + auto W = ffi::empty({num_w_padded, sizeof(PlanW)}, kDLUInt8, device); + copy_to_device(C.data_ptr(), plan_c_ptr, counter_c); + copy_to_device(W.data_ptr(), plan_w_ptr, counter_w); + const auto params = Prefill1ParamsLegacy{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .rid_ptr = rid_ptr, + .num_c = counter_c, + .num_w = counter_w, + .num_c_padded = num_c_padded, + .num_w_padded = num_w_padded, + .num_work = std::max(num_c_padded, num_w_padded), + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(params.num_work, block_size); + if (num_blocks > 0) { + LaunchKernel(num_blocks, block_size, device)(plan_compress_prefill_legacy_kernel, params); + } + return PrefillPlan{std::move(C), std::move(W)}; +} + +inline tvm::ffi::Tensor plan_compress_decode_legacy( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView seq_lens, // GPU + const int32_t compress_ratio) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + RuntimeCheck(compress_ratio == 4 || compress_ratio == 128); + + const auto batch_size = static_cast(B.unwrap()); + const auto device = device_.unwrap(); + auto D = ffi::empty({batch_size, sizeof(PlanD)}, kDLUInt8, device); + const auto params = DecodeParamsLegacy{ + .plan_d = static_cast(D.data_ptr()), + .rid_ptr = static_cast(req_pool_indices.data_ptr()), + .seq_ptr = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(batch_size, block_size); + LaunchKernel(num_blocks, block_size, device)(plan_compress_decode_legacy_kernel, params); + return D; +} + +} // namespace host::compress + +using namespace host::compress; // expose binding diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh new file mode 100644 index 0000000000..46acaa9c46 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh @@ -0,0 +1,208 @@ +#include +#include + +#include + +#include + +namespace host::compress { + +using PlanResult = tvm::ffi::Tuple; + +struct CompressParams { + PrefillPlan* __restrict__ compress_plan; + PrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; + uint32_t compress_ratio; + bool is_overlap; +}; + +inline constexpr uint32_t kBlockSize = 1024; + +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) inline + +PLAN_KERNEL void plan_prefill_cuda(const __grid_constant__ CompressParams params) { + const auto &[ + compress_plan, write_plan, seq_lens, extend_lens, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + __shared__ uint32_t compress_counter; + __shared__ uint32_t write_counter; + + uint32_t batch_id = 0; + uint32_t counter = 0; + uint32_t extend_len = extend_lens[0]; + + const auto tid = threadIdx.x; + if (tid == 0) { + compress_counter = 0; + write_counter = 0; + } + __syncthreads(); + + for (uint32_t i = tid; i < num_tokens; i += blockDim.x) { + const uint32_t ragged_id = i; + uint32_t j = ragged_id - counter; + while (j >= extend_len) { + j -= extend_len; + batch_id += 1; + if (batch_id >= batch_size) [[unlikely]] + break; + counter += extend_len; + extend_len = extend_lens[batch_id]; + } + if (batch_id >= batch_size) [[unlikely]] + break; + const uint32_t seq_len = seq_lens[batch_id]; + const uint32_t extend_len = extend_lens[batch_id]; + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t ratio = compress_ratio * (1 + is_overlap); + const uint32_t window_len = j + 1 < ratio ? ratio - (j + 1) : 0; + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = ragged_id, + .batch_id = batch_id, + .position = position, + .window_len = window_len, + }; + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + if ((position + 1) % compress_ratio == 0) { + const auto write_pos = atomicAdd(&compress_counter, 1); + compress_plan[write_pos] = plan; + } + if (position >= start_write_pos) { + const auto write_pos = atomicAdd(&write_counter, 1); + write_plan[write_pos] = plan; + } + } + __syncthreads(); + constexpr auto kInvalid = static_cast(-1); + const auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + const auto compress_count = compress_counter; + const auto write_count = write_counter; + for (uint32_t i = compress_count + tid; i < num_tokens; i += blockDim.x) { + compress_plan[i] = kInvalidPlan; + } + for (uint32_t i = write_count + tid; i < num_tokens; i += blockDim.x) { + write_plan[i] = kInvalidPlan; + } +} + +inline PlanResult plan_prefill_host(const CompressParams& params, const bool use_cuda_graph) { + const auto &[ + compress_ptr, write_ptr, seq_lens_ptr, extend_lens_ptr, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + uint32_t counter = 0; + uint32_t compress_counter = 0; + uint32_t write_counter = 0; + const auto ratio = compress_ratio * (1 + is_overlap); + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = seq_lens_ptr[i]; + const uint32_t extend_len = extend_lens_ptr[i]; + const uint32_t prefix_len = seq_len - extend_len; + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + /// NOTE: `start_write_pos` must be a multiple of `compress_ratio` + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + /// NOTE: to avoid unsigned integer underflow, don't use `pos - compress_ratio` + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + /// NOTE: `position` is within [prefix_len, seq_len) + for (const auto j : irange(extend_len)) { + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = counter + j, + .batch_id = i, + .position = position, + .window_len = ratio - std::min(j + 1, ratio), + }; + RuntimeCheck(plan.is_valid(compress_ratio, is_overlap), "Internal error!"); + if ((position + 1) % compress_ratio == 0) { + compress_ptr[compress_counter++] = plan; + } + if (position >= start_write_pos) { + write_ptr[write_counter++] = plan; + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return PlanResult{compress_counter, write_counter}; + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_counter, num_tokens)) { + compress_ptr[i] = kInvalidPlan; + } + for (const auto i : irange(write_counter, num_tokens)) { + write_ptr[i] = kInvalidPlan; + } + return PlanResult{num_tokens, num_tokens}; +} + +inline PlanResult plan_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const uint32_t compress_ratio, + const bool is_overlap, // for overlap transform, we have to keep 1 more extra window + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + const bool is_cuda = [&] { + if (extend_lens.device().device_type == kDLCUDA) { + device.set_options(); + return true; + } else { + device.set_options(); + return false; + } + }(); + TensorMatcher({N}) // extend_lens and seq_lens + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kPrefillPlanDim}) // compress_plan and write_plan + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + + const auto params = CompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + .compress_ratio = compress_ratio, + .is_overlap = is_overlap, + }; + + if (!is_cuda) return plan_prefill_host(params, use_cuda_graph); + /// NOTE: cuda kernel plan is naturally compatible with cuda graph + LaunchKernel(1, kBlockSize, device.unwrap())(plan_prefill_cuda, params); + return PlanResult{params.num_tokens, params.num_tokens}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_prefill = host::compress::plan_prefill; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh new file mode 100644 index 0000000000..d3953578b9 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh @@ -0,0 +1,254 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using Plan = device::compress::PrefillPlan; + +/// \brief common block size for memory-bound kernel +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeParams { + void* __restrict__ input; + const void* __restrict__ weight; + float eps; + uint32_t num_works; + const void* __restrict__ handle; + const float* __restrict__ freqs_cis; + uint32_t compress_ratio; +}; + +enum class ForwardMode { + CompressExtend = 0, + CompressDecode = 1, + DefaultForward = 2, +}; + +template +__global__ void fused_norm_rope(const __grid_constant__ FusedNormRopeParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr int64_t kRopeVecSize = kRopeDim / (kWarpThreads * 2); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % (kWarpThreads * 2) == 0); + static_assert(kRopeDim % (kVecSize * kLocalSize) == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeVecSize == 1, "only support rope dim = 64"); + + const auto& [ + _input, _weight, eps, num_works, // norm + handle, freqs_cis, compress_ratio // rope + ] = params; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + + if (work_id >= num_works) return; + + DType* input; + int32_t position; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(handle)[work_id]; + input = static_cast(_input) + plan.ragged_id * kHeadDim; + position = plan.position + 1 - compress_ratio; + if (plan.ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + } else if constexpr (kMode == CompressDecode) { + input = static_cast(_input) + work_id * kHeadDim; + const auto seq_len = static_cast(handle)[work_id]; + if (seq_len % compress_ratio != 0) return; + position = seq_len - compress_ratio; + } else if constexpr (kMode == DefaultForward) { + input = static_cast(_input) + work_id * kHeadDim; + position = static_cast(handle)[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + + using Storage = AlignedVector; + __shared__ Storage s_rope_input[kNumWarps][kRopeSize]; + + // prefetch freq + const auto mem_freq = tile::Memory::warp(); + const auto freq = mem_freq.load(freqs_cis + position * kRopeDim); + + PDLWaitPrimary(); + + // part 1: norm + { + const auto gmem = tile::Memory::warp(); + Storage input_vec[kLocalSize]; + Storage weight_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input, i); + } + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + weight_vec[i] = gmem.load(_weight, i); + } + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + sum_of_squares += fp32_input * fp32_input; + } + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + const auto fp32_weight = cast(weight_vec[i][j]); + input_vec[i][j] = cast(fp32_input * norm_factor * fp32_weight); + } + } + + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope_input[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(input, input_vec[i], i); + } + } + + __syncwarp(); + } + + // part 2: rope + { + // mem elem = DType x 2 + using DTypex2_t = packed_t; + const auto mem_elem = tile::Memory::warp(); + const auto elem = mem_elem.load(s_rope_input[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t output = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(input + (kHeadDim - kRopeDim), cast(output)); + } + + PDLTriggerSecondary(); +} + +template +struct FusedNormRopeKernel { + template + static constexpr auto fused_kernel = fused_norm_rope; + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView handle, + const tvm::ffi::TensorView freqs_cis, + int32_t _mode, + float eps, + uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(_mode); + + auto B = SymbolicSize{"num_q_tokens"}; + auto N = SymbolicSize{"num_compress_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + switch (mode) { + case CompressExtend: + TensorMatcher({N, compress::kPrefillPlanDim}) // plan + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case CompressDecode: + TensorMatcher({N}) // seq_len + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case DefaultForward: + TensorMatcher({N}) // position + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio == 0); + break; + default: + Panic("unsupported forward mode: ", static_cast(mode)); + } + + // launch kernel + const auto num_compress_tokens = static_cast(N.unwrap()); + if (num_compress_tokens == 0) return; + const auto params = FusedNormRopeParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .eps = eps, + .num_works = num_compress_tokens, + .handle = handle.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .compress_ratio = compress_ratio, + }; + const auto num_blocks = div_ceil(num_compress_tokens, kNumWarps); + using KernelType = std::decay_t)>; + static constexpr KernelType kernel_table[3] = { + [static_cast(CompressExtend)] = fused_kernel, + [static_cast(CompressDecode)] = fused_kernel, + [static_cast(DefaultForward)] = fused_kernel, + }; + const auto kernel = kernel_table[static_cast(mode)]; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh new file mode 100644 index 0000000000..a9cac17544 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh @@ -0,0 +1,643 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace { + +using PlanC = device::compress::CompressPlan; +using PlanD = device::compress::DecodePlan; +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +SGL_DEVICE uint8_t quant_fp4_e2m1(float x) { + const float ax = fminf(fabsf(x), 6.0f); + uint8_t idx = 0; + idx += ax > 0.25f; + idx += ax > 0.75f; + idx += ax > 1.25f; + idx += ax > 1.75f; + idx += ax > 2.5f; + idx += ax > 3.5f; + idx += ax > 5.0f; + if (x < 0.0f && idx != 0) idx |= 0x8; + return idx; +} + +constexpr uint32_t kBlockSize = 256; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeStoreParams { + void* __restrict__ input; + const void* __restrict__ handle; // plan decode / compress + const void* __restrict__ weight; + const float* __restrict__ freqs_cis; + const int32_t* __restrict__ out_loc; + uint8_t* __restrict__ kvcache; + float eps; + uint32_t compress_ratio; + uint32_t num_tokens; +}; + +enum class ForwardMode : bool { + CompressExtend = 0, + CompressDecode = 1, +}; + +#define INDEXER_KERNEL __global__ __launch_bounds__(kBlockSize, 8) +#define FLASHMLA_KERNEL __global__ __launch_bounds__(kBlockSize, 8) + +// ---------------------------------------------------------------------------- +// Indexer variant: kHeadDim = 128, 1 token per *warp* (8 tokens per block). +// Each warp's 32 lanes cover the full 128-elem head_dim (kVecSize = 4 each). +// Cache layout: 132 bytes/token (128 fp8 nope + 4 fp32 scale). +// ---------------------------------------------------------------------------- +template +INDEXER_KERNEL void fused_norm_rope_indexer(const __grid_constant__ FusedNormRopeStoreParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + constexpr int64_t kPageBytes = 132ll << kPageBits; + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + using Storage = AlignedVector; + using Float4 = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + // Lanes whose 4-elem pack lies in the rope tail (= last `kRopeSize` packs). + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + if (work_id >= params.num_tokens) return; + + const auto input = static_cast(params.input) + work_id * kHeadDim; + int32_t position; + int32_t out_loc; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.is_invalid()) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[plan.ragged_id]; + } else if constexpr (kMode == CompressDecode) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.seq_len % params.compress_ratio != 0) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float4 data, freq; + + // part 1: norm + { + Storage input_vec, weight_vec; + input_vec.load(input, lane_id); + weight_vec.load(params.weight, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + sum_of_squares += fp32_input * fp32_input; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + const auto fp32_weight = cast(weight_vec[i]); + data[i] = fp32_input * norm_factor * fp32_weight; + } + } + + // part 2: rope (rope-lane only, 4 elems per lane = 2 (real, imag) pairs) + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto freq_x_real = freq[0]; + const auto freq_x_imag = freq[1]; + const auto freq_y_real = freq[2]; + const auto freq_y_imag = freq[3]; + data[0] = x_real * freq_x_real - x_imag * freq_x_imag; + data[1] = x_real * freq_x_imag + x_imag * freq_x_real; + data[2] = y_real * freq_y_real - y_imag * freq_y_imag; + data[3] = y_real * freq_y_imag + y_imag * freq_y_real; + } + + // part 3: hadamard transform + { + // Stage 1: butterfly (data[0], data[1]) and (data[2], data[3]). + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + // Stage 2: butterfly (data[0], data[2]) and (data[1], data[3]). + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } + // Stages 3..7: cross-lane butterflies. Lower-lane (mask bit clear) keeps + // the sum, upper-lane (mask bit set) keeps the difference. shfl_xor is + // unsynchronized across early-returned lanes, but invalid-plan returns + // happen above for *all* lanes of a warp (work_id is warp-uniform), so + // the warp is intact here. +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { +#ifndef USE_ROCM + const float other = __shfl_xor_sync(kFullMask, data[i], mask, kWarpThreads); +#else + const float other = __shfl_xor(data[i], mask, kWarpThreads); +#endif + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; + } + + // part 4: per-warp UE8M0 quant + store. The whole warp emits one fp8 group + // (= 128 elements) plus a single fp32 scale, matching the indexer cache + // layout (`fused_store_indexer_cache`). + { + using OutStorage = AlignedVector; + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + const auto abs_max = warp::reduce_max(local_max); + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 128; + const auto scale_ptr = page_ptr + (128 << kPageBits) + offset * 4; + OutStorage result; + result[0] = pack_fp8(data[0] * inv_scale, data[1] * inv_scale); + result[1] = pack_fp8(data[2] * inv_scale, data[3] * inv_scale); + PDLTriggerSecondary(); + result.store(value_ptr, lane_id); + // The single fp32 scale is identical across all lanes -- write from any lane. + if (lane_id == 0) reinterpret_cast(scale_ptr)[0] = scale; + } +} + +template +INDEXER_KERNEL void fused_norm_rope_indexer_fp4(const __grid_constant__ FusedNormRopeStoreParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + constexpr int64_t kPageBytes = 68ll << kPageBits; + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + using Storage = AlignedVector; + using Float4 = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + if (work_id >= params.num_tokens) return; + + const auto input = static_cast(params.input) + work_id * kHeadDim; + int32_t position; + int32_t out_loc; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.is_invalid()) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[plan.ragged_id]; + } else if constexpr (kMode == CompressDecode) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.seq_len % params.compress_ratio != 0) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float4 data, freq; + + { + Storage input_vec, weight_vec; + input_vec.load(input, lane_id); + weight_vec.load(params.weight, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + sum_of_squares += fp32_input * fp32_input; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + const auto fp32_weight = cast(weight_vec[i]); + data[i] = fp32_input * norm_factor * fp32_weight; + } + } + + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto freq_x_real = freq[0]; + const auto freq_x_imag = freq[1]; + const auto freq_y_real = freq[2]; + const auto freq_y_imag = freq[3]; + data[0] = x_real * freq_x_real - x_imag * freq_x_imag; + data[1] = x_real * freq_x_imag + x_imag * freq_x_real; + data[2] = y_real * freq_y_real - y_imag * freq_y_imag; + data[3] = y_real * freq_y_imag + y_imag * freq_y_real; + } + + { + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const float other = __shfl_xor_sync(0xFFFFFFFFu, data[i], mask, kWarpThreads); + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; + } + + { + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + local_max = warp::reduce_max<8>(local_max); + + const auto scale_raw = fmaxf(1e-4f, local_max) / 6.0f; + const auto scale_ue8m0 = static_cast(cast_to_ue8m0(scale_raw)); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + + const uint8_t packed0 = quant_fp4_e2m1(data[0] * inv_scale) | (quant_fp4_e2m1(data[1] * inv_scale) << 4); + const uint8_t packed1 = quant_fp4_e2m1(data[2] * inv_scale) | (quant_fp4_e2m1(data[3] * inv_scale) << 4); + const uint16_t packed = static_cast(packed0) | (static_cast(packed1) << 8); + + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 64; + const auto scale_ptr = page_ptr + (64 << kPageBits) + offset * 4; + + PDLTriggerSecondary(); + reinterpret_cast(value_ptr)[lane_id] = packed; + if ((lane_id & 7) == 0) static_cast(scale_ptr)[lane_id >> 3] = scale_ue8m0; + } +} + +// ---------------------------------------------------------------------------- +// FlashMLA variant: kHeadDim = 512, 1 token per *block* (256 threads). +// Each thread loads kVecSize=2 BF16, so 256 threads cover the full 512 elems. +// Cache layout: 584 bytes/token = 448 fp8 nope + 64 (=32 bf16x2) rope + 8 scale. +// ---------------------------------------------------------------------------- +template +FLASHMLA_KERNEL void fused_norm_rope_flashmla(const __grid_constant__ FusedNormRopeStoreParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kHeadDim = 512; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 2; + // Last warp owns the rope tail. The remaining 7 warps each emit one + // 64-element fp8 group (own UE8M0 scale). + constexpr uint32_t kRopeWarp = kNumWarps - 1; + constexpr int64_t kPageBytes = host::div_ceil(584ll << kPageBits, 576) * 576; + static_assert(kHeadDim == kBlockSize * kVecSize); + static_assert(kRopeDim == kWarpThreads * kVecSize); + static_assert(kHeadDim - kRopeDim == kRopeWarp * kWarpThreads * kVecSize); + using Storage = AlignedVector; + using Float2 = AlignedVector; + + const auto tx = threadIdx.x; + const auto warp_id = tx / kWarpThreads; + const auto lane_id = tx % kWarpThreads; + const auto work_id = blockIdx.x; + + if (work_id >= params.num_tokens) return; + + const auto input = static_cast(params.input) + work_id * kHeadDim; + int32_t position; + int32_t out_loc; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.is_invalid()) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[plan.ragged_id]; + } else if constexpr (kMode == CompressDecode) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.seq_len % params.compress_ratio != 0) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float2 data, freq; + + // part 1: norm. Each thread owns one 2-elem pack (`tx`-th pack of input). + // Sum of squares is reduced across the whole block via per-warp partials. + { + __shared__ float partial_sums[kNumWarps]; + + Storage input_vec, weight_vec; + input_vec.load(input, tx); + weight_vec.load(params.weight, tx); + if (warp_id == kRopeWarp) freq.load(freqs_cis, lane_id); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + sum_of_squares += fp32_input * fp32_input; + } + + const auto warp_sum = warp::reduce_sum(sum_of_squares); + if (lane_id == 0) partial_sums[warp_id] = warp_sum; + __syncthreads(); + // Replicate the per-warp partial sums to a full warp and reduce. Every + // lane-group of `kNumWarps` lanes ends up with the global sum. + sum_of_squares = warp::reduce_sum(partial_sums[lane_id % kNumWarps]); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + const auto fp32_weight = cast(weight_vec[i]); + data[i] = fp32_input * norm_factor * fp32_weight; + } + } + + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 576; + + PDLTriggerSecondary(); + + // part 2: rope on the rope warp (BF16 store), or per-warp FP8 quant + store. + if (warp_id == kRopeWarp) { + // Each rope-warp lane owns exactly one (real, imag) pair within the rope + // tail. Apply rotation, downcast to BF16, write to the slot's rope region. + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto freq_real = freq[0]; + const auto freq_imag = freq[1]; + data[0] = x_real * freq_real - x_imag * freq_imag; + data[1] = x_real * freq_imag + x_imag * freq_real; + const auto result = cast(fp32x2_t{data[0], data[1]}); + const auto rope_ptr = value_ptr + 448; + reinterpret_cast(rope_ptr)[lane_id] = result; + } else { + // Non-rope warp: per-warp UE8M0 group (64 elems -> 64 fp8 + 1 scale byte). + // BF16 round-trip to match the precision of the non-fused path + // (which goes through quant_to_nope_fp8_rope_bf16_pack_triton with bf16 input). + const auto x = cast(cast(data[0])); + const auto y = cast(cast(data[1])); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const auto scale_ptr = page_ptr + (576 << kPageBits) + offset * 8; + reinterpret_cast(value_ptr)[tx] = result; + // All lanes in this warp produce the same scale byte; let lane 0 publish. + if (lane_id == 0) static_cast(scale_ptr)[warp_id] = scale_ue8m0; + } +} + +template +struct FusedNormRopeKernel { + static constexpr int32_t kLogPageSize = std::countr_zero(kPageSize); + static constexpr bool kIsIndexer = (kHeadDim == 128); + static constexpr int64_t kIndexerBytes = 132 * kPageSize; + static constexpr int64_t kFlashMLABytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr int64_t kPageBytes = kIsIndexer ? kIndexerBytes : kFlashMLABytes; + + /// TODO: Let's fix the config for now. + static_assert(kRopeDim == 64 && (kHeadDim == 128 || kHeadDim == 512)); + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + + template + static constexpr auto select_kernel() { + if constexpr (kIsIndexer) { + return fused_norm_rope_indexer; + } else { + return fused_norm_rope_flashmla; + } + } + + template + static constexpr auto select_fp4_kernel() { + static_assert(kIsIndexer, "FP4 fused store is only defined for the indexer"); + return fused_norm_rope_indexer_fp4; + } + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView plan, + const tvm::ffi::TensorView weight, + const float eps, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView out_loc, + const tvm::ffi::TensorView kvcache, + const bool is_decode, + const uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(is_decode); + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({N, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + TensorMatcher({-1}) // out_loc + .with_dtype() + .with_device(device_) + .verify(out_loc); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(kvcache); + + switch (mode) { + case CompressExtend: + compress::verify_plan_c(plan, N, device_); + RuntimeCheck(out_loc.size(0) >= N.unwrap()); + break; + case CompressDecode: + compress::verify_plan_d(plan, N, device_); + RuntimeCheck(out_loc.size(0) == N.unwrap()); + break; + } + + const auto num_tokens = static_cast(N.unwrap()); + if (num_tokens == 0) return; + const auto params = FusedNormRopeStoreParams{ + .input = input.data_ptr(), + .handle = plan.data_ptr(), + .weight = weight.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .out_loc = static_cast(out_loc.data_ptr()), + .kvcache = static_cast(kvcache.data_ptr()), + .eps = eps, + .compress_ratio = compress_ratio, + .num_tokens = num_tokens, + }; + // Indexer packs `kNumWarps` tokens per block (warp-major); FlashMLA uses + // a whole block per token (cta-major sum-reduce over head_dim=512). + const uint32_t num_blocks = kIsIndexer ? div_ceil(num_tokens, kNumWarps) : num_tokens; + const auto device = device_.unwrap(); + const auto kernel = mode == CompressExtend ? select_kernel() : select_kernel(); + LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + } + + static void forward_fp4( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView plan, + const tvm::ffi::TensorView weight, + const float eps, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView out_loc, + const tvm::ffi::TensorView kvcache, + const bool is_decode, + const uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + static_assert(kIsIndexer, "FP4 fused store is only defined for the indexer"); + constexpr int64_t kFp4PageBytes = 68 * kPageSize; + const auto mode = static_cast(is_decode); + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({N, kHeadDim}).with_dtype().with_device(device_).verify(input); + TensorMatcher({kHeadDim}).with_dtype().with_device(device_).verify(weight); + TensorMatcher({-1, kRopeDim}).with_dtype().with_device(device_).verify(freqs_cis); + TensorMatcher({-1}).with_dtype().with_device(device_).verify(out_loc); + TensorMatcher({-1, -1}).with_strides({kFp4PageBytes, 1}).with_dtype().with_device(device_).verify(kvcache); + + switch (mode) { + case CompressExtend: + compress::verify_plan_c(plan, N, device_); + RuntimeCheck(out_loc.size(0) >= N.unwrap()); + break; + case CompressDecode: + compress::verify_plan_d(plan, N, device_); + RuntimeCheck(out_loc.size(0) == N.unwrap()); + break; + } + + const auto num_tokens = static_cast(N.unwrap()); + if (num_tokens == 0) return; + const auto params = FusedNormRopeStoreParams{ + .input = input.data_ptr(), + .handle = plan.data_ptr(), + .weight = weight.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .out_loc = static_cast(out_loc.data_ptr()), + .kvcache = static_cast(kvcache.data_ptr()), + .eps = eps, + .compress_ratio = compress_ratio, + .num_tokens = num_tokens, + }; + const uint32_t num_blocks = div_ceil(num_tokens, kNumWarps); + const auto device = device_.unwrap(); + const auto kernel = + mode == CompressExtend ? select_fp4_kernel() : select_fp4_kernel(); + LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh new file mode 100644 index 0000000000..90dec3c117 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh @@ -0,0 +1,214 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +[[maybe_unused]] +SGL_DEVICE float act_sqrt_softplus(float x) { + const float softplus = fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x))); + return sqrtf(softplus); +} + +struct MoEHashTopKParams { + const float* __restrict__ router_logits; + const int64_t* __restrict__ input_id; + const int32_t* __restrict__ tid2eid; + int32_t* __restrict__ topk_ids; + float* __restrict__ topk_weights; + uint32_t num_tokens; + uint32_t topk; + uint32_t num_routed_experts; + uint32_t num_shared_experts; + float routed_scaling_factor; +}; + +template +__global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ params) { + using namespace device; + const auto& [ + router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers + num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] = + params; + + const uint32_t topk_fused = topk + num_shared_experts; + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / kWarpThreads; + const uint32_t lane_id = tid % kWarpThreads; + if (warp_id >= num_tokens) return; + // we can safely prefetch the token id + const auto token_id = input_id[warp_id]; + + PDLWaitPrimary(); + + float routed_weight = 0.0f; + int32_t expert_id = 0; + if (lane_id < topk) { + expert_id = tid2eid[token_id * topk + lane_id]; + routed_weight = Fn(router_logits[warp_id * num_routed_experts + expert_id]); + } + + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk_fused) { + const bool is_shared = lane_id >= topk; + const auto output_offset = warp_id * topk_fused + lane_id; + topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id; + topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum; + } + + PDLTriggerSecondary(); +} + +struct TopKParams { + int32_t* __restrict__ topk_ids; + // Exactly one is active: ntn_ptr == nullptr means use ntn_value. + const int32_t* __restrict__ ntn_ptr; + int32_t ntn_value; + int64_t stride; + uint32_t topk; + uint32_t num_tokens; +}; + +__global__ void mask_topk_ids_padded_region(const TopKParams __grid_constant__ params) { + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / device::kWarpThreads; + const uint32_t lane_id = tid % device::kWarpThreads; + if (warp_id >= params.num_tokens || lane_id >= params.topk) return; + device::PDLWaitPrimary(); + const uint32_t num = (params.ntn_ptr != nullptr) // + ? static_cast(params.ntn_ptr[0]) + : static_cast(params.ntn_value); + if (warp_id >= num) params.topk_ids[warp_id * params.stride + lane_id] = -1; + device::PDLTriggerSecondary(); +} + +template +struct HashTopKKernel { + static constexpr auto kernel = moe_hash_topk_fused; + + static void + run(const tvm::ffi::TensorView router_logits, + const tvm::ffi::TensorView input_id, + const tvm::ffi::TensorView tid2eid, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView topk_ids, + float routed_scaling_factor) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto E = SymbolicSize{"num_routed_experts"}; + auto K = SymbolicSize{"topk_fused"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, E}) // + .with_dtype() + .with_device(device) + .verify(router_logits); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input_id); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device) + .verify(tid2eid); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_ids); + + const auto num_tokens = static_cast(N.unwrap()); + const auto topk_fused = static_cast(K.unwrap()); + const auto topk = static_cast(tid2eid.size(1)); + const auto shared_experts = topk_fused - topk; + RuntimeCheck(topk <= topk_fused, "HashTopKKernel requires topk <= topk_fused"); + RuntimeCheck(topk_fused <= device::kWarpThreads, "HashTopKKernel requires topk_fused <= warp size"); + + const auto params = MoEHashTopKParams{ + .router_logits = static_cast(router_logits.data_ptr()), + .input_id = static_cast(input_id.data_ptr()), + .tid2eid = static_cast(tid2eid.data_ptr()), + .topk_ids = static_cast(topk_ids.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .num_tokens = num_tokens, + .topk = topk, + .num_routed_experts = static_cast(E.unwrap()), + .num_shared_experts = shared_experts, + .routed_scaling_factor = routed_scaling_factor, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +// TODO this may not be related to *hash* topk, thus may move +struct MaskKernel { + static constexpr auto kernel = mask_topk_ids_padded_region; + + static void run(tvm::ffi::TensorView topk_ids, tvm::ffi::TensorView num_token_non_padded) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto K = SymbolicSize{"topk"}; + auto D = SymbolicSize{"stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N, K}) // + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(topk_ids); + RuntimeCheck(num_token_non_padded.numel() == 1, "num_token_non_padded should be a scalar"); + RuntimeCheck(K.unwrap() <= device::kWarpThreads, "MaskKernel requires topk <= warp size"); + const int32_t* ntn_ptr = nullptr; + int32_t ntn_value = 0; + const auto ntn_dev = num_token_non_padded.device().device_type; + if (ntn_dev == kDLCUDA) { + RuntimeCheck(is_type(num_token_non_padded.dtype()), "num_token_non_padded on CUDA must be int32"); + ntn_ptr = static_cast(num_token_non_padded.data_ptr()); + } else if (ntn_dev == kDLCPU) { + if (is_type(num_token_non_padded.dtype())) { + ntn_value = *static_cast(num_token_non_padded.data_ptr()); + } else if (is_type(num_token_non_padded.dtype())) { + ntn_value = static_cast(*static_cast(num_token_non_padded.data_ptr())); + } else { + RuntimeCheck(false, "num_token_non_padded on CPU must be int32 or int64"); + } + } else { + RuntimeCheck(false, "num_token_non_padded must be on CPU or CUDA"); + } + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = TopKParams{ + .topk_ids = static_cast(topk_ids.data_ptr()), + .ntn_ptr = ntn_ptr, + .ntn_value = ntn_value, + .stride = static_cast(D.unwrap()), + .topk = static_cast(K.unwrap()), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(true)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh new file mode 100644 index 0000000000..aefec24372 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh @@ -0,0 +1,82 @@ +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace { + +/// NOTE: for offload to cpu kernel, we use persistent kernel +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kBlockQuota = 4; + +#define OFFLOAD_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct OffloadParams { + void** gpu_caches; + void** cpu_caches; + const int64_t* gpu_indices; + const int64_t* cpu_indices; + uint32_t num_items; + uint32_t num_layers; +}; + +OFFLOAD_KERNEL void offload_to_cpu(const __grid_constant__ OffloadParams params) { + using namespace device::hisparse; + const auto [gpu_caches, cpu_caches, gpu_indices, cpu_indices, num_items, num_layers] = params; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + constexpr auto kNumWarps = (kBlockSize / 32) * kBlockQuota; + for (auto i = global_tid / 32; i < num_items; i += kNumWarps) { + const int32_t gpu_index = gpu_indices[i]; + const int32_t cpu_index = cpu_indices[i]; + for (auto j = 0u; j < num_layers; ++j) { + const auto gpu_cache = gpu_caches[j]; + const auto cpu_cache = cpu_caches[j]; + transfer_item( + /*dst_cache=*/cpu_cache, + /*src_cache=*/gpu_cache, + /*dst_index=*/cpu_index, + /*src_index=*/gpu_index); + } + } +} + +[[maybe_unused]] +void hisparse_transfer( + tvm::ffi::TensorView gpu_ptrs, + tvm::ffi::TensorView cpu_ptrs, + tvm::ffi::TensorView gpu_indices, + tvm::ffi::TensorView cpu_indices) { + using namespace host; + auto N = SymbolicSize{"num_items"}; + auto L = SymbolicSize{"num_layers"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({L}) // 1D cache pointers + .with_dtype() + .with_device(device_) + .verify(gpu_ptrs) + .verify(cpu_ptrs); + TensorMatcher({N}) // 1D indices + .with_dtype() + .with_device(device_) + .verify(gpu_indices) + .verify(cpu_indices); + const auto params = OffloadParams{ + .gpu_caches = static_cast(gpu_ptrs.data_ptr()), + .cpu_caches = static_cast(cpu_ptrs.data_ptr()), + .gpu_indices = static_cast(gpu_indices.data_ptr()), + .cpu_indices = static_cast(cpu_indices.data_ptr()), + .num_items = static_cast(N.unwrap()), + .num_layers = static_cast(L.unwrap()), + }; + LaunchKernel(kBlockQuota, kBlockSize, device_.unwrap())(offload_to_cpu, params); +} + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh new file mode 100644 index 0000000000..8fc8d0821d --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh @@ -0,0 +1,845 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +SGL_DEVICE uint8_t quant_fp4_e2m1(float x) { + const float ax = fminf(fabsf(x), 6.0f); + uint8_t idx = 0; + idx += ax > 0.25f; + idx += ax > 0.75f; + idx += ax > 1.25f; + idx += ax > 1.75f; + idx += ax > 2.5f; + idx += ax > 3.5f; + idx += ax > 5.0f; + if (x < 0.0f && idx != 0) idx |= 0x8; + return idx; +} + +// 4 warps per block: warp-per-(token, head) work-item dispatch (Q kernel). +constexpr uint32_t kFusedQBlockSize = 128; +constexpr uint32_t kFusedQNumWarps = kFusedQBlockSize / device::kWarpThreads; + +// 8 warps per block: block-per-token work-item dispatch (K kernel). +constexpr uint32_t kFusedKBlockSize = 256; +constexpr uint32_t kFusedKNumWarps = kFusedKBlockSize / device::kWarpThreads; + +#define Q_KERNEL __global__ __launch_bounds__(kFusedQBlockSize, 16) +#define K_KERNEL __global__ __launch_bounds__(kFusedKBlockSize, 8) + +// ============================================================================ +// Q kernel: warp-per-(token, head) rmsnorm-self + RoPE + write to q_out. +// ============================================================================ + +struct FusedQNormRopeParams { + const void* __restrict__ q_input; // (B, num_q_heads, kHeadDim) DType + void* __restrict__ q_output; // (B, num_q_heads, kHeadDim) DType + const float* __restrict__ freqs_cis; // (max_pos, kRopeDim) fp32 (re/im interleaved) + const void* __restrict__ positions; // (B,) PosT + int64_t q_input_stride_batch; + int64_t q_output_stride_batch; + uint32_t batch_size; + uint32_t num_q_heads; + float eps; +}; + +template +Q_KERNEL void fused_q_norm_rope(const __grid_constant__ FusedQNormRopeParams params) { + using namespace device; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % kVecSize == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeDim == kWarpThreads * 2, "1 (real, imag) pair per lane"); + + using Storage = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kFusedQNumWarps + warp_id; + + const uint32_t total_works = params.batch_size * params.num_q_heads; + if (work_id >= total_works) return; + + const uint32_t batch_id = work_id / params.num_q_heads; + const uint32_t head_id = work_id % params.num_q_heads; + const auto input_ptr = + static_cast(params.q_input) + batch_id * params.q_input_stride_batch + head_id * kHeadDim; + const auto output_ptr = + static_cast(params.q_output) + batch_id * params.q_output_stride_batch + head_id * kHeadDim; + const auto position = static_cast(static_cast(params.positions)[batch_id]); + + __shared__ Storage s_rope[kFusedQNumWarps][kRopeSize]; + + // Prefetch this lane's freq pair before the PDL gate so the wait happens + // outside the dependency chain on `position`. + const auto mem_freq = tile::Memory{lane_id, kWarpThreads}; + + PDLWaitPrimary(); + + // part 1: rmsnorm-self (no weight). + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + Storage input_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input_ptr, i); + } + + const auto freq = mem_freq.load(params.freqs_cis + position * kRopeDim); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto x = cast(input_vec[i][j]); + sum_of_squares += x * x; + } + } + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto x = cast(input_vec[i][j]); + input_vec[i][j] = cast(x * norm_factor); + } + } + + // Stash the rope tail (last kRopeSize lanes' last tile) into shared memory; + // write nope tiles to gmem directly. + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(output_ptr, input_vec[i], i); + } + } + __syncwarp(); + + PDLTriggerSecondary(); + + // part 2: RoPE on all 32 lanes -- one (real, imag) bf16x2 pair per lane. + using DType2 = packed_t; + const auto mem_elem = tile::Memory{lane_id, kWarpThreads}; + const auto elem = mem_elem.load(s_rope[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t rotated = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(output_ptr + (kHeadDim - kRopeDim), cast(rotated)); +} + +template +struct FusedQNormRopeKernel { + template + static constexpr auto kernel = fused_q_norm_rope; + + static void forward( + const tvm::ffi::TensorView q_input, + const tvm::ffi::TensorView q_output, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + float eps) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_q_heads"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_input); + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_output); + TensorMatcher({-1, kRopeDim}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_heads = static_cast(H.unwrap()); + if (batch_size == 0) return; + + const auto params = FusedQNormRopeParams{ + .q_input = q_input.data_ptr(), + .q_output = q_output.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_input_stride_batch = q_input.stride(0), + .q_output_stride_batch = q_output.stride(0), + .batch_size = batch_size, + .num_q_heads = num_q_heads, + .eps = eps, + }; + const auto total_works = batch_size * num_q_heads; + const auto num_blocks = div_ceil(total_works, kFusedQNumWarps); + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(num_blocks, kFusedQBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(k, params); + } +}; + +// ============================================================================ +// K kernel: block-per-token rmsnorm (with kv_weight) + RoPE + FlashMLA store. +// ============================================================================ + +struct FusedKNormRopeFlashMLAParams { + const void* __restrict__ kv; // (B, kHeadDim) DType + const void* __restrict__ kv_weight; // (kHeadDim,) DType + const float* __restrict__ freqs_cis; // (max_pos, kRopeDim) fp32 + const void* __restrict__ positions; // (B,) PosT + const int32_t* __restrict__ out_loc; // (B,) int32 -> cache slot id + uint8_t* __restrict__ kvcache; // (npages, kPageBytes) uint8 + // Row stride for `kv` in elements. Required because the upstream caller often + // passes `qkv_a[..., q_lora_rank:]`, a non-contiguous slice whose stride[0] + // equals `q_lora_rank + kHeadDim` rather than `kHeadDim`. + int64_t kv_stride_batch; + uint32_t batch_size; + float eps; +}; + +template +K_KERNEL void fused_k_norm_rope_flashmla(const __grid_constant__ FusedKNormRopeFlashMLAParams params) { + using namespace device; + + constexpr int64_t kVecSize = 2; + constexpr uint32_t kRopeWarp = kFusedKNumWarps - 1; + constexpr int64_t kPageBytes = host::div_ceil(584ll << kPageBits, 576) * 576; + static_assert(kHeadDim == kFusedKBlockSize * kVecSize); + static_assert(kRopeDim == kWarpThreads * kVecSize); + static_assert(kHeadDim - kRopeDim == kRopeWarp * kWarpThreads * kVecSize); + using Storage = AlignedVector; + using Float2 = AlignedVector; + + const auto tx = threadIdx.x; + const auto warp_id = tx / kWarpThreads; + const auto lane_id = tx % kWarpThreads; + const auto work_id = blockIdx.x; + if (work_id >= params.batch_size) return; + + const auto input_ptr = static_cast(params.kv) + work_id * params.kv_stride_batch; + const auto position = static_cast(static_cast(params.positions)[work_id]); + const auto out_loc = params.out_loc[work_id]; + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float2 data, freq; + + // part 1: norm. Each thread owns one 2-elem pack (the `tx`-th). + // Sum-of-squares is reduced block-wide via per-warp partials. + { + __shared__ float partial_sums[kFusedKNumWarps]; + + Storage input_vec, weight_vec; + input_vec.load(input_ptr, tx); + weight_vec.load(params.kv_weight, tx); + if (warp_id == kRopeWarp) freq.load(freqs_cis, lane_id); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto x = cast(input_vec[i]); + sum_of_squares += x * x; + } + const auto warp_sum = warp::reduce_sum(sum_of_squares); + if (lane_id == 0) partial_sums[warp_id] = warp_sum; + __syncthreads(); + // Replicate the per-warp partial sums onto all lanes of one warp and + // reduce. Every group of `kBlockItemNumWarps` lanes ends up with the + // global sum. + sum_of_squares = warp::reduce_sum(partial_sums[lane_id % kFusedKNumWarps]); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto x = cast(input_vec[i]); + const auto w = cast(weight_vec[i]); + data[i] = x * norm_factor * w; + } + } + + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 576; + + PDLTriggerSecondary(); + + // part 2: rope on warp 7 (BF16 store), per-warp UE8M0 quant + store on warps 0..6. + if (warp_id == kRopeWarp) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto freq_real = freq[0]; + const auto freq_imag = freq[1]; + data[0] = x_real * freq_real - x_imag * freq_imag; + data[1] = x_real * freq_imag + x_imag * freq_real; + const auto result = cast(fp32x2_t{data[0], data[1]}); + const auto rope_ptr = value_ptr + 448; + reinterpret_cast(rope_ptr)[lane_id] = result; + } else { + const auto x = data[0]; + const auto y = data[1]; + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const auto scale_ptr = page_ptr + (576 << kPageBits) + offset * 8; + reinterpret_cast(value_ptr)[tx] = result; + if (lane_id == 0) static_cast(scale_ptr)[warp_id] = scale_ue8m0; + } +} + +template +struct FusedKNormRopeFlashMLAKernel { + static constexpr int32_t kLogPageSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogPageSize == kPageSize); + static_assert(kHeadDim == 512 && kRopeDim == 64, "FlashMLA layout requires (512, 64)"); + + template + static constexpr auto kernel = fused_k_norm_rope_flashmla; + + static void forward( + const tvm::ffi::TensorView kv, + const tvm::ffi::TensorView kv_weight, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + const tvm::ffi::TensorView out_loc, + const tvm::ffi::TensorView kvcache, + float eps) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // + .with_strides({-1, 1}) + .with_dtype() + .with_device(device_) + .verify(kv); + TensorMatcher({kHeadDim}) // + .with_dtype() + .with_device(device_) + .verify(kv_weight); + TensorMatcher({-1, kRopeDim}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(out_loc); + TensorMatcher({-1, -1}) // + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(kvcache); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto params = FusedKNormRopeFlashMLAParams{ + .kv = kv.data_ptr(), + .kv_weight = kv_weight.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .out_loc = static_cast(out_loc.data_ptr()), + .kvcache = static_cast(kvcache.data_ptr()), + .kv_stride_batch = kv.stride(0), + .batch_size = batch_size, + .eps = eps, + }; + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(batch_size, kFusedKBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(k, params); + } +}; + +// ============================================================================ +// Indexer Q kernel: warp-per-(token, head) RoPE + Hadamard + fp8 act-quant. +// ============================================================================ + +struct FusedQIndexerRopeHadamardQuantParams { + const void* __restrict__ q_input; // (B, num_heads, 128) DType + void* __restrict__ q_fp8; // (B, num_heads, 128) fp8_e4m3 + // weights_out[b, h] = weight[b, h] * weight_scale * q_scale[b, h]. + // q_scale is computed internally and not exposed -- the only consumer of + // it is `weights_out`. + const void* __restrict__ weight; // (B, num_heads) DType + float* __restrict__ weights_out; // (B, num_heads) fp32 (== (B, H, 1) flat) + float weight_scale; // scalar c4_indexer.weight_scale + const float* __restrict__ freqs_cis; // (max_pos, 64) fp32 + const void* __restrict__ positions; // (B,) PosT + uint32_t batch_size; + uint32_t num_heads; +}; + +template +Q_KERNEL void fused_q_indexer_rope_hadamard_quant(const __grid_constant__ FusedQIndexerRopeHadamardQuantParams params) { + using namespace device; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; // = 16 + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + + using Storage = AlignedVector; + using Float4 = AlignedVector; + using OutStorage = AlignedVector; // 4 fp8 / lane + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kFusedQNumWarps + warp_id; + // Last `kRopeSize` lanes own the rope tail; their 4-elem packs cover the + // trailing kRopeDim elements. + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + const uint32_t total_works = params.batch_size * params.num_heads; + if (work_id >= total_works) return; + + const uint32_t batch_id = work_id / params.num_heads; + const auto input_ptr = static_cast(params.q_input) + work_id * kHeadDim; + const auto position = static_cast(static_cast(params.positions)[batch_id]); + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + // Lane 0 prefetches the weight scalar for this (token, head) work item. + // Weight is (B, num_heads) DType; we need one scalar per warp -- offload + // the load to lane 0 only. The multiply + store happens once the q_scale + // is known (part 4). + + PDLWaitPrimary(); + Float4 data, freq; + const auto weight_val = cast(static_cast(params.weight)[work_id]); + + // part 1: load (no norm). Each lane owns a 4-elem pack. + { + Storage input_vec; + input_vec.load(input_ptr, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + data[i] = cast(input_vec[i]); + } + } + + // part 2: rope on rope lanes only (4 elems / lane = 2 (real, imag) pairs). + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto fxr = freq[0]; + const auto fxi = freq[1]; + const auto fyr = freq[2]; + const auto fyi = freq[3]; + data[0] = x_real * fxr - x_imag * fxi; + data[1] = x_real * fxi + x_imag * fxr; + data[2] = y_real * fyr - y_imag * fyi; + data[3] = y_real * fyi + y_imag * fyr; + } + + PDLTriggerSecondary(); + + // part 3: 128-point Hadamard (2 local stages + 5 cross-lane shfl_xor stages). + // Same recipe as `fused_norm_rope_indexer`; see comments there for the + // butterfly invariants and the early-return safety argument. + { + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const float other = __shfl_xor_sync(0xFFFFFFFFu, data[i], mask, kWarpThreads); + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; + } + + { + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + const auto abs_max = warp::reduce_max(local_max); + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + OutStorage result; + result[0] = pack_fp8(data[0] * inv_scale, data[1] * inv_scale); + result[1] = pack_fp8(data[2] * inv_scale, data[3] * inv_scale); + + // q_fp8 row pointer: 128 fp8 / row = 32 OutStorage / row, one per lane. + auto out_row = static_cast(params.q_fp8) + work_id * kHeadDim; + result.store(out_row, lane_id); + params.weights_out[work_id] = weight_val * params.weight_scale * scale; + } +} + +template +struct FusedQIndexerRopeHadamardQuantKernel { + template + static constexpr auto kernel = fused_q_indexer_rope_hadamard_quant; + + static void forward( + const tvm::ffi::TensorView q_input, + const tvm::ffi::TensorView q_fp8, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView weights_out, + double weight_scale, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions) { + using namespace host; + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + + auto B = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_heads"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + // Caller path is `wq_b(q_lora).view(-1, H, D)` -> contiguous; the kernel + // assumes a flat `(B*H, kHeadDim)` layout for both q_input and q_fp8. + // Pin the head/innermost strides; assert the batch stride below. + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_input); + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_fp8); + TensorMatcher({B, H}) // + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({B, H, 1}) // + .with_dtype() + .with_device(device_) + .verify(weights_out); + TensorMatcher({-1, kRopeDim}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + + const auto batch_size = static_cast(B.unwrap()); + const auto num_heads = static_cast(H.unwrap()); + if (batch_size == 0) return; + + // The kernel computes row pointers as `base + work_id * kHeadDim`, so + // both inputs must be contiguous in (batch, head, elem) order. + const int64_t expected_batch_stride = static_cast(num_heads) * kHeadDim; + RuntimeCheck( + q_input.stride(0) == expected_batch_stride, + "q_input must be contiguous (B, H, kHeadDim); got stride[0]=", + q_input.stride(0)); + RuntimeCheck( + q_fp8.stride(0) == expected_batch_stride, + "q_fp8 must be contiguous (B, H, kHeadDim); got stride[0]=", + q_fp8.stride(0)); + + const auto params = FusedQIndexerRopeHadamardQuantParams{ + .q_input = q_input.data_ptr(), + .q_fp8 = q_fp8.data_ptr(), + .weight = weight.data_ptr(), + .weights_out = static_cast(weights_out.data_ptr()), + .weight_scale = static_cast(weight_scale), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .batch_size = batch_size, + .num_heads = num_heads, + }; + const auto total_works = batch_size * num_heads; + const auto num_blocks = div_ceil(total_works, kFusedQNumWarps); + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(num_blocks, kFusedQBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(k, params); + } +}; + +struct FusedQIndexerRopeHadamardFp4QuantParams { + const void* __restrict__ q_input; + void* __restrict__ q_fp4; + int32_t* __restrict__ q_sf; + const void* __restrict__ weight; + float* __restrict__ weights_out; + float weight_scale; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + uint32_t batch_size; + uint32_t num_heads; +}; + +template +Q_KERNEL void +fused_q_indexer_rope_hadamard_fp4_quant(const __grid_constant__ FusedQIndexerRopeHadamardFp4QuantParams params) { + using namespace device; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + + using Storage = AlignedVector; + using Float4 = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kFusedQNumWarps + warp_id; + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + const uint32_t total_works = params.batch_size * params.num_heads; + if (work_id >= total_works) return; + + const uint32_t batch_id = work_id / params.num_heads; + const auto input_ptr = static_cast(params.q_input) + work_id * kHeadDim; + const auto position = static_cast(static_cast(params.positions)[batch_id]); + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float4 data, freq; + const auto weight_val = cast(static_cast(params.weight)[work_id]); + + { + Storage input_vec; + input_vec.load(input_ptr, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + data[i] = cast(input_vec[i]); + } + } + + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto fxr = freq[0]; + const auto fxi = freq[1]; + const auto fyr = freq[2]; + const auto fyi = freq[3]; + data[0] = x_real * fxr - x_imag * fxi; + data[1] = x_real * fxi + x_imag * fxr; + data[2] = y_real * fyr - y_imag * fyi; + data[3] = y_real * fyi + y_imag * fyr; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] = cast(cast(data[i])); + } + + PDLTriggerSecondary(); + + { + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const float other = __shfl_xor_sync(0xFFFFFFFFu, data[i], mask, kWarpThreads); + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] = cast(cast(data[i])); + } + + { + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + local_max = warp::reduce_max<8>(local_max); + const auto scale_raw = fmaxf(1e-4f, local_max) / 6.0f; + const auto scale_ue8m0 = static_cast(cast_to_ue8m0(scale_raw)); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const uint8_t packed0 = quant_fp4_e2m1(data[0] * inv_scale) | (quant_fp4_e2m1(data[1] * inv_scale) << 4); + const uint8_t packed1 = quant_fp4_e2m1(data[2] * inv_scale) | (quant_fp4_e2m1(data[3] * inv_scale) << 4); + const uint16_t packed = static_cast(packed0) | (static_cast(packed1) << 8); + auto out_row = static_cast(params.q_fp4) + work_id * (kHeadDim / 2); + reinterpret_cast(out_row)[lane_id] = packed; + if ((lane_id & 7) == 0) { + reinterpret_cast(params.q_sf + work_id)[lane_id >> 3] = scale_ue8m0; + } + params.weights_out[work_id] = weight_val * params.weight_scale; + } +} + +template +struct FusedQIndexerRopeHadamardFp4QuantKernel { + template + static constexpr auto kernel = fused_q_indexer_rope_hadamard_fp4_quant; + + static void forward( + const tvm::ffi::TensorView q_input, + const tvm::ffi::TensorView q_fp4, + const tvm::ffi::TensorView q_sf, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView weights_out, + double weight_scale, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions) { + using namespace host; + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kFp4Dim = kHeadDim / 2; + + auto B = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_heads"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, H, kHeadDim}) + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_input); + TensorMatcher({B, H, kFp4Dim}) + .with_strides({-1, kFp4Dim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_fp4); + TensorMatcher({B, H}).with_dtype().with_device(device_).verify(q_sf); + TensorMatcher({B, H}).with_dtype().with_device(device_).verify(weight); + TensorMatcher({B, H, 1}).with_dtype().with_device(device_).verify(weights_out); + TensorMatcher({-1, kRopeDim}).with_dtype().with_device(device_).verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}).with_dtype(pos_dtype).with_device(device_).verify(positions); + + const auto batch_size = static_cast(B.unwrap()); + const auto num_heads = static_cast(H.unwrap()); + if (batch_size == 0) return; + + const int64_t expected_q_stride = static_cast(num_heads) * kHeadDim; + const int64_t expected_fp4_stride = static_cast(num_heads) * kFp4Dim; + RuntimeCheck(q_input.stride(0) == expected_q_stride, "q_input must be contiguous"); + RuntimeCheck(q_fp4.stride(0) == expected_fp4_stride, "q_fp4 must be contiguous"); + RuntimeCheck(q_sf.stride(0) == static_cast(num_heads) && q_sf.stride(1) == 1, "q_sf must be contiguous"); + + const auto params = FusedQIndexerRopeHadamardFp4QuantParams{ + .q_input = q_input.data_ptr(), + .q_fp4 = q_fp4.data_ptr(), + .q_sf = static_cast(q_sf.data_ptr()), + .weight = weight.data_ptr(), + .weights_out = static_cast(weights_out.data_ptr()), + .weight_scale = static_cast(weight_scale), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .batch_size = batch_size, + .num_heads = num_heads, + }; + const auto total_works = batch_size * num_heads; + const auto num_blocks = div_ceil(total_works, kFusedQNumWarps); + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(num_blocks, kFusedQBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(k, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh new file mode 100644 index 0000000000..7d5f97824b --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh @@ -0,0 +1,219 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct MegaMoEPreDispatchParams { + const bf16_t* __restrict__ x; // [num_tokens, hidden] + const int32_t* __restrict__ topk_idx; // [num_tokens, top_k] + const float* __restrict__ topk_weights; // [num_tokens, top_k] + + fp8_e4m3_t* __restrict__ buf_x; // [padded_max, hidden] + int32_t* __restrict__ buf_x_sf; // contiguous int32 [P, G/4]; see layout comment + int64_t* __restrict__ buf_topk_idx; // [padded_max, top_k] + float* __restrict__ buf_topk_weights; // [padded_max, top_k] + + uint32_t num_tokens; + uint32_t padded_max; + uint32_t hidden; + uint32_t num_groups; // hidden / group_size + uint32_t top_k; +}; + +// kGroupSize must match sglang_per_token_group_quant_fp8_ue8m0(group_size=). +template +__global__ __launch_bounds__(1024, 2) void // + mega_moe_pre_dispatch_kernel(const MegaMoEPreDispatchParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kVecElems = 8; // 8 bf16 = 16B load per thread + static_assert(kGroupSize % kVecElems == 0, "group_size must be a multiple of 8"); + constexpr uint32_t kThreadsPerGroup = kGroupSize / kVecElems; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + + PDLWaitPrimary(); + if (bid < params.num_tokens) { + // ---- Quantize path: one CTA per valid token ---- + + const uint32_t token_id = bid; + const auto token_in = params.x + static_cast(token_id) * params.hidden; + const auto token_out = params.buf_x + static_cast(token_id) * params.hidden; + + InputVec in_vec; + in_vec.load(token_in, tid); + + float local_max = 0.0f; + float vals[kVecElems]; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + const auto [v0, v1] = cast(in_vec[i]); + vals[2 * i + 0] = v0; + vals[2 * i + 1] = v1; + local_max = fmaxf(local_max, fmaxf(fabsf(v0), fabsf(v1))); + } + + // Absmax across the kThreadsPerGroup threads that cover one group. + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + const float raw_scale = absmax / math::FP8_E4M3_MAX; + const uint32_t ue8m0_exp = cast_to_ue8m0(raw_scale); + // 2^-ue8m0_exp as fp32 (equivalent to 1 / __uint_as_float(ue8m0 << 23)). + const float inv_scale = __uint_as_float((127u + 127u - ue8m0_exp) << 23); + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + out_vec[i] = pack_fp8(vals[2 * i + 0] * inv_scale, vals[2 * i + 1] * inv_scale); + } + out_vec.store(token_out, tid); + + // One thread per group writes its UE8M0 byte into the contiguous + // row-major int32-packed layout: byte address = t*num_groups + g + // (see layout comment at the top of the file). + const uint32_t group_id = tid / kThreadsPerGroup; + const uint32_t within_group_id = tid % kThreadsPerGroup; + if (within_group_id == 0 && group_id < params.num_groups) { + const uint32_t byte_off = token_id * params.num_groups + group_id; + reinterpret_cast(params.buf_x_sf)[byte_off] = static_cast(ue8m0_exp); + } + + // Copy this token's topk row (no alignment assumptions; top_k is small). + if (tid < params.top_k) { + const uint32_t off = token_id * params.top_k + tid; + params.buf_topk_idx[off] = params.topk_idx[off]; + params.buf_topk_weights[off] = params.topk_weights[off]; + } + } else { + // ---- Pad path: trailing blocks fill [num_tokens, padded_max) with (-1, 0) ---- + const uint32_t copy_bid = bid - params.num_tokens; + const uint32_t pad_base = params.num_tokens * params.top_k; + const uint32_t slot = pad_base + copy_bid * blockDim.x + tid; + const uint32_t total_slots = params.padded_max * params.top_k; + + if (slot < total_slots) { + params.buf_topk_idx[slot] = -1; + params.buf_topk_weights[slot] = 0.0f; + } + } + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct MegaMoEPreDispatchKernel { + static_assert(kGroupSize == 32 || kGroupSize == 64 || kGroupSize == 128, "unsupported group_size"); + static constexpr auto kernel = mega_moe_pre_dispatch_kernel(kGroupSize), kUsePDL>; + + static void + run(const tvm::ffi::TensorView x, + const tvm::ffi::TensorView topk_idx, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView buf_x, + const tvm::ffi::TensorView buf_x_sf, + const tvm::ffi::TensorView buf_topk_idx, + const tvm::ffi::TensorView buf_topk_weights) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto P = SymbolicSize{"padded_max"}; + auto H = SymbolicSize{"hidden"}; + auto K = SymbolicSize{"top_k"}; + auto G4 = SymbolicSize{"num_groups_div_4"}; + device.set_options(); + + TensorMatcher({M, H}) // input x + .with_dtype() + .with_device(device) + .verify(x); + TensorMatcher({M, K}) // topk_idx + .with_dtype() + .with_device(device) + .verify(topk_idx); + TensorMatcher({M, K}) // topk_weights + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({P, H}) // buf.x + .with_dtype() + .with_device(device) + .verify(buf_x); + // buf.x_sf is the contiguous row-major int32 view from DeepGEMM's mega + // symm buffer (DeepGEMM/csrc/apis/mega.hpp): shape (P, G/4), strides + // (G/4, 1). No explicit strides required -> TensorMatcher enforces + // is_contiguous(). + TensorMatcher({P, G4}) // buf_x_sf + .with_dtype() + .with_device(device) + .verify(buf_x_sf); + TensorMatcher({P, K}) // buf.topk_idx + .with_dtype() + .with_device(device) + .verify(buf_topk_idx); + TensorMatcher({P, K}) // buf.topk_weights + .with_dtype() + .with_device(device) + .verify(buf_topk_weights); + + const auto num_tokens = static_cast(M.unwrap()); + const auto padded_max = static_cast(P.unwrap()); + const auto hidden = static_cast(H.unwrap()); + const auto top_k = static_cast(K.unwrap()); + const auto num_groups_div_4 = static_cast(G4.unwrap()); + + RuntimeCheck(num_tokens <= padded_max, "num_tokens must not exceed padded_max"); + RuntimeCheck(hidden % kGroupSize == 0, "hidden must be a multiple of group_size"); + const auto num_groups = hidden / static_cast(kGroupSize); + RuntimeCheck(num_groups == num_groups_div_4 * 4u, "num_groups must be a multiple of 4"); + RuntimeCheck(hidden % 8u == 0, "hidden must be a multiple of 8 (16B bf16 loads)"); + const auto num_threads = hidden / 8u; + RuntimeCheck(num_threads <= 1024, "hidden too large for single-block-per-row quant"); + RuntimeCheck(num_threads >= top_k, "top_k must fit into one quant CTA"); + + const auto pad_slots = (padded_max - num_tokens) * top_k; + const uint32_t num_pad_blocks = pad_slots == 0 ? 0u : ((pad_slots + num_threads - 1u) / num_threads); + const auto num_total_blocks = num_tokens + num_pad_blocks; + + const auto params = MegaMoEPreDispatchParams{ + .x = static_cast(x.data_ptr()), + .topk_idx = static_cast(topk_idx.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .buf_x = static_cast(buf_x.data_ptr()), + .buf_x_sf = static_cast(buf_x_sf.data_ptr()), + .buf_topk_idx = static_cast(buf_topk_idx.data_ptr()), + .buf_topk_weights = static_cast(buf_topk_weights.data_ptr()), + .num_tokens = num_tokens, + .padded_max = padded_max, + .hidden = hidden, + .num_groups = num_groups, + .top_k = top_k, + }; + + if (num_total_blocks == 0) return; + LaunchKernel(num_total_blocks, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh new file mode 100644 index 0000000000..38be975558 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh @@ -0,0 +1,119 @@ +#include +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kBlockSize = 1024; +constexpr uint32_t kSplitKV = 256; // const for both SM90 and SM100 + +struct MetadataParams { + /// NOTE: batch_size > 0 + uint32_t batch_size; + uint32_t num_sm; + const uint32_t* __restrict__ context_lens; + uint32_t* __restrict__ schedule_metadata; + bool use_smem = true; +}; + +__global__ __launch_bounds__(kBlockSize, 1) // + void smxx_paged_mqa_logits_metadata(const MetadataParams params) { + using namespace device; + extern __shared__ uint32_t s_length[]; + static constexpr auto kNumWarps = kBlockSize / kWarpThreads; + static_assert(kNumWarps == kWarpThreads); + + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + __shared__ uint32_t s_warp_sum[kNumWarps]; + + uint32_t local_sum = 0; + for (uint32_t i = tx; i < params.batch_size; i += kBlockSize) { + const auto length = params.context_lens[i]; + local_sum += (length + kSplitKV - 1) / kSplitKV; + if (params.use_smem) s_length[i] = length; + } + + s_warp_sum[warp_id] = warp::reduce_sum(local_sum); + __syncthreads(); + + const auto global_sum = warp::reduce_sum(s_warp_sum[lane_id]); + if (lane_id != 0) return; + + const auto length_ptr = params.use_smem ? s_length : params.context_lens; + + const auto avg = global_sum / params.num_sm; + const auto ret = global_sum % params.num_sm; + uint32_t q = 0; + uint32_t num_work = (length_ptr[0] + kSplitKV - 1) / kSplitKV; + uint32_t sum_work = num_work; + for (auto i = warp_id; i <= params.num_sm; i += kNumWarps) { + const auto target = i * avg + min(i, ret); + while (sum_work <= target) { + if (++q >= params.batch_size) break; + num_work = (length_ptr[q] + kSplitKV - 1) / kSplitKV; + sum_work += num_work; + } + if (q >= params.batch_size) { + params.schedule_metadata[2 * i + 0] = params.batch_size; + params.schedule_metadata[2 * i + 1] = 0; + } else { + // sum > target && (sum - length) <= target + params.schedule_metadata[2 * i + 0] = q; + params.schedule_metadata[2 * i + 1] = target - (sum_work - num_work); + } + } +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +struct IndexerMetadataKernel { + static constexpr auto kMaxBatchSizeInSmem = 16384 * 2; // 128 KB smeme + static void run(tvm::ffi::TensorView seq_lens, tvm::ffi::TensorView metadata) { + using namespace host; + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_sm"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({M, 2}) // + .with_dtype() + .with_device(device) + .verify(metadata); + const auto batch_size = static_cast(N.unwrap()); + const auto num_sm = static_cast(M.unwrap()) - 1; + RuntimeCheck(num_sm <= 1024); + const auto use_smem = batch_size <= kMaxBatchSizeInSmem; + const auto params = MetadataParams{ + .batch_size = batch_size, + .num_sm = num_sm, + .context_lens = static_cast(seq_lens.data_ptr()), + .schedule_metadata = static_cast(metadata.data_ptr()), + .use_smem = use_smem, + }; + constexpr auto kernel = smxx_paged_mqa_logits_metadata; + setup_kernel_smem_once(); + const auto smem = use_smem ? (batch_size + 1) * sizeof(uint32_t) : 0; + LaunchKernel(1, kBlockSize, device.unwrap(), smem)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh new file mode 100644 index 0000000000..2239d3972d --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh @@ -0,0 +1,169 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +using DType = bf16_t; +constexpr int64_t kRopeDim = 64; +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedQKRopeParams { + void* __restrict__ q; + void* __restrict__ k; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + int64_t q_stride_batch; + int64_t k_stride_batch; + int64_t q_stride_head; + int64_t k_stride_head; + uint32_t num_q_heads; + uint32_t num_k_heads; + uint32_t batch_size; +}; + +template +__global__ __launch_bounds__(kBlockSize, 16) // + void deepseek_rope_kernel(const __grid_constant__ FusedQKRopeParams param) { + using namespace device; + using DType2 = packed_t; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto global_warp_id = blockIdx.x * kNumWarps + warp_id; + + const auto& [ + q, k, freqs_cis, positions, // + q_stride_batch, k_stride_batch, q_stride_head, k_stride_head, // + num_q_heads, num_k_heads, batch_size + ] = param; + + const auto num_total_heads = num_q_heads + num_k_heads; + const auto head_id = global_warp_id % num_total_heads; + const auto batch_id = global_warp_id / num_total_heads; + if (batch_id >= batch_size) return; + + const auto position = static_cast(positions)[batch_id]; + const auto is_q = head_id < num_q_heads; + const auto local_head = is_q ? head_id : (head_id - num_q_heads); + const auto stride_batch = is_q ? q_stride_batch : k_stride_batch; + const auto stride_head = is_q ? q_stride_head : k_stride_head; + const auto base_ptr = is_q ? q : k; + const auto input = static_cast(pointer::offset(base_ptr, batch_id * stride_batch, local_head * stride_head)); + + const auto freq_ptr = reinterpret_cast(freqs_cis + position * kRopeDim); + const auto [f_real, f_imag] = freq_ptr[lane_id]; + PDLWaitPrimary(); + + const auto data = input[lane_id]; + const auto [x_real, x_imag] = cast(data); + fp32x2_t output; + if constexpr (kInverse) { + // (a + bi) * (c - di) = (ac + bd) + (bc - ad)i + output = { + x_real * f_real + x_imag * f_imag, + x_imag * f_real - x_real * f_imag, + }; + } else { + // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + output = { + x_real * f_real - x_imag * f_imag, + x_real * f_imag + x_imag * f_real, + }; + } + input[lane_id] = cast(output); + + PDLTriggerSecondary(); +} + +template +struct FusedQKRopeKernel { + // 4 kernel variants: {forward, inverse} x {int32, int64} + static constexpr auto kernel_fwd_i32 = deepseek_rope_kernel; + static constexpr auto kernel_fwd_i64 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i32 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i64 = deepseek_rope_kernel; + + static void forward( + const tvm::ffi::TensorView q, + const tvm::ffi::Optional k, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + bool inverse) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto Q = SymbolicSize{"num_q_heads"}; + auto K = SymbolicSize{"num_k_heads"}; + constexpr auto D = kRopeDim; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, Q, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(q); + if (k.has_value()) { + TensorMatcher({B, K, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(k.value()); + } else { + K.set_value(0); + } + TensorMatcher({-1, D}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + const bool pos_i32 = pos_dtype.is_type(); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto num_q_heads = static_cast(Q.unwrap()); + const auto num_k_heads = static_cast(K.unwrap()); + const auto num_total_heads = num_q_heads + num_k_heads; + const auto total_warps = batch_size * num_total_heads; + const auto num_blocks = div_ceil(total_warps, kNumWarps); + + const auto elem_size = static_cast(sizeof(DType)); + const auto params = FusedQKRopeParams{ + .q = q.data_ptr(), + .k = k ? k.value().data_ptr() : nullptr, + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_stride_batch = q.stride(0) * elem_size, + .k_stride_batch = k ? k.value().stride(0) * elem_size : 0, + .q_stride_head = q.stride(1) * elem_size, + .k_stride_head = k ? k.value().stride(1) * elem_size : 0, + .num_q_heads = num_q_heads, + .num_k_heads = num_k_heads, + .batch_size = batch_size, + }; + + // dispatch: {inverse} x {pos_i32} + using KernelType = decltype(kernel_fwd_i32); + const KernelType kernel = + inverse ? (pos_i32 ? kernel_inv_i32 : kernel_inv_i64) : (pos_i32 ? kernel_fwd_i32 : kernel_fwd_i64); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh new file mode 100644 index 0000000000..be0e759445 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh @@ -0,0 +1,540 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct SiluMulQuantVarlenParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +template +SGL_DEVICE fp32x2_t silu_and_mul(DType2 gate, DType2 up, float limit) { + using namespace device; + // refer to as implementation. TL;DR: must clamp in bf16 + // https://github.com/deepseek-ai/DeepGEMM/blob/7f2a703ed51ac1f7af07f5e1453b2d3267d37d50/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh#L984-L997 + if constexpr (kApplySwigluLimit) { + static_assert(std::is_same_v); + gate = __hmin2(gate, {limit, limit}); + up = __hmax2(up, {-limit, -limit}); + up = __hmin2(up, {limit, limit}); + } + const auto [g0, g1] = cast(gate); + const auto [u0, u1] = cast(up); + const auto silu0 = g0 / (1.0f + __expf(-g0)); + const auto silu1 = g1 / (1.0f + __expf(-g1)); + const float val0 = silu0 * u0; + const float val1 = silu1 * u1; + if constexpr (kPrecise) { // I don't know if we should enable this? + return {val0, val1}; + } else { + return cast(cast(fp32x2_t{val0, val1})); + } +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantVarlenParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_varlen_kernel(const SiluMulQuantVarlenParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + // gran=8 interleaved: every 16-element chunk on the N axis is + // [gate[0..7], up[0..7]]. Each thread handles 8 consecutive output + // elements, so its gate chunk lives at vec index 2*threadIdx.x and its + // up chunk at 2*threadIdx.x+1. + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +struct SiluAndMulClampParams { + const void* __restrict__ input; + void* __restrict__ output; + float swiglu_limit; +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_clamp_kernel(const SiluAndMulClampParams __grid_constant__ params) { + using namespace device; + static_assert(sizeof(DType) == 2, "only fp16/bf16 supported"); + using DType2 = packed_t; + constexpr auto kVecSize = 16 / sizeof(DType); + static_assert(kVecSize % 2 == 0 && kVecSize > 0); + using Vec = AlignedVector; + const auto bid = blockIdx.x; + const auto tile = tile::Memory::cta(); + const float limit = params.swiglu_limit; + + PDLWaitPrimary(); + const auto gate = tile.load(params.input, bid * 2 + 0); + const auto up = tile.load(params.input, bid * 2 + 1); + Vec out; + +#pragma unroll + for (uint32_t i = 0; i < kVecSize / 2; ++i) { + out[i] = cast(silu_and_mul(cast(gate[i]), cast(up[i]), limit)); + } + + tile.store(params.output, out, bid); + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_varlen_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_varlen_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantVarlenParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct SiluAndMulClampKernel { + static constexpr auto kernel = silu_mul_clamp_kernel; + + static void run(const tvm::ffi::TensorView input, const tvm::ffi::TensorView output, const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"gate_up_dim"}; // 2 * out_dim + auto H = SymbolicSize{"out_dim"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate || up) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, H}) // output + .with_dtype() + .with_device(device) + .verify(output); + RuntimeCheck(D.unwrap() == 2 * H.unwrap(), "input last dim must be 2 * output last dim"); + + constexpr uint32_t kVecSize = 16 / sizeof(DType); + const auto out_dim = static_cast(H.unwrap()); + const auto num_tokens = static_cast(M.unwrap()); + RuntimeCheck(out_dim % kVecSize == 0, "out_dim must be divisible by vector size"); + const auto num_threads = out_dim / kVecSize; + RuntimeCheck(num_threads <= 1024, "out_dim too large for single-block-per-row launch"); + + const auto params = SiluAndMulClampParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .swiglu_limit = static_cast(swiglu_limit), + }; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +struct SiluMulQuantContigParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t scale_row_stride_int32; // only used when kTransposed=true +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_contig_kernel(const SiluMulQuantContigParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto token_id = blockIdx.x; + const auto work_id = threadIdx.x / kWorkThreads; + + const auto input = params.input + token_id * params.hidden_dim * 2; + const auto output = params.output + token_id * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + // Physical layout is (G//4_pad, M_pad) int32; each int32 packs 4 + // consecutive UE8M0 exponents for the same token. Byte address: + // (work_id / 4) * M_pad * 4 + token * 4 + (work_id % 4). + const auto base = reinterpret_cast(params.output_scale); + return base + (work_id / 4u) * (params.scale_row_stride_int32 * 4u) + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + token_id * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +template +struct SiluAndMulContigPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_contig_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_contig_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate/up, natural or gran=8 interleaved on last dim) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, N}) // fp8 output + .with_dtype() + .with_device(device) + .verify(output); + + const auto hidden_dim = N.unwrap(); + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + const auto num_groups = static_cast(hidden_dim / kGroupSize); + + uint32_t scale_row_stride_int32 = 0; + if (!transposed) { + G.set_value(num_groups); + TensorMatcher({M, G}) // (M, G) fp32 natural row-major + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + RuntimeCheck(num_groups % 4 == 0, "transposed layout requires num_groups % 4 == 0"); + auto G_ = SymbolicSize{"G // 4"}; + G_.set_value(num_groups / 4); + auto M_pad = SymbolicSize{"M padded"}; + TensorMatcher({M, G_}) // `.transpose(-1,-2)[:M,:]` view of (G//4_pad, M_pad) int32 + .with_strides({int64_t{1}, M_pad}) // col-major transposed + .with_dtype() + .with_device(device) + .verify(output_scale); + scale_row_stride_int32 = static_cast(M_pad.unwrap()); + } + + const auto num_tokens = static_cast(M.unwrap()); + + const auto params = SiluMulQuantContigParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .scale_row_stride_int32 = scale_row_stride_int32, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh new file mode 100644 index 0000000000..49f6f55963 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh @@ -0,0 +1,205 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct FusedStoreCacheParam { + const void* __restrict__ input; + void* __restrict__ cache; + const void* __restrict__ indices; + uint32_t num_tokens; +}; + +template +__global__ void fused_store_flashmla_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 584 = 576 + 8 + constexpr int64_t kPageBytes = host::div_ceil(584 << kPageBits, 576) * 576; + + // each warp handles 64 elements, 8 warps, each block handles 1 row + const auto& [input, cache, indices, num_tokens] = param; + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + const uint32_t wid = tid / 32; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[bid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + const auto elems = static_cast(input)[tid + bid * 256]; + if (wid != 7) { + const auto [x, y] = cast(elems); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576); + const auto scale_ptr = pointer::offset(page_ptr, 576 << kPageBits, offset * 8); + static_cast(value_ptr)[tid] = result; + static_cast(scale_ptr)[wid] = scale_ue8m0; + } else { + const auto result = cast(elems); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576, 448); + static_cast(value_ptr)[tid - 7 * 32] = result; + } + + PDLTriggerSecondary(); +} + +template +__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 132 = 128 + 4 + constexpr int64_t kPageBytes = 132 << kPageBits; + + // each warp handles 128 elements, 1 warp, each block handles multiple rows + const auto& [input, cache, indices, num_tokens] = param; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto global_wid = global_tid / 32; + const auto lane_id = threadIdx.x % 32; + + if (global_wid >= num_tokens) return; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[global_wid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + using InStorage = AlignedVector; + using OutStorage = AlignedVector; + const auto elems = static_cast(input)[global_tid]; + const auto [x0, x1] = cast(elems[0]); + const auto [y0, y1] = cast(elems[1]); + const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1))); + const auto abs_max = warp::reduce_max(local_max); + // use normal fp32 scale + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 128); + const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4); + OutStorage result; + result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale); + result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale); + static_cast(value_ptr)[lane_id] = result; + static_cast(scale_ptr)[0] = scale; + + PDLTriggerSecondary(); +} + +template +struct FusedStoreCacheFlashMLAKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr auto kernel = fused_store_flashmla_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 512}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 256; + const auto num_blocks = num_tokens; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct FusedStoreCacheIndexerKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = 132 * kPageSize; + static constexpr auto kernel = fused_store_indexer_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 128}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128; + const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh new file mode 100644 index 0000000000..b1ccd24b20 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh @@ -0,0 +1,340 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +#ifndef SGL_TOPK +#define SGL_TOPK 512 +#endif + +constexpr uint32_t kTopK = SGL_TOPK; +constexpr uint32_t kTopKBlockSize = SGL_TOPK; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopKParams { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_transform_kernel(const __grid_constant__ TopKParams params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopKKernel { + static constexpr auto kernel = topk_transform_kernel; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, kTopK}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, kTopK}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopKParams{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh new file mode 100644 index 0000000000..8c4a526575 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh @@ -0,0 +1,493 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace { + +#ifndef SGL_TOPK +#define SGL_TOPK 512 +#endif + +inline constexpr uint32_t K = SGL_TOPK; + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +namespace impl = device::top512; +using Large = impl::ClusterTopK; +using Medium = impl::StreamingTopK; +using Small = impl::RegisterTopK; + +using Metadata = Large::Metadata; +constexpr uint32_t kBlockSize = impl::kBlockSize; +constexpr uint32_t kNumClusters = 15; // based on hardware limits +constexpr uint32_t kClusterSize = Large::kClusterSize; +constexpr uint32_t kMax2PassLength = Small::kMax2PassLength; +constexpr uint32_t kMaxSupportedLength = Large::kMaxLength; + +/// Common metadata lives at metadata[0] (first row of the [batch_size+1, 4] tensor). +/// Per-item metadata starts at metadata[1..batch_size]. The plan kernel writes both. +struct alignas(16) GlobalMetadata { + uint32_t cluster_threshold; // decided per-batch in plan kernel + uint32_t num_cluster_items; // N = number of items routed to the cluster path + uint32_t reserved[2]; +}; +static_assert(sizeof(GlobalMetadata) == sizeof(Metadata), "layout: row 0 must occupy one Metadata-sized slot"); + +// optimize occupancy for prefill +#define SMALL_TOPK_KERNEL __global__ __launch_bounds__(kBlockSize, 2) +// cluster at y dim +#define LARGE_CLUSTER __cluster_dims__(1, kClusterSize, 1) +// stage-1 is persistent cluster, and shared memory usage is huge (can not 2) +#define LARGE_TOPK_STAGE_1 __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// stage-2 is non-persistent non-cluster, with less shared memory and higher occupancy +#define LARGE_TOPK_STAGE_2 __global__ __launch_bounds__(kBlockSize, 2) +// fused into 1 stage when batch-size <= kNumPersistentClusters +#define FUSED_COMBINE_KERNEL __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// plan runs once as a single block before the combine kernels +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct TopKParams { + const uint32_t* __restrict__ seq_lens; + const float* __restrict__ scores; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int64_t score_stride; + int64_t page_table_stride; + uint8_t* __restrict__ workspace; // [batch, kWorkspaceBytes] -- internally allocated + /// Pointer to the full metadata tensor: metadata[0] is GlobalMetadata, metadata[1..] + /// are per-item entries (at most kNumClusters * rounds of them). + const Metadata* __restrict__ metadata = nullptr; + int64_t workspace_stride; // bytes per batch + uint32_t batch_size; + uint32_t page_bits; + + SGL_DEVICE const float* get_scores(const uint32_t batch_id) const { + return scores + batch_id * score_stride; + } + SGL_DEVICE impl::TransformParams get_transform(const uint32_t batch_id, int32_t* indices) const { + return { + .page_table = page_table + batch_id * page_table_stride, + .indices_in = indices, + .indices_out = page_indices + batch_id * K, + .page_bits = page_bits, + }; + } + SGL_DEVICE const GlobalMetadata& get_global_metadata() const { + return *reinterpret_cast(metadata); + } + SGL_DEVICE const Metadata& get_item_metadata(uint32_t work_id) const { + return metadata[1 + work_id]; // +1 to skip the GlobalMetadata row + } +}; + +SGL_DEVICE uint2 partition_work(uint32_t length, uint32_t rank) { + constexpr uint32_t kTMAAlign = 4; + const auto total_units = (length + kTMAAlign - 1) / kTMAAlign; + const auto base = total_units / kClusterSize; + const auto extra = total_units % kClusterSize; + const auto local_units = base + (rank < extra ? 1u : 0u); + const auto offset_units = rank * base + min(rank, extra); + const auto offset = offset_units * kTMAAlign; + const auto finish = min(offset + local_units * kTMAAlign, length); + return {offset, finish - offset}; +} + +/// Persistent scheduler. A single block: +/// 1. Decides a cluster_threshold from the real seq_lens distribution (or +/// uses the caller-supplied `static_cluster_threshold` when non-zero). +/// 2. Writes that threshold + N into metadata[0] (the GlobalMetadata row). +/// 3. Compacts items with seq_len > threshold into metadata[1..N+1), laid out +/// to match the persistent consumer's round-robin stride (kNumClusters). +/// Entries for clusters that get no work are zero-filled. +PLAN_KERNEL void topk_plan( + const uint32_t* __restrict__ seq_lens, + Metadata* __restrict__ metadata, + const uint32_t batch_size, + const uint32_t static_cluster_threshold) { + // Candidate thresholds, strictly increasing. Picked to give the auto-heuristic + // reasonable granularity without needing a full sort. Must all be >= kMax2PassLength. + + struct Pair { + uint32_t threshold; + uint32_t max_batch_size; + }; + /// NOTE: only tuned on B200 + constexpr Pair kCandidates[] = { + {32768, 30}, + {40960, 45}, + {49152, 45}, + {65536, 60}, + {98304, 60}, + {131072, 75}, + {196608, 90}, + {262144, 105}, + }; + constexpr uint32_t kNumCandidates = std::size(kCandidates); + constexpr uint32_t kMinBatchSize = kCandidates[0].max_batch_size; + static_assert(kCandidates[0].threshold == kMax2PassLength); + static_assert(kCandidates[kNumCandidates - 1].threshold == kMaxSupportedLength); + + __shared__ uint32_t s_count; // final N after compaction + __shared__ uint32_t s_counts[kNumCandidates]; + __shared__ uint32_t s_threshold; + + const auto tx = threadIdx.x; + if (tx == 0) s_count = 0; + if (tx < kNumCandidates) s_counts[tx] = 0; + __syncthreads(); + + // --- Phase 1: decide threshold ------------------------------------------ + if (static_cluster_threshold > 0) { + if (tx == 0) s_threshold = static_cluster_threshold; + } else if (batch_size <= kMinBatchSize) { + if (tx == 0) s_threshold = kMax2PassLength; // always prefer cluster + } else { + // Count items above each candidate threshold. Monotonically non-increasing in T. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + assert(sl <= kMaxSupportedLength); + uint32_t count = 0; +#pragma unroll + for (uint32_t j = 0; j < kNumCandidates; ++j) { + count += (sl > kCandidates[j].threshold ? 1 : 0); + } + if (count > 0) { + atomicAdd(&s_counts[count - 1], 1); + } + } + __syncthreads(); + if (tx == 0) { + uint32_t accum = 0; + uint32_t chosen = kMaxSupportedLength; +#pragma unroll + for (uint32_t i = 0; i < kNumCandidates; ++i) { + const auto j = kNumCandidates - 1 - i; + accum += s_counts[j]; + /// NOTE: `accum` increasing, while `max_batch_size` decreasing + if (accum > kCandidates[j].max_batch_size) break; + chosen = kCandidates[j].threshold; + } + s_threshold = chosen; + } + } + __syncthreads(); + // sanity check: below 2 pass threshold, must fits in small path + const auto cluster_threshold = max(s_threshold, kMax2PassLength); + + // --- Phase 2: compact items with seq_len > threshold into metadata[1..] - + // Per-item rows live at metadata[1 + pos]; metadata[0] is the GlobalMetadata row. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + if (sl > cluster_threshold) { + const auto pos = atomicAdd(&s_count, 1); + metadata[1 + pos] = {i, sl, false}; + } + } + __syncthreads(); + const auto N = s_count; + + // --- Phase 3: has_next + sentinels + GlobalMetadata --------------------- + for (uint32_t i = tx; i < N; i += kBlockSize) { + if (i + kNumClusters < N) metadata[1 + i].has_next = true; + } + // Zero-fill the first kNumClusters sentinel slots that got no valid entry. + if (tx < kNumClusters && tx >= N) metadata[1 + tx] = {0, 0, false}; + // Write global metadata (row 0). + if (tx == 0) { + auto* g = reinterpret_cast(metadata); + *g = { + .cluster_threshold = cluster_threshold, + .num_cluster_items = N, + .reserved = {0, 0}, + }; + } +} + +SMALL_TOPK_KERNEL void // short context +topk_short_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + // trivial case + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + device::PDLTriggerSecondary(); + Small::transform(transform); + } +} + +LARGE_TOPK_STAGE_1 void // long context, middle to large batch size +topk_combine_preprocess(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + uint32_t work_id = blockIdx.x; + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + uint32_t length; + uint32_t offset; + const auto cluster_rank = blockIdx.y; + + const auto prefetch_metadata = [&] { + const auto metadata = params.get_item_metadata(work_id); + batch_id = metadata.batch_id; + seq_len = metadata.seq_len; + has_next = metadata.has_next; + work_id += kNumClusters; // advance to the next item for this cluster + }; + const auto launch_prologue = [&] { + const auto partition = partition_work(seq_len, cluster_rank); + offset = partition.x; + length = partition.y; + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + }; + + device::PDLWaitPrimary(); + device::PDLTriggerSecondary(); + + prefetch_metadata(); + if (seq_len == 0) return; + Large::stage1_init(smem); + launch_prologue(); + while (true) { + const auto this_length = length; + const auto this_offset = offset; + const auto need_prefetch = has_next; + const auto transform = params.get_transform(batch_id, s_topk_indices); + const auto ws = params.workspace + batch_id * params.workspace_stride; + if (need_prefetch) prefetch_metadata(); + Large::stage1(s_topk_indices, this_length, smem, /*reuse=*/true); + if (need_prefetch) launch_prologue(); + Large::stage1_epilogue(transform, this_offset, ws, smem); + if (!need_prefetch) break; + } +} + +LARGE_TOPK_STAGE_2 void // long context, middle to large batch size +topk_combine_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto cluster_threshold = params.get_global_metadata().cluster_threshold; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= kMax2PassLength) { + if (seq_len <= Small::kMax1PassLength) { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } else { + __syncwarp(); + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } + Small::transform(transform); + } else if (seq_len <= cluster_threshold) { + Medium::run(params.get_scores(batch_id), seq_len, s_topk_indices, smem); + Medium::transform(transform, smem); + } else { + const auto ws = params.workspace + batch_id * params.workspace_stride; + device::PDLWaitPrimary(); + Large::transform(transform, ws, smem); + } +} + +FUSED_COMBINE_KERNEL void // long context, small batch size +topk_fused_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto cluster_rank = blockIdx.y; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + if (cluster_rank != 0) return; // only first rank work + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= Small::kMax1PassLength) { + if (cluster_rank != 0) return; // only first rank work + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + Small::transform(transform); + } else { + const auto [offset, length] = partition_work(seq_len, cluster_rank); + const auto ws = params.workspace + batch_id * params.workspace_stride; + Large::stage1_init(smem); + device::PDLWaitPrimary(); + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + Large::stage1(s_topk_indices, length, smem); + Large::stage1_epilogue(transform, offset, ws, smem); + cooperative_groups::this_cluster().sync(); + if (cluster_rank != 0) return; // only first rank do the stage-2 + Large::transform(transform, ws, smem); + } +} + +struct CombinedTopKKernel { + static constexpr auto kStage1SMEM = sizeof(Large::Smem) + 128; + static constexpr auto kStage2SMEM = std::max(sizeof(Small::Smem), sizeof(Medium::Smem)) + 128; + + static void plan( // + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView metadata, + const uint32_t static_cluster_threshold) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto batch_size = static_cast(B.unwrap()); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1); + if (batch_size <= kNumClusters) return; // metadata unused in fused path + + const auto device = device_.unwrap(); + constexpr auto kernel = topk_plan; + LaunchKernel(1, kBlockSize, device)( // + kernel, + static_cast(seq_lens.data_ptr()), + static_cast(metadata.data_ptr()), + batch_size, + static_cluster_threshold); + } + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::TensorView workspace, + const tvm::ffi::TensorView metadata) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto L = SymbolicSize{"max_seq_len"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto W = SymbolicSize{"workspace_stride"}; + constexpr auto D = Large::kWorkspaceInts; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, L}) // + .with_strides({S, 1}) + .with_dtype() + .with_device(device_) + .verify(scores); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({B, -1}) // + .with_strides({P, 1}) + .with_dtype() + .with_device(device_) + .verify(page_table); + TensorMatcher({B, K}) // + .with_dtype() + .with_device(device_) + .verify(page_indices); + TensorMatcher({B, D}) // + .with_strides({W, 1}) + .with_dtype() + .with_device(device_) + .verify(workspace); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto max_seq_len = static_cast(L.unwrap()); + const auto device = device_.unwrap(); + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + RuntimeCheck(S.unwrap() % 4 == 0, "score_stride must be a multiple of 4 (TMA 16-byte alignment)"); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1, "invalid metadata shape"); + + // NOTE: this should be fixed later + // RuntimeCheck(max_seq_len <= kMaxSupportedLength, max_seq_len, " exceeds the maximum supported length"); + + const auto params = TopKParams{ + .seq_lens = static_cast(seq_lens.data_ptr()), + .scores = static_cast(scores.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .workspace = static_cast(workspace.data_ptr()), + .metadata = static_cast(metadata.data_ptr()), + .workspace_stride = W.unwrap() * static_cast(sizeof(int32_t)), + .batch_size = batch_size, + .page_bits = page_bits, + }; + + if (max_seq_len <= Small::kMax1PassLength) { + // All items fit in the short path -- no stage-1 needed + constexpr auto kernel = topk_short_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel, params); + } else { + // Some items may be large -- launch stage-1 + main + if (batch_size <= kNumClusters) { + // can fuse into 1 stage + constexpr auto kernel = topk_fused_transform; + constexpr auto kSMEM = std::max(kStage1SMEM, kStage2SMEM); + setup_kernel_smem_once(); + LaunchKernel({batch_size, kClusterSize}, kBlockSize, device, kSMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel, params); + } else { + // stage 1 + stage 2 + constexpr auto kernel_stage_1 = topk_combine_preprocess; + setup_kernel_smem_once(); + const auto num_clusters = std::min(batch_size, kNumClusters); + LaunchKernel({num_clusters, kClusterSize}, kBlockSize, device, kStage1SMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel_stage_1, params); + constexpr auto kernel_stage_2 = topk_combine_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel_stage_2, params); + } + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/dsv4/__init__.py b/lightllm/third_party/sglang_jit/dsv4/__init__.py new file mode 100644 index 0000000000..507b225167 --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/__init__.py @@ -0,0 +1,8 @@ +from .elementwise import fused_k_norm_rope_flashmla, fused_q_norm_rope +from .topk import topk_transform_512 + +__all__ = [ + "fused_k_norm_rope_flashmla", + "fused_q_norm_rope", + "topk_transform_512", +] diff --git a/lightllm/third_party/sglang_jit/dsv4/elementwise.py b/lightllm/third_party/sglang_jit/dsv4/elementwise.py new file mode 100644 index 0000000000..07011b0479 --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/elementwise.py @@ -0,0 +1,215 @@ +from typing import Optional, Tuple + +import torch + +from lightllm.third_party.sglang_jit.jit_utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) +from lightllm.third_party.sglang_jit.runtime_utils import is_hip + +from .utils import make_name + +_is_hip = is_hip() + + +@cache_once +def _jit_fused_rope_module(): + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("fused_rope"), + *args, + cuda_files=["deepseek_v4/rope.cuh"], + cuda_wrappers=[("forward", f"FusedQKRopeKernel<{args}>::forward")], + ) + + +@cache_once +def _jit_main_q_norm_rope_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, +): + """Main MLA path Q kernel: rmsnorm-self + RoPE, warp per (token, head).""" + args = make_cpp_args(dtype, head_dim, rope_dim, is_arch_support_pdl()) + return load_jit( + make_name("main_q_norm_rope"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedQNormRopeKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_main_k_norm_rope_flashmla_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, + page_size: int, +): + """Main MLA path K kernel: rmsnorm + RoPE + write to FlashMLA paged cache.""" + args = make_cpp_args(dtype, head_dim, rope_dim, page_size, is_arch_support_pdl()) + return load_jit( + make_name("main_k_norm_rope_flashmla"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedKNormRopeFlashMLAKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_main_q_indexer_rope_hadamard_quant_module(dtype: torch.dtype): + """C4 indexer Q kernel: RoPE + 128-pt Hadamard + fp8 act-quant""" + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("main_q_indexer_rope_hadamard_quant"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedQIndexerRopeHadamardQuantKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_main_q_indexer_rope_hadamard_fp4_quant_module(dtype: torch.dtype): + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("main_q_indexer_rope_hadamard_fp4_quant"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedQIndexerRopeHadamardFp4QuantKernel<{args}>::forward"), + ], + ) + + +def fused_rope_inplace( + q: torch.Tensor, + k: Optional[torch.Tensor], + freqs_cis: torch.Tensor, + positions: torch.Tensor, + inverse: bool = False, +) -> None: + """Apply rotary embeddings to both Q and K in a single fused CUDA kernel. + + Args: + q: [batch_size, num_q_heads, rope_dim] bfloat16 + k: [batch_size, num_k_heads, rope_dim] bfloat16 or None + freqs_cis: [max_seq_len, rope_dim // 2] complex64 (full table) + positions: [batch_size] int32 or int64, indices into freqs_cis + inverse: if True, apply inverse rotation (conjugate freqs) + """ + if _is_hip: + from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton + + apply_rotary_emb_triton(q, freqs_cis, positions=positions, inverse=inverse) + if k is not None: + apply_rotary_emb_triton(k, freqs_cis, positions=positions, inverse=inverse) + return + + freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() + module = _jit_fused_rope_module() + module.forward(q, k, freqs_real, positions, inverse) + + +def fused_q_norm_rope( + q_input: torch.Tensor, + q_output: torch.Tensor, + eps: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, +) -> None: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + head_dim = q_input.shape[-1] + rope_dim = freqs_real.shape[-1] + module = _jit_main_q_norm_rope_module(q_input.dtype, head_dim, rope_dim) + module.forward(q_input, q_output, freqs_real, positions, eps) + + +def fused_q_indexer_rope_hadamard_quant( + q_input: torch.Tensor, + weight: torch.Tensor, + weight_scale: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + q_fp8 = torch.empty(q_input.shape, dtype=torch.float8_e4m3fn, device=q_input.device) + weights_out = torch.empty((*q_input.shape[:-1], 1), dtype=torch.float32, device=q_input.device) + if _is_hip: + torch.ops.sgl_kernel.dsv4_fused_q_indexer_rope_hadamard_quant( + q_input, + q_fp8, + weight, + weights_out, + float(weight_scale), + freqs_real, + positions, + ) + else: + module = _jit_main_q_indexer_rope_hadamard_quant_module(q_input.dtype) + module.forward( + q_input, + q_fp8, + weight, + weights_out, + float(weight_scale), + freqs_real, + positions, + ) + return q_fp8, weights_out + + +def fused_q_indexer_rope_hadamard_fp4_quant( + q_input: torch.Tensor, + weight: torch.Tensor, + weight_scale: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + if _is_hip: + raise RuntimeError("DeepSeek V4 FP4 indexer requires the CUDA fused Q path.") + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + q_fp4 = torch.empty( + (*q_input.shape[:-1], q_input.shape[-1] // 2), + dtype=torch.int8, + device=q_input.device, + ) + q_sf = torch.empty(q_input.shape[:-1], dtype=torch.int32, device=q_input.device) + weights_out = torch.empty((*q_input.shape[:-1], 1), dtype=torch.float32, device=q_input.device) + module = _jit_main_q_indexer_rope_hadamard_fp4_quant_module(q_input.dtype) + module.forward( + q_input, + q_fp4, + q_sf, + weight, + weights_out, + float(weight_scale), + freqs_real, + positions, + ) + return (q_fp4, q_sf), weights_out + + +def fused_k_norm_rope_flashmla( + kv: torch.Tensor, + kv_weight: torch.Tensor, + eps: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + out_loc: torch.Tensor, + kvcache: torch.Tensor, + page_size: int, +) -> None: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + head_dim = kv.shape[-1] + rope_dim = freqs_real.shape[-1] + module = _jit_main_k_norm_rope_flashmla_module(kv.dtype, head_dim, rope_dim, page_size) + module.forward(kv, kv_weight, freqs_real, positions, out_loc, kvcache, eps) diff --git a/lightllm/third_party/sglang_jit/dsv4/topk.py b/lightllm/third_party/sglang_jit/dsv4/topk.py new file mode 100644 index 0000000000..1bfce7cef3 --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/topk.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Optional + +import torch + +from lightllm.third_party.sglang_jit.jit_utils import ( + cache_once, + is_arch_support_pdl, + is_hip_runtime, + load_jit, + make_cpp_args, +) + +from .utils import make_name + + +@cache_once +def _jit_topk_v1_module(topk: int): + args = make_cpp_args(is_arch_support_pdl()) + assert topk in (512, 1024), "Only support topk=512 or 1024" + return load_jit( + make_name(f"topk_v1_{topk}"), + *args, + cuda_files=["deepseek_v4/topk_v1.cuh"], + cuda_wrappers=[("topk_transform", f"TopKKernel<{args}>::transform")], + extra_cuda_cflags=[f"-DSGL_TOPK={topk}"], + ) + + +@cache_once +def _jit_topk_v2_module(topk: int): + return load_jit( + make_name(f"topk_v2_{topk}"), + cuda_files=["deepseek_v4/topk_v2.cuh"], + cuda_wrappers=[ + ("topk_transform", "CombinedTopKKernel::transform"), + ("topk_plan", "CombinedTopKKernel::plan"), + ], + extra_cuda_cflags=[f"-DSGL_TOPK={topk}"], + ) + + +def topk_transform_512( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + if is_hip_runtime(): + torch.ops.sgl_kernel.deepseek_v4_topk_transform_512( + scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices + ) + else: + module = _jit_topk_v1_module(out_page_indices.shape[1]) + module.topk_transform(scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices) + + +_WORKSPACE_INTS_PER_BATCH = 2 + 1024 * 2 +_PLAN_METADATA_INTS_PER_BATCH = 4 + + +def plan_topk_v2(seq_lens: torch.Tensor, static_threshold: int = 0) -> torch.Tensor: + module = _jit_topk_v2_module(512) # does not matter + bs = seq_lens.shape[0] + metadata = seq_lens.new_empty(bs + 1, _PLAN_METADATA_INTS_PER_BATCH) + module.topk_plan(seq_lens, metadata, static_threshold) + return metadata + + +def topk_transform_512_v2( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + metadata: torch.Tensor, +) -> None: + module = _jit_topk_v2_module(out_page_indices.shape[1]) + bs = scores.shape[0] + workspace = seq_lens.new_empty(bs, _WORKSPACE_INTS_PER_BATCH) + module.topk_transform( + scores, + seq_lens, + page_tables, + out_page_indices, + page_size, + workspace, + metadata, + ) diff --git a/lightllm/third_party/sglang_jit/dsv4/utils.py b/lightllm/third_party/sglang_jit/dsv4/utils.py new file mode 100644 index 0000000000..8085074f6c --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/utils.py @@ -0,0 +1,2 @@ +def make_name(name: str) -> str: + return f"dpsk_v4_{name}" diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh new file mode 100644 index 0000000000..c9da765f4a --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh @@ -0,0 +1,35 @@ +/// \file atomic.cuh +/// \brief Device-side atomic operations. + +#pragma once +#include + +namespace device::atomic { + +/** + * \brief Atomically computes the maximum of `*addr` and `value`, storing the + * result in `*addr`. + * \param addr Pointer to the value in global/shared memory to be updated. + * \param value The value to compare against. + * \return The old value at `*addr` before the update. + * \note On CUDA, this uses `atomicMax`/`atomicMin` on the reinterpreted + * integer representation. On ROCm, a CAS loop is used as a fallback. + */ +SGL_DEVICE float max(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +} // namespace device::atomic diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh new file mode 100644 index 0000000000..b47a4a27b2 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh @@ -0,0 +1,40 @@ +/// \file cta.cuh +/// \brief CTA (Cooperative Thread Array / thread-block) level primitives. + +#pragma once +#include +#include +#include + +namespace device::cta { + +/** + * \brief Compute the maximum of `value` across all threads in the CTA. + * + * Uses a two-level reduction: first within each warp via `warp::reduce_max`, + * then across warps using shared memory. The final result is stored in + * `smem[0]`. + * + * \tparam T Numeric type (must be supported by `warp::reduce_max`). + * \param value Per-thread input value. + * \param smem Shared memory buffer (must have at least `blockDim.x / 32` + * elements). + * \param min_value Identity element for max (default 0.0f). + * \note This function does NOT issue a trailing `__syncthreads()`. + * Callers must synchronize before reading `smem[0]`. + */ +template +SGL_DEVICE void reduce_max(T value, float* smem, float min_value = 0.0f) { + const uint32_t warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = warp::reduce_max(value); + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_value = tx * kWarpThreads < blockDim.x ? smem[tx] : min_value; + const auto max_value = warp::reduce_max(local_value); + smem[0] = max_value; + } + // no extra sync; it is caller's responsibility to sync if needed +} + +} // namespace device::cta diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh new file mode 100644 index 0000000000..02b166d01c --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +struct alignas(16) PrefillPlan { + uint32_t ragged_id; + uint32_t batch_id; + uint32_t position; + uint32_t window_len; // must be in `[0, compress_ratio * (1 + is_overlap))` + + bool is_valid(const uint32_t ratio, const bool is_overlap) const { + const uint32_t max_window_len = ratio * (1 + is_overlap); + return window_len < max_window_len; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::PrefillPlan; +using PrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kPrefillPlanDim = 16; + +static_assert(alignof(PrefillPlan) == sizeof(PrefillPlan)); +static_assert(sizeof(PrefillPlan) == kPrefillPlanDim * sizeof(PrefillPlanTensorDtype)); + +} // namespace host::compress diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh new file mode 100644 index 0000000000..3e87127c5f --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh @@ -0,0 +1,99 @@ +#pragma once + +#include +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +/// \brief Per-batch decode plan. Layout: 16 bytes. +struct alignas(16) DecodePlan { + uint32_t seq_len; + int32_t write_loc; + int32_t read_page_0; + int32_t read_page_1; +}; + +/// \brief Per-token compress plan (used by c4/c128 prefill). Layout: 16 bytes. +struct alignas(16) CompressPlan { + uint32_t seq_len; + uint16_t ragged_id; + uint16_t buffer_len; + int32_t read_page_0; + /// \brief Stage 0 (CPU): batch_id (used to look up page table). + /// \brief Stage 1 (GPU): final state-pool write location. + int32_t read_page_1; + + static SGL_DEVICE __host__ CompressPlan invalid() { + return CompressPlan{-1u, 0, 0, -1, -1}; + } + + SGL_DEVICE __host__ bool is_invalid() const { + return seq_len == -1u; + } +}; + +/// \brief Per-token write plan (used by c4/c128 prefill). Layout: 8 bytes. +struct alignas(8) WritePlan { + /// \brief Stage 0 (CPU): packed `(batch_id << 16) | ragged_id`. + /// \brief Stage 1 (GPU): just `ragged_id`. + uint32_t ragged_id; + /// \brief Stage 0 (CPU): position + 1 (used to look up state slot). + /// \brief Stage 1 (GPU): final state-pool write location. + int32_t write_loc; + + static SGL_DEVICE __host__ WritePlan invalid() { + return WritePlan{-1u, -1}; + } + + SGL_DEVICE __host__ bool is_invalid() const { + return ragged_id == -1u; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::CompressPlan; +using device::compress::DecodePlan; +using device::compress::WritePlan; + +static_assert(alignof(DecodePlan) == sizeof(DecodePlan)); +static_assert(sizeof(DecodePlan) == 16); +static_assert(alignof(CompressPlan) == sizeof(CompressPlan)); +static_assert(sizeof(CompressPlan) == 16); +static_assert(alignof(WritePlan) == sizeof(WritePlan)); +static_assert(sizeof(WritePlan) == 8); + +inline auto verify_plan_d(tvm::ffi::TensorView t, SymbolicSize& N, SymbolicDevice& device) -> const DecodePlan* { + TensorMatcher({N, sizeof(DecodePlan)}) // + .with_dtype() + .with_device(device) + .verify(t); + return static_cast(t.data_ptr()); +} + +inline auto verify_plan_c(tvm::ffi::TensorView t, SymbolicSize& N, SymbolicDevice& device) -> const CompressPlan* { + TensorMatcher({N, sizeof(CompressPlan)}) // + .with_dtype() + .with_device(device) + .verify(t); + return static_cast(t.data_ptr()); +} + +inline auto verify_plan_w(tvm::ffi::TensorView t, SymbolicSize& N, SymbolicDevice& device) -> const WritePlan* { + TensorMatcher({N, sizeof(WritePlan)}) // + .with_dtype() + .with_device(device) + .verify(t); + return static_cast(t.data_ptr()); +} + +} // namespace host::compress diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh new file mode 100644 index 0000000000..53a62755b4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +#include +#ifndef USE_ROCM +#include +#endif + +// Small helpers shared by the DeepSeek-V4 FP8/UE8M0 quantization kernels +// (silu_and_mul_masked_post_quant, store, mega_moe_pre_dispatch, ...). +// All functions are `SGL_DEVICE` (= `__forceinline__ __device__`) so +// including this header in multiple translation units is ODR-safe. + +namespace deepseek_v4::fp8 { + +// Round `x` to the nearest representable UE8M0 value. Returns the raw +// 8-bit biased exponent; the actual fp32 scale is `2^(exp - 127)` +// (i.e. `__uint_as_float(exp << 23)`). +SGL_DEVICE int32_t cast_to_ue8m0(float x) { + uint32_t u = __float_as_uint(x); + int32_t exp = int32_t((u >> 23) & 0xFF); + uint32_t mant = u & 0x7FFFFF; + return exp + (mant != 0); +} + +// 1 / 2^(exp - 127) as fp32. Equivalent to `1.0f / __uint_as_float(exp << 23)`. +SGL_DEVICE float inv_scale_ue8m0(int32_t exp) { + return __uint_as_float((127 + 127 - exp) << 23); +} + +// Clamp to [-FP8_E4M3_MAX, FP8_E4M3_MAX]. +// Uses platform-specific max from type.cuh (448 for E4M3FN, 224 for E4M3FNUZ). +SGL_DEVICE float fp8_e4m3_clip(float val) { + return fmaxf(fminf(val, kFP8E4M3Max), -kFP8E4M3Max); +} + +#ifndef USE_ROCM +// Pack two fp32 values into a single fp8x2_e4m3 with clamping. +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} +#else +// Software float -> FP8 E4M3 conversion for ROCm/HIP. +// Supports both E4M3FN (MI350X, gfx950) and E4M3FNUZ (MI300X, gfx942). +SGL_DEVICE uint8_t cvt_float_to_fp8_e4m3(float val) { + val = fp8_e4m3_clip(val); + if (val == 0.0f) return 0; + + uint32_t f32 = __float_as_uint(val); + uint8_t sign = static_cast((f32 >> 31) << 7); + int32_t exp32 = static_cast((f32 >> 23) & 0xFF) - 127; + uint32_t mant23 = f32 & 0x7FFFFF; + +#if HIP_FP8_TYPE_FNUZ + // E4M3FNUZ: bias=8, max=240, no negative zero, NaN=0x80 + constexpr int32_t kBias = 8; + constexpr int32_t kMaxExp = 15; + constexpr int32_t kMinSubnormExp = -10; // min subnormal exponent + constexpr int32_t kMinNormExp = -7; // min normal exponent + constexpr uint8_t kSaturate = 0x7Fu; // max normal = 0_1111_111 = 240.0 +#else + // E4M3FN: bias=7, max=448, NaN=0x7F + constexpr int32_t kBias = 7; + constexpr int32_t kMaxExp = 15; + constexpr int32_t kMinSubnormExp = -9; + constexpr int32_t kMinNormExp = -6; + constexpr uint8_t kSaturate = 0x7Eu; // max normal = 0_1111_110 = 448.0 +#endif + + int32_t exp8; + uint8_t mant3; + + if (exp32 < kMinSubnormExp) { + return sign; + } else if (exp32 < kMinNormExp) { + // Subnormal range + int32_t shift = -(kBias - 1) - exp32; // 1..3 + uint32_t subnorm_mant = (0x800000 | mant23) >> (shift + 20); + uint32_t round_bit = ((0x800000 | mant23) >> (shift + 19)) & 1; + subnorm_mant += round_bit; + mant3 = static_cast(subnorm_mant & 0x07); + exp8 = 0; + if (subnorm_mant > 7) { + exp8 = 1; + mant3 = 0; + } + } else { + exp8 = exp32 + kBias; + mant3 = static_cast(mant23 >> 20); + uint32_t round_bit = (mant23 >> 19) & 1; + mant3 += round_bit; + if (mant3 > 7) { + mant3 = 0; + exp8++; + } + if (exp8 >= kMaxExp) return sign | kSaturate; + } + return sign | (static_cast(exp8) << 3) | mant3; +} + +// Pack two fp32 values into a single fp8x2_e4m3 (uint16_t on HIP). +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + uint8_t x8 = cvt_float_to_fp8_e4m3(x); + uint8_t y8 = cvt_float_to_fp8_e4m3(y); + return static_cast(x8) | (static_cast(y8) << 8); +} +#endif + +} // namespace deepseek_v4::fp8 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh new file mode 100644 index 0000000000..0a3acc4773 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh @@ -0,0 +1,96 @@ +#include +#include + +#include + +#include + +namespace device::hisparse { + +/// NOTE: We call nope+rope as a "value" here. +/// GPU Cache layout: +/// VALUE 0, VALUE 1, ..., VALUE 63, +/// SCALE 0, SCALE 1, ..., SCALE 63, +/// [Padding to align to 576 bytes] +/// CPU Cache follow a trivial linear layout without any padding. +inline constexpr int64_t kGPUPageSize = 64; +inline constexpr int64_t kGPUPageBits = 6; // log2(kGPUPageSize) +inline constexpr int64_t kValueBytes = 576; +inline constexpr int64_t kScaleBytes = 8; +/// NOTE: FlashMLA requires each page to be aligned to 576 bytes +inline constexpr int64_t kCPUItemBytes = kValueBytes + kScaleBytes; +inline constexpr int64_t kGPUPageBytes = host::div_ceil(kCPUItemBytes * kGPUPageSize, 576) * 576; +inline constexpr int64_t kGPUScaleOffset = kValueBytes * kGPUPageSize; + +struct PointerInfo { + int64_t* value_ptr; + int64_t* scale_ptr; +}; + +SGL_DEVICE PointerInfo get_pointer_gpu(void* cache, int32_t index) { + using namespace device; + static_assert(1 << kGPUPageBits == kGPUPageSize); + const int32_t page_num = index >> kGPUPageBits; + const int32_t page_offset = index & (kGPUPageSize - 1); + const auto page_ptr = pointer::offset(cache, page_num * kGPUPageBytes); + const auto value_ptr = pointer::offset(page_ptr, page_offset * kValueBytes); + const auto scale_ptr = pointer::offset(page_ptr, kGPUScaleOffset + page_offset * kScaleBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +SGL_DEVICE PointerInfo get_pointer_cpu(void* cache, int32_t index) { + using namespace device; + const auto value_ptr = pointer::offset(cache, index * kCPUItemBytes); + const auto scale_ptr = pointer::offset(value_ptr, kValueBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +enum class TransferDirection { + DeviceToDevice = 0, + DeviceToHost = 1, + HostToDevice = 2, +}; + +template +SGL_DEVICE void transfer_item(void* dst_cache, void* src_cache, const int32_t dst_index, const int32_t src_index) { + constexpr bool is_dst_device = (direction != TransferDirection::DeviceToHost); + constexpr bool is_src_device = (direction != TransferDirection::HostToDevice); + constexpr auto dst_fn = is_dst_device ? get_pointer_gpu : get_pointer_cpu; + constexpr auto src_fn = is_src_device ? get_pointer_gpu : get_pointer_cpu; + + const auto [dst_value_ptr, dst_scale_ptr] = dst_fn(dst_cache, dst_index); + const auto [src_value_ptr, src_scale_ptr] = src_fn(src_cache, src_index); + + int64_t local_items[2]; + const int64_t* tail_src_ptr; + int64_t* tail_dst_ptr; + + const int32_t lane_id = threadIdx.x % 32; + + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + local_items[i] = src_value_ptr[j]; + } + + if (lane_id < 8) { // handle the tail element safely + const auto last_id = 64 + lane_id; + tail_src_ptr = src_value_ptr + last_id; + tail_dst_ptr = dst_value_ptr + last_id; + } else { // broadcast load/store is safe + tail_src_ptr = src_scale_ptr; + tail_dst_ptr = dst_scale_ptr; + } + + const auto tail_item = *tail_src_ptr; + + // store first 512 bytes of value + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + dst_value_ptr[j] = local_items[i]; + } + + // store the tail element + *tail_dst_ptr = tail_item; +} + +} // namespace device::hisparse diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh new file mode 100644 index 0000000000..e58214c951 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh @@ -0,0 +1,257 @@ +#pragma once +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct ClusterTopK { + static constexpr uint32_t kClusterSize = 8; + static constexpr uint32_t kHistBits = 10; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 4; + static constexpr uint32_t kMaxLength = kClusterSize * kNumStages * kSizePerStage; + static constexpr uint32_t kStoreLane = kBlockSize - 1; + static constexpr uint32_t kAboveBits = 11; + + // --------------------------------------------------------------------------- + // Shared memory layouts + // --------------------------------------------------------------------------- + + struct Smem { + uint64_t barrier[kNumStages]; + uint32_t local_above_equal[kClusterSize]; + uint32_t prefix_above_equal; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + uint32_t histogram[kHistBins]; + alignas(128) float score_buffer[kNumStages][kSizePerStage]; + Tie tie_buffer[kMaxTies]; + }; + + struct alignas(16) Metadata { + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + }; + + struct WorkSpace { + uint2 metadata; // {num_above, num_ties} + Tie ties[kMaxTies]; + }; + + static constexpr uint32_t kWorkspaceInts = sizeof(WorkSpace) / sizeof(uint32_t); + + // --------------------------------------------------------------------------- + // Stage 1: histogram + cluster reduce + find threshold + scatter + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_init(void* _smem) { + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto smem = static_cast(_smem); + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < kNumStages) ptx::mbarrier_init(&smem->barrier[tx], 1); + __syncthreads(); + } + + SGL_DEVICE static void stage1_prologue(const float* scores, uint32_t length, void* _smem) { + if (threadIdx.x == 0) { + const auto smem = static_cast(_smem); + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + const auto length_aligned = (length + 3u) & ~3u; // align to 4 for TMA +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + if (stage >= num_stages) break; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length_aligned - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[stage]; + ptx::tma_load(smem->score_buffer[stage], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + } + } + + SGL_DEVICE static void stage1(int32_t* indices, uint32_t length, void* _smem, bool reuse = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram, counters, and barriers +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; + const auto size = min(kSizePerStage, length - offset); + if (lane_id == 0) ptx::mbarrier_wait(&smem->barrier[stage], 0); + __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto idx = tx + i * kBlockSize; + if (idx >= size) break; + const auto score = smem->score_buffer[stage][idx]; + const auto bin = extract_coarse_bin(score); + atomicAdd(&smem->histogram[bin], 1); + } + } + + static_assert(kHistBins <= kBlockSize); + + // 2-shot all-reduce + { + auto cluster = cooperative_groups::this_cluster(); + cluster.sync(); + const auto cluster_rank = blockIdx.y; + const auto kLocalSize = kHistBins / kClusterSize; + const auto offset = kLocalSize * cluster_rank; + + const auto src_tx = tx / kClusterSize; + const auto src_rank = tx % kClusterSize; + + if (tx < kHistBins) { + const auto addr = &smem->histogram[offset + src_tx]; + const auto src_addr = cluster.map_shared_rank(addr, src_rank); + *src_addr = warp::reduce_sum(*src_addr); + } + cluster.sync(); + } + + // now each block holds the whole histogram, find the threshold bin + { + const auto value = tx < kHistBins ? smem->histogram[tx] : 0; + const auto warp_inc = warp_inclusive_sum(lane_id, value); + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + const auto tmp = smem->warp_sum[lane_id]; + // total_length = sum of all bins in the globally-reduced histogram + // (problem.length is block-local; after cluster reduction we need the global total) + const auto total_length = warp::reduce_sum(tmp); + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_inc; + const auto above = total_length - prefix_sum; + if (tx < kHistBins && above < K && above + value >= K) { + smem->counter_gt = smem->counter_eq = 0; + smem->match = { + .bin = tx, + .above_count = above, + .equal_count = value, + }; + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // write above and equal results to global memory +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto buf_idx = tx + i * kBlockSize; + const auto global_idx = offset + buf_idx; + if (global_idx >= length) break; + const auto score = smem->score_buffer[stage][buf_idx]; + const auto bin = extract_coarse_bin(score); + if (bin > thr_bin) { + indices[atomicAdd(&smem->counter_gt, 1)] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } + } + if (reuse) { + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < num_stages) ptx::mbarrier_arrive(&smem->barrier[tx]); + } + __syncthreads(); + } + + // --------------------------------------------------------------------------- + // Stage 1 epilogue: cross-block prefix sum + page translate + tie store + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_epilogue(const TransformParams params, const uint32_t offset, void* _ws, void* _smem) { + auto cluster = cooperative_groups::this_cluster(); + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto local_above = smem->counter_gt; + const auto local_equal = smem->counter_eq; + const auto cluster_rank = blockIdx.y; + + constexpr uint32_t kAboveMask = (1 << kAboveBits) - 1; + static_assert(kAboveMask >= K); + + // Pack local counts -- NO alignment rounding (contiguous layout) + static_assert(kMaxTies <= kBlockSize); + const auto idx_above = tx < local_above ? params.indices_in[tx] : 0; + const auto tie_value = tx < local_equal ? smem->tie_buffer[tx] : Tie{0, 0.0f}; + + // push to remote shared memory, can reduce latency of reading remote + if (tx < kClusterSize) { + const auto value = (local_equal << kAboveBits) | local_above; + const auto dst_addr = cluster.map_shared_rank(smem->local_above_equal, tx); + dst_addr[cluster_rank] = value; + } + // after this last sync, only read local shared memory + // so that it is safe when peer rank has already exited the kernel + cluster.sync(); + if (tx < kClusterSize) { + const auto value = tx < cluster_rank ? smem->local_above_equal[tx] : 0; + const auto kActiveMask = (1u << kClusterSize) - 1; + smem->prefix_above_equal = warp::reduce_sum(value, kActiveMask); + } + __syncthreads(); + + const auto prefix_packed = smem->prefix_above_equal; + const auto prefix_above = prefix_packed & kAboveMask; + const auto prefix_equal = prefix_packed >> kAboveBits; + + // Page-translate above elements + if (tx < local_above) { + params.write(tx + prefix_above, idx_above + offset); + } + // Contiguous tie store via regular global writes (no TMA, no gaps) + const auto ws = static_cast(_ws); + if (tx < local_equal && tx + prefix_equal < kMaxTies) { + ws->ties[tx + prefix_equal] = {tie_value.idx + offset, tie_value.score}; + } + // Block 0 writes global metadata {num_above, num_ties} + if (cluster_rank == kClusterSize - 1 && tx == 0) { + const auto sum_above = prefix_above + local_above; + const auto sum_equal = prefix_equal + local_equal; + ws->metadata = make_uint2(sum_above, sum_equal); + } + } + + SGL_DEVICE static void transform(const TransformParams params, const void* _ws, void* _smem) { + const auto ws = static_cast(_ws); + const auto meta = &ws->metadata; + const auto [num_above, num_equal] = *meta; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(ws->ties, clamped_ties, num_above, K, params, _smem); + } +}; + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh new file mode 100644 index 0000000000..d553032d79 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh @@ -0,0 +1,176 @@ +#pragma once +#include +#include +#include +#include + +#include + +namespace device::top512 { + +inline constexpr uint32_t kMaxTopK = 1024; +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kNumWarps = kBlockSize / kWarpThreads; +inline constexpr uint32_t kMaxTies = 1024; // == kBlockSize: 1 element per thread in stage2 +static constexpr uint32_t kRadixBins = 256; +static_assert(kMaxTopK <= kBlockSize && kMaxTies <= kBlockSize); + +// always use float4 to load from global memory +using Vec4 = AlignedVector; + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +struct TransformParams { + const int32_t* __restrict__ page_table; + const int32_t* __restrict__ indices_in; + int32_t* __restrict__ indices_out; + uint32_t page_bits; + + SGL_DEVICE void transform(const uint32_t idx) const { + indices_out[idx] = page_to_indices(page_table, indices_in[idx], page_bits); + } + SGL_DEVICE void write(const uint32_t dst, const uint32_t src) const { + indices_out[dst] = page_to_indices(page_table, src, page_bits); + } +}; + +struct alignas(16) MatchBin { + uint32_t bin; + uint32_t above_count; + uint32_t equal_count; +}; + +struct alignas(8) Tie { + uint32_t idx; + float score; +}; + +struct TieHandleSmem { + alignas(128) uint32_t counter; // output position counter + alignas(128) MatchBin match; + uint32_t histogram[kRadixBins]; // 256-bin radix histogram + uint32_t warp_sum[kNumWarps]; // for 2-pass prefix sum +}; + +template +SGL_DEVICE uint32_t extract_coarse_bin(float x) { + static_assert(0 < kBits && kBits < 15); + const auto hx = cast(x); + const uint16_t bits = *reinterpret_cast(&hx); + const uint16_t key = (bits & 0x8000) ? ~bits : bits | 0x8000; + return key >> (16 - kBits); +} + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +/// Order-preserving float32 -> uint32 for radix select +SGL_DEVICE uint32_t extract_exact_bin(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE void trivial_transform(const TransformParams& params, uint32_t length, uint32_t K) { + if (const auto tx = threadIdx.x; tx < length) { + params.write(tx, tx); + } else if (tx < K) { + params.indices_out[tx] = -1; + } +} + +SGL_DEVICE void tie_handle_transform( + const Tie* __restrict__ ties, // + const uint32_t num_ties, + const uint32_t num_above, + const uint32_t K, + const TransformParams params, + void* _smem) { + auto* smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Each thread loads one element (or becomes inactive) + const bool has_elem = tx < num_ties; + const auto tie = has_elem ? ties[tx] : Tie{0, 0.0f}; + const uint32_t key = extract_exact_bin(tie.score); + const uint32_t idx = tie.idx; + bool active = has_elem; + uint32_t topk_remain = K - num_above; + uint32_t write_pos = K; + + smem->counter = 0; + __syncthreads(); + + // Number of warps covering the 256-bin histogram (256/32 = 8) + constexpr uint32_t kRadixWarps = kRadixBins / kWarpThreads; + +#pragma unroll + for (int round = 0; round < 4; round++) { + const uint32_t shift = 24 - round * 8; + const uint32_t bin = (key >> shift) & 0xFFu; + + // 1. Build histogram + if (tx < kRadixBins) smem->histogram[tx] = 0; + __syncthreads(); + if (active) atomicAdd(&smem->histogram[bin], 1); + __syncthreads(); + + // 2. v2-style 2-pass prefix sum on 256 bins + // Only first 256 threads (8 warps) carry histogram bins. + // Other threads get hist_val=0 and harmless prefix results. + uint32_t hist_val = 0; + uint32_t warp_inc = 0; + if (tx < kRadixBins) { + hist_val = smem->histogram[tx]; + warp_inc = warp_inclusive_sum(lane_id, hist_val); + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + } + __syncthreads(); + if (tx < kRadixBins) { + // Inter-warp prefix (only first kHistWarps warp totals matter) + const auto tmp = (lane_id < kRadixWarps) ? smem->warp_sum[lane_id] : 0; + const auto total = warp::reduce_sum(tmp); + const auto inter = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + const auto prefix = inter + warp_inc; // inclusive prefix through this bin + const auto above = total - prefix; // elements in bins ABOVE this one + // 3. Find threshold bin + if (above < topk_remain && above + hist_val >= topk_remain) { + smem->match = {tx, above, topk_remain - above}; + } + } + __syncthreads(); + + const auto [thr, n_above, _] = smem->match; + + // 4. Scatter + if (active) { + if (bin > thr) { + write_pos = num_above + atomicAdd(&smem->counter, 1); + active = false; + } else if (bin < thr) { + active = false; + } else if (round == 3) { + write_pos = K - atomicAdd(&smem->match.equal_count, -1u); + } + // my_bin == thr && round < 3: stay active for next round + } + + topk_remain -= n_above; + if (topk_remain == 0) break; + } + + if (write_pos < K) params.write(write_pos, idx); +} + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh new file mode 100644 index 0000000000..73eef555f4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh @@ -0,0 +1,54 @@ +#pragma once +#include + +#include + +#include + +namespace device::top512 { + +namespace ptx { + +SGL_DEVICE void mbarrier_wait(uint64_t* addr, uint32_t phase) { + while (!cuda::ptx::mbarrier_try_wait_parity(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, addr, phase)) + ; +} + +SGL_DEVICE void mbarrier_init(uint64_t* addr, uint32_t arrives) { + cuda::ptx::mbarrier_init(addr, arrives); +} + +SGL_DEVICE void mbarrier_arrive_expect_tx(uint64_t* addr, uint32_t tx) { + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr, tx); +} + +SGL_DEVICE void mbarrier_arrive(uint64_t* addr) { + cuda::ptx::mbarrier_arrive(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr); +} + +SGL_DEVICE void tma_load(void* dst, const void* src, uint32_t num_bytes, uint64_t* mbar) { + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, cuda::ptx::space_global, dst, src, num_bytes, mbar); +} + +SGL_DEVICE uint32_t elect_sync() { + uint32_t pred = 0; + asm volatile( + "{\n\t" + ".reg .pred %%px;\n\t" + "elect.sync _|%%px, %1;\n\t" + "@%%px mov.s32 %0, 1;\n\t" + "}" + : "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} + +SGL_DEVICE bool elect_sync_cta(uint32_t tx) { + const auto warp_id = tx / 32; + const auto uniform_warp_id = __shfl_sync(0xFFFFFFFF, warp_id, 0); + return (uniform_warp_id == 0 && elect_sync()); +} + +} // namespace ptx + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh new file mode 100644 index 0000000000..77d7361ee8 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh @@ -0,0 +1,302 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct RegisterTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kVecsPerThread = 4; + static constexpr uint32_t kMaxTolerance = 0; + static constexpr uint32_t kMax1PassLength = kVecsPerThread * 4 * kBlockSize; + static constexpr uint32_t kMaxExtraLength = kMax1PassLength; + static constexpr uint32_t kMax2PassLength = kMax1PassLength + kMaxExtraLength; + + struct Smem { + using HistVec = AlignedVector; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + uint64_t mbarrier; // for cp.async + MatchBin match; + uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + alignas(16) float score_buffer[kMaxExtraLength]; + }; + + template + SGL_DEVICE static void + run(const float* scores, // + int32_t* indices, + const uint32_t length, + void* _smem, + const bool use_pdl = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram + { + typename Smem::HistVec hist_vec; + hist_vec.fill(0); + smem->histogram_vec[tx] = hist_vec; + if (tx == 0) { + smem->counter_gt = smem->counter_eq = 0; + if constexpr (kIs2Pass) { + ptx::mbarrier_init(&smem->mbarrier, 1); + } + } + __syncthreads(); + } + + if (use_pdl) device::PDLWaitPrimary(); + + // Load scores into registers + Vec4 local[kVecsPerThread]; +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { + const uint32_t base = (tx + v * kBlockSize) * 4; + if (base >= length) break; + local[v].load(scores, tx + v * kBlockSize); + } + + // Fetch the next chunk of scores + if constexpr (kIs2Pass) { + if (ptx::elect_sync_cta(tx)) { + const auto length_aligned = (length + 3u - kMax1PassLength) & ~3u; + const auto size_bytes = length_aligned * sizeof(float); + ptx::tma_load(smem->score_buffer, scores + kMax1PassLength, size_bytes, &smem->mbarrier); + ptx::mbarrier_arrive_expect_tx(&smem->mbarrier, size_bytes); + } + __syncwarp(); // avoid warp divergence on + } + + // Accumulate histogram via shared-memory atomics +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + if constexpr (!kIs2Pass) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if (idx >= length) goto LABEL_ACC_FINISH; + } + atomicAdd(&smem->histogram[extract_coarse_bin(local[v][e])], 1); + } + } + if constexpr (kIs2Pass) { + // 16K ~ 32K. `i` is a float4 index + if (lane_id == 0) ptx::mbarrier_wait(&smem->mbarrier, 0); + __syncwarp(); + for (uint32_t i = tx; i + kMax1PassLength < length; i += kBlockSize) { + const auto val = smem->score_buffer[i]; + atomicAdd(&smem->histogram[extract_coarse_bin(val)], 1); + } + } + [[maybe_unused]] LABEL_ACC_FINISH: + __syncthreads(); + + // Phase 2: Exclusive prefix scan -> find threshold bin + { + constexpr uint32_t kItems = kHistBins / kBlockSize; + uint32_t orig[kItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t tmp_local_sum = 0; + +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + orig[i] = hist_vec[i]; + tmp_local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, tmp_local_sum); + const auto warp_exc = warp_inc - tmp_local_sum; + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + // Exactly one bin satisfies: above < K && above + count >= K + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // Phase 3: Scatter + // Elements strictly above threshold go directly to output. + // Tied elements: simple path admits first-come; tiebreak path collects into tie_buffer. + const bool need_tiebreak = (num_equal + num_above > K + kMaxTolerance); + const auto topk_indices = indices; + const auto tie_buffer = smem->tie_buffer; + +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if constexpr (!kIs2Pass) { + if (idx >= length) goto LABEL_SCATTER_DONE; + } + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + // prefetch the next scores + if constexpr (kIs2Pass) { + local[v].load(smem->score_buffer, tx + v * kBlockSize); + } + } + + // 16K ~ 32K, already in registers: similar loop as above but read from smem->score_buffer + if constexpr (kIs2Pass) { +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e + kMax1PassLength; + if (idx >= length) goto LABEL_SCATTER_DONE; + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + } + } + + [[maybe_unused]] LABEL_SCATTER_DONE: + if (!need_tiebreak) return; + + // Phase 4: Tie-breaking within the threshold bin. + // Assume num_ties <= kBlockSize (at most 1 block of ties). + // Each thread takes one tied element, computes its rank (number of + // elements with strictly higher score, breaking exact float ties by + // original index), and writes to output if rank < topk_remain. + __syncthreads(); + static_assert(kMaxTies <= kBlockSize); + + const uint32_t num_ties = min(num_equal, kMaxTies); + const uint32_t topk_remain = K - num_above; + + const auto is_greater = [](const Tie& a, const Tie& b) { + return (a.score > b.score) || (a.score == b.score && a.idx < b.idx); + }; + + if (num_ties <= kWarpThreads) { + static_assert(kWarpThreads <= kNumWarps); + if (lane_id >= num_ties || warp_id >= num_ties) return; // some threads are idle + /// NOTE: use long long to avoid mask overflow when num_ties == 32 + const uint32_t mask = (1ull << num_ties) - 1u; + const auto tie = tie_buffer[lane_id]; + const auto target_tie = tie_buffer[warp_id]; + const bool pred = is_greater(tie, target_tie); + const auto rank = static_cast(__popc(__ballot_sync(mask, pred))); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } else if (num_ties <= kWarpThreads * 2) { + // 64 x 64 topk implementation: each thread takes 2 elements + const auto lane_id_1 = lane_id + kWarpThreads; + const auto warp_id_1 = warp_id + kWarpThreads; + const auto invalid = Tie{.idx = 0xFFFFFFFF, .score = -FLT_MAX}; + const auto tie_0 = tie_buffer[lane_id]; + const auto tie_1 = lane_id_1 < num_ties ? tie_buffer[lane_id_1] : invalid; + if (true) { + const auto target = tie_buffer[warp_id]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + if (warp_id_1 < num_ties) { + const auto target = tie_buffer[warp_id_1]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + } else { + /// NOTE: Based on my observation, this path is very rarely reached + [[unlikely]]; + // Block-level: each thread reads from tie_buffer in shared memory + for (auto i = warp_id; i < num_ties; i += kNumWarps) { + const auto target_tie = tie_buffer[i]; + uint32_t local_rank = 0; + for (auto j = lane_id; j < num_ties; j += kWarpThreads) { + const auto tie = tie_buffer[j]; + if (is_greater(tie, target_tie)) local_rank++; + } + // sum the rank across the warp + const auto rank = warp::reduce_sum(local_rank); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } + } + } + + SGL_DEVICE static void transform(const TransformParams params) { + __syncthreads(); + if (const auto tx = threadIdx.x; tx < K) params.transform(tx); + } +}; + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh new file mode 100644 index 0000000000..4462b89a19 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct StreamingTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 2; // double buffer + + static constexpr uint32_t kHistItems = kHistBins / kBlockSize; // 4 + static_assert(kHistItems * kBlockSize == kHistBins); + using HistVec = AlignedVector; + + struct Smem { + uint64_t barrier[2][kNumStages]; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + union { + float score_buffer[kNumStages][kSizePerStage]; + TieHandleSmem stage2; // reuse smem for tie handling in phase D + }; + }; + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /// NOTE: length must be 4-aligned since we load 4 floats/thread. Caller should round up. + template + SGL_DEVICE static void issue_tma(const float* scores, uint32_t stage, uint32_t length, Smem* smem) { + const auto buf_idx = stage % kNumStages; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[kIsScatter][buf_idx]; + ptx::tma_load(smem->score_buffer[buf_idx], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + + // --------------------------------------------------------------------------- + // Unified streaming pass. Used for both phase A (kIsScatter=false) and + // phase C (kIsScatter=true). Each buffer is reused across iterations via the + // reuse-arrive trick (same pattern as ClusterTopKImpl::stage1). + // --------------------------------------------------------------------------- + + template + SGL_DEVICE static void stream_pass( + const float* scores, + const uint32_t length, + const uint32_t thr_bin, // ignored when !kIsScatter + int32_t* s_topk_indices, // ignored when !kIsScatter + Smem* smem) { + const auto tx = threadIdx.x; + const auto num_iters = (length + kSizePerStage - 1) / kSizePerStage; + const auto lane_id = tx % kWarpThreads; + + // Initial double-buffer TMA prologue. + const auto length_aligned = (length + 3u) & ~3u; + if (tx == 0) { +#pragma unroll + for (uint32_t i = 0; i < kNumStages; i++) { + if (i >= num_iters) break; + issue_tma(scores, i, length_aligned, smem); + } + } + + for (uint32_t iter = 0; iter < num_iters; iter++) { + const auto buf_idx = iter % kNumStages; + const auto offset = iter * kSizePerStage; + const auto this_size = min(kSizePerStage, length - offset); + + if (lane_id == 1) { + const auto phase_bit = (iter / kNumStages) & 1; + ptx::mbarrier_wait(&smem->barrier[kIsScatter][buf_idx], phase_bit); + } + __syncwarp(); + +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; i++) { + const auto local_idx = tx + i * kBlockSize; + if (local_idx >= this_size) break; + const auto score = smem->score_buffer[buf_idx][local_idx]; + const auto bin = extract_coarse_bin(score); + if constexpr (kIsScatter) { + const auto global_idx = offset + local_idx; + if (bin > thr_bin) { + const auto pos = atomicAdd(&smem->counter_gt, 1); + if (pos < K) s_topk_indices[pos] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } else { + atomicAdd(&smem->histogram[bin], 1); + } + } + + __syncthreads(); + if (tx == 0) { + if (const auto next_iter = iter + kNumStages; next_iter < num_iters) { + issue_tma(scores, next_iter, length_aligned, smem); + } + } + } + } + + // --------------------------------------------------------------------------- + // Phase B: find the threshold bin via a warp-level prefix scan. + // Same structure as SmallTopKImpl's phase 2 (4 bins/thread, warp_sum relay). + // --------------------------------------------------------------------------- + + SGL_DEVICE static void find_threshold(uint32_t length, Smem* smem) { + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + uint32_t orig[kHistItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t local_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + orig[i] = hist_vec[i]; + local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, local_sum); + const auto warp_exc = warp_inc - local_sum; + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kHistItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + SGL_DEVICE static void run(const float* scores, const uint32_t length, int32_t* topk_indices, void* _smem) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + + // Init histogram, barriers, counters. + { + HistVec zero; + zero.fill(0); + smem->histogram_vec[tx] = zero; + if (tx < 2 * kNumStages) { + const auto base_barrier = &smem->barrier[0][0]; + ptx::mbarrier_init(&base_barrier[tx], 1); + } + if (tx == 0) { + smem->counter_gt = 0; + smem->counter_eq = 0; + } + __syncthreads(); + } + + // Phase A: histogram pass (pipelined TMA stream). + stream_pass(scores, length, 0, nullptr, smem); + + // Phase B: locate threshold bin & re-init barriers + find_threshold(length, smem); + + // Phase C: scatter pass. + stream_pass(scores, length, smem->match.bin, topk_indices, smem); + } + + SGL_DEVICE static void transform(const TransformParams params, void* _smem) { + // Phase D: page-translate above entries, then refine ties. + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto num_above = smem->match.above_count; + if (tx < num_above) params.transform(tx); + const auto num_equal = smem->counter_eq; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(smem->tie_buffer, clamped_ties, num_above, K, params, &smem->stage2); + } +}; + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh new file mode 100644 index 0000000000..e0ce2dc086 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh @@ -0,0 +1,120 @@ +#pragma once +#include + +namespace device::distributed { + +inline constexpr uint32_t kMaxNumGPU = 8; + +struct alignas(128) Semaphore { + public: + constexpr Semaphore() : m_flag(0), m_counter(0) {} + + template + SGL_DEVICE uint32_t get() const { + uint32_t val; + if constexpr (kFence) { + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(val) : "l"(&m_flag)); + } else { + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(val) : "l"(&m_flag)); + } + return val; + } + + template + SGL_DEVICE uint32_t add(uint32_t val) { + uint32_t old_val; + if constexpr (kFence) { + asm volatile("atom.release.sys.global.add.u32 %0, [%1], %2;" : "=r"(old_val) : "l"(&m_flag), "r"(val)); + } else { + asm volatile("atom.global.add.u32 %0, [%1], %2;" : "=r"(old_val) : "l"(&m_flag), "r"(val)); + } + return old_val; + } + + // Only called by the owning GPU - plain load is sufficient + SGL_DEVICE uint32_t get_counter() const { + return m_counter; + } + + // Only called by the owning GPU - plain store is sufficient + SGL_DEVICE void set_counter(uint32_t val) { + m_counter = val; + } + + private: + uint32_t m_flag; + uint32_t m_counter; +}; + +struct PullController { + public: + using SignalType = Semaphore; + + PullController(void** signals, uint32_t num_gpu) { + for (uint32_t i = 0; i < num_gpu; ++i) { + m_signals[i] = static_cast(signals[i]); + } + } + + /// Synchronize all GPUs. + /// When kFence is true, establishes happens-before across GPUs using + /// release/acquire semantics, ensuring prior writes are visible system-wide. + template + SGL_DEVICE void sync(uint32_t rank, uint32_t num_gpu) const { + // For fenced sync: ensure all threads in this block have completed their writes, + // so the signaling thread's release carries them transitively. + static_assert(!(kFence && kStart), "Start stage does not need to wait fence"); + if constexpr (kFence || !kStart) __syncthreads(); + constexpr auto kStage = kStart ? 1 : 2; + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + if (lane_id == 0 && warp_id < num_gpu) { + auto& signal = m_signals[warp_id][blockIdx.x]; + signal.add(1); + if (warp_id == rank) { + const auto target = num_gpu * kStage; + /// NOTE: correctness here: + /// - base is only read/updated locally by the owning GPU + const auto base = signal.get_counter(); + while (signal.get() - base < target) + ; + if constexpr (!kStart) { + signal.set_counter(base + target); + } + } + } + if constexpr (kStart) __syncthreads(); + } + + private: + Semaphore* __restrict__ m_signals[kMaxNumGPU]; +}; + +struct PushController { + public: + using SignalType = uint32_t; + static constexpr int64_t kNumStages = 2; + + PushController(void* ptr) : m_local_signal(static_cast(ptr)) {} + + SGL_DEVICE SignalType epoch() const { + return m_local_signal[blockIdx.x]; + } + + SGL_DEVICE void exit() const { + __syncthreads(); + if (threadIdx.x == 0) { + this->exit_unsafe(blockIdx.x); + } + } + + SGL_DEVICE void exit_unsafe(uint32_t which) const { + auto& signal = m_local_signal[which]; + signal = (signal + 1) % kNumStages; + } + + private: + SignalType* m_local_signal; +}; + +} // namespace device::distributed diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh new file mode 100644 index 0000000000..239fac71a1 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh @@ -0,0 +1,354 @@ +#pragma once +#include + +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace host::distributed { + +using device::distributed::PullController, device::distributed::PushController; + +struct AllReduceData { + constexpr AllReduceData() {} + void* __restrict__ input[device::distributed::kMaxNumGPU]; +}; + +using ExternHandle = tvm::ffi::Array; + +inline ExternHandle to_extern_handle(void* ptr) { + ExternHandle array; + cudaIpcMemHandle_t handle; + RuntimeDeviceCheck(cudaIpcGetMemHandle(&handle, ptr)); + for (size_t i = 0; i < sizeof(handle); ++i) { + array.push_back(handle.reserved[i]); + } + return array; +} + +inline void* from_extern_handle(const ExternHandle& array) { + cudaIpcMemHandle_t handle; + RuntimeCheck(array.size() == sizeof(handle), "Invalid IPC handle size: ", array.size()); + for (size_t i = 0; i < sizeof(handle); ++i) { + handle.reserved[i] = array[i]; + } + void* ptr; + RuntimeDeviceCheck(cudaIpcOpenMemHandle(&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); + return ptr; +} + +struct HandleHash { + std::size_t operator()(const cudaIpcMemHandle_t& handle) const { + return std::hash{}({handle.reserved, sizeof(handle.reserved)}); + } +}; + +struct HandleEqual { + bool operator()(const cudaIpcMemHandle_t& a, const cudaIpcMemHandle_t& b) const { + return std::memcmp(a.reserved, b.reserved, sizeof(a.reserved)) == 0; + } +}; + +/** + * \brief The control plane of the custom all-reduce implementation. + * It manages the internal state and synchronization of the participating GPUs. + */ +struct CustomAllReduceBase : public tvm::ffi::Object { + public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.CustomAllReduce", CustomAllReduceBase, tvm::ffi::Object); + + static constexpr bool _type_mutable = true; + using InputPair = tvm::ffi::Tuple; // (offset, ipc handle) + + CustomAllReduceBase( + uint32_t rank, + uint32_t num_gpu, + uint32_t max_num_cta_pull, + uint32_t max_num_cta_push, + int64_t pull_buffer_size, + int64_t push_buffer_size, + int64_t graph_buffer_count) + : m_pull_buffer_bytes(pull_buffer_size), + m_push_buffer_bytes(push_buffer_size), + m_graph_buffer_count(graph_buffer_count), + m_rank(rank), + m_num_gpu(num_gpu), + m_max_num_cta_pull(max_num_cta_pull), + m_max_num_cta_push(max_num_cta_push), + // default config for pull kernel, can be updated by `configure()` + m_num_cta(max_num_cta_pull), + m_cta_size(256) { + RuntimeCheck(pull_buffer_size % 128 == 0, "Pull buffer size should be aligned to 128 bytes"); + RuntimeCheck(push_buffer_size % 128 == 0, "Push buffer size should be aligned to 128 bytes"); + RuntimeCheck(rank < num_gpu, "Invalid rank: ", rank); + const int64_t kU32Max = static_cast(std::numeric_limits::max()); + const int64_t push_buffer_size_all = push_all_ranks_bytes(); + RuntimeCheck(pull_buffer_size <= kU32Max, "Pull buffer size is too large: ", pull_buffer_size); + RuntimeCheck(push_buffer_size_all <= kU32Max, "Push buffer size is too large: ", push_buffer_size_all); + RuntimeDeviceCheck(cudaMalloc(&m_storage, storage_bytes())); + } + + ExternHandle share_storage() { + return to_extern_handle(m_storage); + } + + tvm::ffi::Array share_graph_inputs() { + tvm::ffi::Array result; + const auto new_inputs_count = registered_count() - m_cum_registered_count; + RuntimeCheck(new_inputs_count >= 0, "Invalid new count: ", new_inputs_count); + result.reserve(new_inputs_count); + std::unordered_map ipc_cache; + const auto get_handle = [&](void* ptr) -> ExternHandle { + const auto it = ipc_cache.find(ptr); + if (it != ipc_cache.end()) return it->second; + const auto handle = to_extern_handle(ptr); + ipc_cache.try_emplace(ptr, handle); + return handle; + }; + for (const auto ptr : std::span(m_graph_capture_inputs).subspan(m_cum_registered_count)) { + // note: must share the base address of each allocation, or we get wrong address + void* base_ptr; + const auto cu_result = cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr); + RuntimeCheck(cu_result == CUDA_SUCCESS, "failed to get pointer attr"); + const auto offset = reinterpret_cast(ptr) - reinterpret_cast(base_ptr); + result.push_back(InputPair{offset, get_handle(base_ptr)}); + } + return result; + } + + void post_init(tvm::ffi::Array ipc_storages) { + RuntimeCheck(ipc_storages.size() == m_num_gpu, "Invalid array size: ", ipc_storages.size()); + m_peer_storage.resize(m_num_gpu); + for (const auto i : irange(m_num_gpu)) { + if (i == m_rank) { + m_peer_storage[i] = m_storage; + } else { + m_peer_storage[i] = from_extern_handle(ipc_storages[i]); + } + } + + // set signal buffer to zero + const auto pull_signal = get_pull_signal(m_storage); + RuntimeDeviceCheck(cudaMemset(pull_signal, 0, pull_signal_bytes())); + + // update the pull controller and data pointer + RuntimeCheck(!m_pull_ctrl.has_value(), "Controller is already initialized"); + m_pull_ctrl.emplace(m_peer_storage.data(), m_num_gpu); + AllReduceData data; + for (const auto i : irange(m_num_gpu)) { + data.input[i] = get_pull_buffer(m_peer_storage[i]); + } + const auto default_data_ptr = get_data_ptr(); + RuntimeDeviceCheck(cudaMemcpy(default_data_ptr, &data, sizeof(AllReduceData), cudaMemcpyHostToDevice)); + + // update the push controller and data pointer + RuntimeCheck(!m_push_ctrl.has_value(), "Controller is already initialized"); + const auto push_signal = get_push_signal(m_storage); + RuntimeDeviceCheck(cudaMemset(push_signal, 0, push_signal_bytes())); + m_push_ctrl.emplace(push_signal); + const auto push_buffer = get_push_buffer(m_storage); + RuntimeDeviceCheck(cudaMemset(push_buffer, 0, push_all_ranks_bytes())); + } + + void register_inputs(tvm::ffi::Array> ipc_graph_inputs) { + RuntimeCheck(ipc_graph_inputs.size() == m_num_gpu); + const auto new_registered_count = registered_count() - m_cum_registered_count; + RuntimeCheck(new_registered_count >= 0, "Invalid registered count: ", new_registered_count); + if (new_registered_count == 0) return; // avoid `m_get_data_ptr()` out-of-bounds + std::vector data; + data.resize(new_registered_count); + const auto open_cached = [&](const ExternHandle& h) -> void* { + RuntimeCheck(h.size() == sizeof(cudaIpcMemHandle_t), "Invalid IPC handle size: ", h.size()); + cudaIpcMemHandle_t handle; + for (size_t i = 0; i < sizeof(handle); ++i) + handle.reserved[i] = h[i]; + const auto [it, success] = m_ipc_cache.try_emplace(handle, nullptr); + if (success) { + void* ptr; + RuntimeDeviceCheck(cudaIpcOpenMemHandle(&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); + it->second = ptr; + } + return it->second; + }; + for (const auto i : irange(ipc_graph_inputs.size())) { + const auto& array = ipc_graph_inputs[i]; + RuntimeCheck(int64_t(array.size()) == new_registered_count); + if (i == m_rank) { + for (const auto j : irange(new_registered_count)) { + data[j].input[i] = m_graph_capture_inputs[m_cum_registered_count + j]; + } + } else { + for (const auto j : irange(new_registered_count)) { + /// NOTE: structural binding will cause intern compiler error... + const auto elem = array[j]; + const auto offset = elem.get<0>(); + const auto ipc_handle = elem.get<1>(); + data[j].input[i] = pointer::offset(open_cached(ipc_handle), offset); + } + } + } + + const auto new_registered_bytes = sizeof(AllReduceData) * new_registered_count; + const auto dst_ptr = get_data_ptr(m_cum_registered_count); + m_cum_registered_count += new_registered_count; + RuntimeDeviceCheck(cudaMemcpy(dst_ptr, data.data(), new_registered_bytes, cudaMemcpyHostToDevice)); + } + + void set_cuda_graph_capture(bool enabled) { + m_is_graph_capturing = enabled; + } + + void free_ipc_handles() { + for (const auto& pair : m_ipc_cache) { + host::RuntimeDeviceCheck(cudaIpcCloseMemHandle(pair.second)); + } + m_ipc_cache.clear(); + } + + void free_storage() { + host::RuntimeDeviceCheck(cudaFree(m_storage)); + m_storage = nullptr; + } + + tvm::ffi::Tuple configure_pull(uint32_t num_cta, uint32_t cta_size) { + using host::RuntimeCheck; + const auto min_cta_size = m_num_gpu * device::kWarpThreads; + RuntimeCheck(num_cta > 0 && num_cta <= m_max_num_cta_pull, "Invalid number of CTAs: ", num_cta); + RuntimeCheck(cta_size >= min_cta_size, "Block size must be at least ", min_cta_size); + const auto old_num_cta = m_num_cta; + const auto old_block_size = m_cta_size; + m_num_cta = num_cta; + m_cta_size = cta_size; + return tvm::ffi::Tuple{old_num_cta, old_block_size}; + } + + protected: + AllReduceData* allocate_graph_capture_input(void* data_ptr) { + const auto count = registered_count(); + RuntimeCheck(count < m_graph_buffer_count, "Graph buffer overflow, increase `graph_buffer_count`!"); + m_graph_capture_inputs.push_back(data_ptr); + return get_data_ptr(count); + } + AllReduceData* get_data_ptr(int64_t which = -1) { + const auto count = registered_count(); + RuntimeCheck(which >= -1 && which < count, "Invalid graph buffer index: ", which, ", count: ", count); + const auto start = get_pull_params(m_storage); + return static_cast(start) + (1 + which); + } + int64_t registered_count() const { + return static_cast(m_graph_capture_inputs.size()); + } + int64_t pull_signal_bytes() const { + return _align_bytes(sizeof(PullController::SignalType) * m_max_num_cta_pull); + } + int64_t push_signal_bytes() const { + return _align_bytes(sizeof(PushController::SignalType) * m_max_num_cta_push); + } + int64_t graph_param_bytes() const { + return _align_bytes(sizeof(AllReduceData) * (1 + m_graph_buffer_count)); // 1 for default + } + int64_t push_all_ranks_bytes() const { + return _align_bytes(PushController::kNumStages * m_num_gpu * m_push_buffer_bytes); + } + int64_t storage_bytes() const { + return _get_offset_impl(5); + } + void* get_pull_signal(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(0)); + } + void* get_push_signal(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(1)); + } + void* get_pull_params(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(2)); + } + void* get_pull_buffer(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(3)); + } + void* get_push_buffer(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(4)); + } + int64_t _get_offset_impl(int64_t which) const { + // | SignalArray (pull + push) | GraphBuffers (pull params) | Buffers (pull + push) | + const int64_t offset_map[5] = { + /*[0]=*/pull_signal_bytes(), + /*[1]=*/push_signal_bytes(), + /*[2]=*/graph_param_bytes(), + /*[3]=*/m_pull_buffer_bytes, + /*[4]=*/push_all_ranks_bytes(), + }; + RuntimeCheck(which >= 0 && which <= 5, "Invalid offset index: ", which); + return std::accumulate(offset_map, offset_map + which, int64_t(0)); + } + static int64_t _align_bytes(int64_t size) { + return div_ceil(size, 128) * 128; + } + + const int64_t m_pull_buffer_bytes; + const int64_t m_push_buffer_bytes; + const int64_t m_graph_buffer_count; + const uint32_t m_rank; + const uint32_t m_num_gpu; + const uint32_t m_max_num_cta_pull; + const uint32_t m_max_num_cta_push; + // these 2 config should only affect pull kernel + uint32_t m_num_cta; + uint32_t m_cta_size; + // other states + bool m_is_graph_capturing = false; + int64_t m_cum_registered_count = 0; + std::optional m_pull_ctrl; + std::optional m_push_ctrl; + void* m_storage = nullptr; + std::vector m_graph_capture_inputs; + std::vector m_peer_storage; + std::unordered_map m_ipc_cache; +}; + +struct CustomAllReduceRef : public tvm::ffi::ObjectRef { + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(CustomAllReduceRef, tvm::ffi::ObjectRef, CustomAllReduceBase); +}; + +} // namespace host::distributed + +namespace device::distributed { + +template +SGL_DEVICE auto reduce_impl(AlignedVector (&storage)[M]) -> AlignedVector { + fp32x2_t acc[N] = {}; +#pragma unroll // unroll num gpu + for (uint32_t i = 0; i < M; ++i) { +#pragma unroll // unroll vec + for (uint32_t j = 0; j < N; ++j) { + const auto [x, y] = cast(storage[i][j]); + auto& [x_acc, y_acc] = acc[j]; + x_acc += x; + y_acc += y; + } + } + + AlignedVector result; +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + result[j] = cast(acc[j]); + } + + return result; +} + +} // namespace device::distributed diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h new file mode 100644 index 0000000000..17d9048d4c --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h @@ -0,0 +1,104 @@ +#pragma once +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace host::ffi { + +using tvm::ffi::Tensor, tvm::ffi::TensorView, tvm::ffi::ShapeView; + +inline Tensor empty(ShapeView shape, DLDataType dtype, DLDevice device) { + return Tensor::FromEnvAlloc(::TVMFFIEnvTensorAlloc, shape, dtype, device); +} + +inline Tensor empty_like(TensorView tensor) { + return empty(tensor.shape(), tensor.dtype(), tensor.device()); +} + +struct _dummy_deleter { + void operator()(void*) const {} +}; + +// template + +template +struct FromBlobContext { + [[no_unique_address]] Fn deleter; + int64_t dimension; + int64_t* get_shape() { + return reinterpret_cast(this + 1); + } + int64_t* get_stride() { + return this->get_shape() + dimension; + } +}; + +template +inline Tensor from_blob( + void* data, + ShapeView shape, + DLDataType dtype, + DLDevice device, + Fn&& deleter = {}, + std::optional stride = {}, + uint64_t byte_offset = 0) { + using Context = FromBlobContext>; + const auto ndim = shape.size(); + const auto ctx = [&] { + auto ptr = std::malloc(sizeof(Context) + sizeof(int64_t) * ndim * 2); + auto ctx = static_cast(ptr); + std::construct_at(ctx, std::forward(deleter), static_cast(ndim)); + stdr::copy_n(shape.data(), ndim, ctx->get_shape()); + if (stride.has_value()) { + RuntimeCheck(stride->size() == ndim, "Stride ndim mismatch!"); + stdr::copy_n(stride->data(), ndim, ctx->get_stride()); + } else { + int64_t stride_val = 1; + for (const auto i : irange(ndim)) { + const auto j = ndim - 1 - i; + ctx->get_stride()[j] = stride_val; + stride_val *= shape[j]; + } + } + return ctx; + }(); + const auto tensor = DLTensor{ + .data = data, + .device = device, + .ndim = static_cast(ndim), + .dtype = dtype, + .shape = ctx->get_shape(), + .strides = ctx->get_stride(), + .byte_offset = byte_offset, + }; + const auto blob_deleter = [](DLManagedTensor* self) { + auto ctx = static_cast(self->manager_ctx); + ctx->deleter(self->dl_tensor.data); + std::destroy_at(ctx); + std::free(ctx); + }; + auto managed_tensor = DLManagedTensor{tensor, ctx, blob_deleter}; + return Tensor::FromDLPack(&managed_tensor); +} + +template +inline Tensor from_blob_like( + void* data, + TensorView t, + Fn&& deleter = {}, + bool is_contiguous = false, // if override to true, the stride will be ignored + uint64_t byte_offset = 0) { + const auto stride = is_contiguous ? std::nullopt : std::optional{t.strides()}; + return from_blob(data, t.shape(), t.dtype(), t.device(), std::forward(deleter), stride, byte_offset); +} + +} // namespace host::ffi diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh new file mode 100644 index 0000000000..cd024acd46 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh @@ -0,0 +1,168 @@ +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace host::norm { + +/** + * \brief Check if the given configuration is supported. + * \tparam T Element type (only fp16_t/bf16_t is supported) + * \tparam kDim Dimension size (usually hidden size) + */ +template +inline constexpr bool is_config_supported() { + if (!std::is_same_v && !std::is_same_v) return false; + if (kDim <= 256) { + return (kDim == 64 || kDim == 128 || kDim == 256); + } else { + return (kDim % 256 == 0 && kDim <= 8192); + } +} + +/** + * \brief Determine whether to use cta norm based on dimension size. + * TL;DR: use warp norm for dim <= 256, cta norm otherwise. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \note This function assumes that the configuration is supported. + * \see `is_config_supported` + */ +template +inline constexpr bool should_use_cta() { + static_assert(is_config_supported(), "Unsupported norm configuration"); + return kDim > 256; +} + +/** + * \brief Get the number of threads per CTA for cta norm. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \return Number of threads per CTA + */ +template +inline constexpr uint32_t get_cta_threads() { + static_assert(should_use_cta()); + return (kDim / 256) * device::kWarpThreads; +} + +} // namespace host::norm + +namespace device::norm { + +namespace details { + +template +SGL_DEVICE AlignedVector apply_norm_impl( + const AlignedVector input, + const AlignedVector weight, + const float eps, + [[maybe_unused]] float* smem_buffer, + [[maybe_unused]] uint32_t num_warps) { + float sum_of_squares = 0.0f; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + sum_of_squares += fp32_input.x * fp32_input.x; + sum_of_squares += fp32_input.y * fp32_input.y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + float norm_factor; + if constexpr (kUseCTA) { + // need to synchronize across the cta + const auto warp_id = threadIdx.x / kWarpThreads; + smem_buffer[warp_id] = sum_of_squares; + __syncthreads(); + // use the first warp to reduce + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < num_warps ? smem_buffer[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem_buffer[32] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + norm_factor = smem_buffer[32]; + } else { + norm_factor = math::rsqrt(sum_of_squares / kDim + eps); + } + + AlignedVector output; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + const auto fp32_weight = cast(weight[i]); + output[i] = cast({ + fp32_input.x * norm_factor * fp32_weight.x, + fp32_input.y * norm_factor * fp32_weight.y, + }); + } + + return output; +} + +} // namespace details + +/** + * \brief Apply norm using warp-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_warp(const T& input, const T& weight, float eps) { + static_assert(kDim <= 256, "Warp norm only supports dim <= 256"); + return details::apply_norm_impl(input, weight, eps, nullptr, 0); +} + +/** + * \brief Apply norm using CTA-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \param smem Shared memory buffer + * \param num_warps Number of warps in the CTA + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_cta( + const T& input, const T& weight, float eps, float* smem, uint32_t num_warps = blockDim.x / kWarpThreads) { + static_assert(kDim > 256, "CTA norm only supports dim > 256"); + return details::apply_norm_impl(input, weight, eps, smem, num_warps); +} + +/** + * \brief Storage type for norm operation. + * For warp norm, the storage size depends on kDim. + * For cta norm, the storage size is fixed to 16B. + * We will also pack the input 16-bit floats into 32-bit types + * for faster CUDA core operations. + * + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size + */ +template +using StorageType = std::conditional_t< // storage type + (kDim > 256), // whether to use cta norm + AlignedVector, 4>, // cta norm storage, fixed to 16B + AlignedVector, kDim / (2 * kWarpThreads)> // warp norm storage + >; + +/** + * \brief Minimum shared memory size (in bytes) required for cta norm. + */ +inline constexpr uint32_t kSmemBufferSize = 33; + +} // namespace device::norm diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh new file mode 100644 index 0000000000..4f9ac48141 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh @@ -0,0 +1,71 @@ +/// \file math.cuh +/// \brief Device-side math helper functions and constants. +/// +/// Provides type-generic wrappers around CUDA math intrinsics by +/// dispatching through `dtype_trait`. All functions are forced-inline +/// device functions. + +#pragma once +#include + +#include + +namespace device::math { + +/// \brief Constant: log2(e) +inline constexpr float log2e = 1.44269504088896340736f; +/// \brief Constant: ln(2) +inline constexpr float loge2 = 0.693147180559945309417f; +/// \brief Maximum representable value for FP8 E4M3 format. +inline constexpr float FP8_E4M3_MAX = 448.0f; +static_assert(log2e * loge2 == 1.0f, "log2e * loge2 must be 1"); + +/// \brief Returns the larger of `a` and `b`. +template +SGL_DEVICE T max(T a, T b) { + return dtype_trait::max(a, b); +} + +/// \brief Returns the smaller of `a` and `b`. +template +SGL_DEVICE T min(T a, T b) { + return dtype_trait::min(a, b); +} + +/// \brief Returns the absolute value of `a`. +template +SGL_DEVICE T abs(T a) { + return dtype_trait::abs(a); +} + +/// \brief Returns the square root of `a`. +template +SGL_DEVICE T sqrt(T a) { + return dtype_trait::sqrt(a); +} + +/// \brief Returns the reciprocal square root of `a` (i.e. 1 / sqrt(a)). +template +SGL_DEVICE T rsqrt(T a) { + return dtype_trait::rsqrt(a); +} + +/// \brief Returns e^a. +template +SGL_DEVICE T exp(T a) { + return dtype_trait::exp(a); +} + +/// \brief Returns sin(a). +template +SGL_DEVICE T sin(T a) { + return dtype_trait::sin(a); +} + +/// \brief Returns cos(a). +template +SGL_DEVICE T cos(T a) { + return dtype_trait::cos(a); +} + +} // namespace device::math diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh new file mode 100644 index 0000000000..4ea722a3fe --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh @@ -0,0 +1,86 @@ +/// \file runtime.cuh +/// \brief Host-side CUDA runtime query helpers. +/// +/// Thin wrappers around CUDA occupancy and device-property APIs with +/// automatic error checking via `RuntimeDeviceCheck`. + +#pragma once + +#include + +#include +#include +#ifndef USE_ROCM +#include +#else +#include +#ifndef cudaOccupancyMaxActiveBlocksPerMultiprocessor +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#endif +#ifndef cudaDeviceGetAttribute +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#endif +#ifndef cudaDevAttrMultiProcessorCount +#define cudaDevAttrMultiProcessorCount hipDeviceAttributeMultiprocessorCount +#endif +#ifndef cudaDevAttrComputeCapabilityMajor +#define cudaDevAttrComputeCapabilityMajor hipDeviceAttributeComputeCapabilityMajor +#endif +#ifndef cudaRuntimeGetVersion +#define cudaRuntimeGetVersion hipRuntimeGetVersion +#endif +#ifndef cudaOccupancyAvailableDynamicSMemPerBlock +inline hipError_t +cudaOccupancyAvailableDynamicSMemPerBlock(std::size_t* smem, const void* func, int num_blocks, int block_size) { + // HIP does not expose this directly; return max shared mem as conservative estimate + hipDeviceProp_t prop; + int device; + hipGetDevice(&device); + hipGetDeviceProperties(&prop, device); + *smem = prop.sharedMemPerBlock; + return hipSuccess; +} +#endif +#endif + +namespace host::runtime { + +// Return the maximum number of active blocks per SM for the given kernel +template +inline auto get_blocks_per_sm(T&& kernel, int32_t block_dim, std::size_t dynamic_smem = 0) -> uint32_t { + int num_blocks_per_sm = 0; + RuntimeDeviceCheck( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, block_dim, dynamic_smem)); + return static_cast(num_blocks_per_sm); +} + +// Return the number of SMs for the given device +inline auto get_sm_count(int device_id) -> uint32_t { + int sm_count; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); + return static_cast(sm_count); +} + +// Return the Major compute capability for the given device +inline auto get_cc_major(int device_id) -> int { + int cc_major; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device_id)); + return cc_major; +} + +// Return the runtime version +inline auto get_runtime_version() -> int { + int runtime_version; + RuntimeDeviceCheck(cudaRuntimeGetVersion(&runtime_version)); + return runtime_version; +} + +// Return the maximum dynamic shared memory per block for the given kernel +template +inline auto get_available_dynamic_smem_per_block(T&& kernel, int num_blocks, int block_size) -> std::size_t { + std::size_t smem_size; + RuntimeDeviceCheck(cudaOccupancyAvailableDynamicSMemPerBlock(&smem_size, kernel, num_blocks, block_size)); + return smem_size; +} + +} // namespace host::runtime diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp b/lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp new file mode 100644 index 0000000000..d229d3a975 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp @@ -0,0 +1,334 @@ +#pragma once + +#include +#include +#ifndef __CUDACC__ +#include +#endif + +namespace host { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ + private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + + public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = + "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h new file mode 100644 index 0000000000..7c9fd52131 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h @@ -0,0 +1,40 @@ +/// \file source_location.h +/// \brief Portable `source_location` wrapper. +/// +/// Uses `std::source_location` when available (C++20), otherwise falls +/// back to a minimal stub that returns empty/zero values. + +#pragma once +#include + +/// NOTE: fallback to a minimal source_location implementation +#if defined(__cpp_lib_source_location) +#include + +using source_location_t = std::source_location; + +#else + +struct source_location_fallback { + public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char* file_name() const noexcept { + return ""; + } + constexpr const char* function_name() const noexcept { + return ""; + } +}; + +using source_location_t = source_location_fallback; + +#endif diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h new file mode 100644 index 0000000000..1ae9233a61 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h @@ -0,0 +1,605 @@ +/// \file tensor.h +/// \brief Tensor validation and symbolic matching utilities. +/// +/// Provides the `TensorMatcher` fluent API for validating tensor shapes, +/// strides, dtypes, and devices at kernel entry points, along with +/// `SymbolicSize`, `SymbolicDType`, and `SymbolicDevice` for capturing +/// and cross-checking tensor metadata across multiple tensors. +/// +/// See the "Tensor Checking" section in the JIT kernel dev guide for +/// usage examples. + +#pragma once +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#elif defined(__HIPCC__) +#include +#endif + +namespace host { + +namespace details { + +inline constexpr auto kAnyDeviceID = -1; +inline constexpr auto kAnySize = static_cast(-1); +inline constexpr auto kNullSize = static_cast(-1); +inline constexpr auto kNullDType = static_cast(18u); +inline constexpr auto kNullDevice = static_cast(-1); + +struct SizeRef; +struct DTypeRef; +struct DeviceRef; + +template +struct _dtype_trait {}; + +template +struct _dtype_trait { + inline static constexpr DLDataType value = { + .code = std::is_signed_v ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt, + .bits = static_cast(sizeof(T) * 8), + .lanes = 1}; +}; + +template +struct _dtype_trait { + inline static constexpr DLDataType value = { + .code = DLDataTypeCode::kDLFloat, .bits = static_cast(sizeof(T) * 8), .lanes = 1}; +}; + +#ifdef __CUDACC__ +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat8_e4m3fn, .bits = 8, .lanes = 1}; +}; +#elif defined(__HIPCC__) +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1}; +}; +#endif + +template +struct _device_trait { + inline static constexpr DLDevice value = {.device_type = Code, .device_id = kAnyDeviceID}; +}; + +template +inline constexpr auto kDTypeList = std::array{_dtype_trait::value...}; + +template +inline constexpr auto kDeviceList = std::array{_device_trait::value...}; + +template +struct PrintAbleSpan { + explicit PrintAbleSpan(std::span data) : data(data) {} + std::span data; +}; + +// define DLDataType comparison and printing in root namespace +inline constexpr auto kDeviceStringMap = [] { + constexpr auto map = std::array, 16>{ + std::pair{DLDeviceType::kDLCPU, "cpu"}, + std::pair{DLDeviceType::kDLCUDA, "cuda"}, + std::pair{DLDeviceType::kDLCUDAHost, "cuda_host"}, + std::pair{DLDeviceType::kDLOpenCL, "opencl"}, + std::pair{DLDeviceType::kDLVulkan, "vulkan"}, + std::pair{DLDeviceType::kDLMetal, "metal"}, + std::pair{DLDeviceType::kDLVPI, "vpi"}, + std::pair{DLDeviceType::kDLROCM, "rocm"}, + std::pair{DLDeviceType::kDLROCMHost, "rocm_host"}, + std::pair{DLDeviceType::kDLExtDev, "ext_dev"}, + std::pair{DLDeviceType::kDLCUDAManaged, "cuda_managed"}, + std::pair{DLDeviceType::kDLOneAPI, "oneapi"}, + std::pair{DLDeviceType::kDLWebGPU, "webgpu"}, + std::pair{DLDeviceType::kDLHexagon, "hexagon"}, + std::pair{DLDeviceType::kDLMAIA, "maia"}, + std::pair{DLDeviceType::kDLTrn, "trn"}, + }; + constexpr auto max_type = stdr::max(map | stdv::keys); + auto result = std::array{}; + for (const auto& [code, name] : map) { + result[static_cast(code)] = name; + } + return result; +}(); + +struct PrintableDevice { + DLDevice device; +}; + +inline auto& operator<<(std::ostream& os, DLDevice device) { + const auto& mapping = kDeviceStringMap; + const auto entry = static_cast(device.device_type); + RuntimeCheck(entry < mapping.size()); + const auto name = mapping[entry]; + RuntimeCheck(!name.empty(), "Unknown device: ", int(device.device_type)); + os << name; + if (device.device_id != kAnyDeviceID && device.device_type != DLDeviceType::kDLCPU) { + os << ":" << device.device_id; + } + return os; +} + +inline auto& operator<<(std::ostream& os, PrintableDevice pd) { + return os << pd.device; +} + +template +inline auto& operator<<(std::ostream& os, PrintAbleSpan span) { + os << "["; + for (const auto i : irange(span.data.size())) { + if (i > 0) { + os << ", "; + } + os << span.data[i]; + } + os << "]"; + return os; +} + +} // namespace details + +/// \brief Check whether `dtype` matches the DLDataType for C++ type `T`. +template +inline bool is_type(DLDataType dtype) { + return dtype == details::_dtype_trait::value; +} + +/** + * \brief A symbolic dimension size that can be bound once and + * verified across multiple tensors. + * + * Create with an optional annotation string for error messages: + * \code + * auto N = SymbolicSize{"num_tokens"}; + * \endcode + * + * Call `verify()` during tensor matching to either bind the first + * observed value or check subsequent values match. Call `unwrap()` + * to retrieve the bound value (panics if unset). + */ +struct SymbolicSize { + public: + SymbolicSize(std::string_view annotation = {}) : m_value(details::kNullSize), m_annotation(annotation) {} + SymbolicSize(const SymbolicSize&) = delete; + SymbolicSize& operator=(const SymbolicSize&) = delete; + + auto get_name() const -> std::string_view { + return m_annotation; + } + + auto set_value(int64_t value) -> void { + RuntimeCheck(!this->has_value(), "Size value already set"); + m_value = value; + } + + auto has_value() const -> bool { + return m_value != details::kNullSize; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> int64_t { + RuntimeCheck(info, this->has_value(), "Size value is not set"); + return m_value; + } + + auto verify(int64_t value, const char* prefix, int64_t dim) -> void { + if (this->has_value()) { + if (m_value != value) { + [[unlikely]]; + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " but got ", value); + } + } else { + this->set_value(value); + } + } + + auto value_or_name(const char* prefix, int64_t dim) const -> std::string { + if (const auto value = this->get_value()) { + return std::to_string(*value); + } else { + return m_name_str(prefix, dim); + } + } + + private: + auto m_name_str(const char* prefix, int64_t dim) const -> std::string { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_annotation.empty()) os << "('" << m_annotation << "')"; + return std::move(os).str(); + } + + std::int64_t m_value; + std::string_view m_annotation; +}; + +inline auto operator==(DLDevice lhs, DLDevice rhs) -> bool { + return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id; +} + +/** + * \brief A symbolic data type that can be constrained and verified. + * + * Optionally restrict allowed types via `set_options()`. + * Use `verify()` to bind/check the dtype, and `unwrap()` to retrieve it. + */ +struct SymbolicDType { + public: + SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} + SymbolicDType(const SymbolicDType&) = delete; + SymbolicDType& operator=(const SymbolicDType&) = delete; + + auto set_value(DLDataType value) -> void { + RuntimeCheck(!this->has_value(), "Dtype value already set"); + RuntimeCheck( + m_check(value), "Dtype value [", value, "] not in the allowed options: ", details::PrintAbleSpan{m_options}); + m_value = value; + } + + auto has_value() const -> bool { + return m_value.code != details::kNullDType; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> DLDataType { + RuntimeCheck(info, this->has_value(), "Dtype value is not set"); + return m_value; + } + + auto set_options(std::span options) -> void { + m_options = options; + } + + template + auto set_options() -> void { + m_options = details::kDTypeList; + } + + auto verify(DLDataType dtype) -> void { + if (this->has_value()) { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " but got ", dtype); + } else { + this->set_value(dtype); + } + } + + template + auto is_type() const -> bool { + return ::host::is_type(m_value); + } + + private: + auto m_check(DLDataType value) const -> bool { + return stdr::empty(m_options) || (stdr::find(m_options, value) != stdr::end(m_options)); + } + + std::span m_options; + DLDataType m_value; +}; + +/** + * \brief A symbolic device that can be constrained and verified. + * + * Optionally restrict allowed device types via + * `set_options()`. The device id can be wildcarded. + */ +struct SymbolicDevice { + public: + SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} + SymbolicDevice(const SymbolicDevice&) = delete; + SymbolicDevice& operator=(const SymbolicDevice&) = delete; + + auto set_value(DLDevice value) -> void { + RuntimeCheck(!this->has_value(), "Device value already set"); + RuntimeCheck( + m_check(value), + "Device value [", + details::PrintableDevice{value}, + "] not in the allowed options: ", + details::PrintAbleSpan{m_options}); + m_value = value; + } + + auto has_value() const -> bool { + return m_value.device_type != details::kNullDevice; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> DLDevice { + RuntimeCheck(info, this->has_value(), "Device value is not set"); + return m_value; + } + + auto set_options(std::span options) -> void { + m_options = options; + } + + template + auto set_options() -> void { + m_options = details::kDeviceList; + } + + auto verify(DLDevice device) -> void { + if (this->has_value()) { + RuntimeCheck( + m_value == device, + "Device mismatch: expected ", + details::PrintableDevice{m_value}, + " but got ", + details::PrintableDevice{device}); + } else { + this->set_value(device); + } + } + + private: + auto m_check(DLDevice value) const -> bool { + return stdr::empty(m_options) || (stdr::any_of(m_options, [value](const DLDevice& opt) { + // device type must exactly match + if (opt.device_type != value.device_type) return false; + // device id can be wildcarded + return opt.device_id == details::kAnyDeviceID || opt.device_id == value.device_id; + })); + } + + std::span m_options; + DLDevice m_value; +}; + +namespace details { + +template +struct BaseRef { + public: + BaseRef(const BaseRef&) = delete; + BaseRef& operator=(const BaseRef&) = delete; + + auto operator->() const -> T* { + return m_ref; + } + auto operator*() const -> T& { + return *m_ref; + } + auto rebind(T& other) -> void { + m_ref = &other; + } + + explicit BaseRef() : m_ref(&m_cache), m_cache() {} + BaseRef(T& size) : m_ref(&size), m_cache() {} + + private: + T* m_ref; + T m_cache; +}; + +struct SizeRef : BaseRef { + using BaseRef::BaseRef; + SizeRef(int64_t value) { + if (value != kAnySize) { + (**this).set_value(value); + } else { + // otherwise, we can match any size + } + } +}; + +struct DTypeRef : BaseRef { + using BaseRef::BaseRef; + DTypeRef(DLDataType options) { + (**this).set_value(options); + } + DTypeRef(std::initializer_list options) { + (**this).set_options(options); + } + DTypeRef(std::span options) { + (**this).set_options(options); + } +}; + +struct DeviceRef : BaseRef { + using BaseRef::BaseRef; + DeviceRef(DLDevice options) { + (**this).set_value(options); + } + DeviceRef(std::initializer_list options) { + (**this).set_options(options); + } + DeviceRef(std::span options) { + (**this).set_options(options); + } +}; + +} // namespace details + +/** + * \brief Fluent API for validating tensor shape, strides, dtype, and device. + * + * Construct with the expected shape (using `SymbolicSize` or literal + * integers), chain `.with_strides()`, `.with_dtype<...>()`, and + * `.with_device<...>()`, then call `.verify(tensor)`. + * + * Example: + * \code + * auto N = SymbolicSize{"N"}; + * TensorMatcher({N, 128}) + * .with_dtype() + * .with_device() + * .verify(input_tensor); + * \endcode + * + * \note `TensorMatcher` is a move-only temporary. Do not store in a variable. + */ +struct TensorMatcher { + private: + using SizeRef = details::SizeRef; + using DTypeRef = details::DTypeRef; + using DeviceRef = details::DeviceRef; + + public: + TensorMatcher(const TensorMatcher&) = delete; + TensorMatcher& operator=(const TensorMatcher&) = delete; + + explicit TensorMatcher(std::initializer_list shape) : m_shape(shape), m_strides(), m_dtype() {} + + auto with_strides(std::initializer_list strides) && -> TensorMatcher&& { + // no partial update allowed + RuntimeCheck(m_strides.size() == 0, "Strides already specified"); + RuntimeCheck(m_shape.size() == strides.size(), "Strides size must match shape size"); + m_strides = strides; + return std::move(*this); + } + + template + auto with_dtype(DTypeRef&& dtype) && -> TensorMatcher&& { + m_init_dtype(); + m_dtype.rebind(*dtype); + m_dtype->set_options(); + return std::move(*this); + } + + template + auto with_dtype() && -> TensorMatcher&& { + static_assert(sizeof...(Ts) > 0, "At least one dtype option must be specified"); + m_init_dtype(); + m_dtype->set_options(); + return std::move(*this); + } + + template + auto with_device(DeviceRef&& device) && -> TensorMatcher&& { + m_init_device(); + m_device.rebind(*device); + m_device->set_options(); + return std::move(*this); + } + + template + auto with_device() && -> TensorMatcher&& { + static_assert(sizeof...(Codes) > 0, "At least one device option must be specified"); + m_init_device(); + m_device->set_options(); + return std::move(*this); + } + + // once we start verification, we cannot modify anymore + auto verify(tvm::ffi::TensorView view, DebugInfo info = {}) const&& -> const TensorMatcher&& { + try { + m_verify_impl(view); + } catch (PanicError& e) { + auto oss = std::ostringstream{}; + oss << "Tensor match failed for "; + s_print_tensor(oss, view); + oss << " at " << info.file_name() << ":" << info.line() << "\n- Root cause: " << e.root_cause(); + throw PanicError(std::move(oss).str()); + } + return std::move(*this); + } + + private: + static auto s_print_tensor(std::ostringstream& oss, tvm::ffi::TensorView view) -> void { + oss << "Tensor<"; + int64_t dim = 0; + for (const auto& size : view.shape()) { + if (dim++ > 0) oss << ", "; + oss << size; + } + oss << ">[strides=<"; + dim = 0; + for (const auto& stride : view.strides()) { + if (dim++ > 0) { + oss << ", "; + } + oss << stride; + } + oss << ">, dtype=" << view.dtype(); + oss << ", device=" << details::PrintableDevice{view.device()} << "]"; + } + + auto m_verify_impl(tvm::ffi::TensorView view) const -> void { + const auto dim = static_cast(view.dim()); + RuntimeCheck(dim == m_shape.size(), "Tensor dimension mismatch: expected ", m_shape.size(), " but got ", dim); + for (const auto i : irange(dim)) { + m_shape[i]->verify(view.size(i), "shape", i); + } + if (m_has_strides()) { + for (const auto i : irange(dim)) { + if (view.size(i) != 1 || !m_strides[i]->has_value()) { + // skip stride check for size 1 dimension + m_strides[i]->verify(view.stride(i), "stride", i); + } + } + } else { + RuntimeCheck(view.is_contiguous(), "Tensor is not contiguous as expected"); + } + // since we may double verify, we will force to check + m_dtype->verify(view.dtype()); + m_device->verify(view.device()); + } + + auto m_init_dtype() -> void { + RuntimeCheck(!m_has_dtype, "DType already specified"); + m_has_dtype = true; + } + + auto m_init_device() -> void { + RuntimeCheck(!m_has_device, "Device already specified"); + m_has_device = true; + } + + auto m_has_strides() const -> bool { + return !m_strides.empty(); + } + + std::span m_shape; + std::span m_strides; + DTypeRef m_dtype; + DeviceRef m_device; + bool m_has_dtype = false; + bool m_has_device = false; +}; + +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh new file mode 100644 index 0000000000..1adc821706 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh @@ -0,0 +1,62 @@ +/// \file tile.cuh +/// \brief Tiled memory access helpers for coalesced global memory I/O. +/// +/// `tile::Memory` represents a contiguous memory region where multiple +/// threads cooperatively load/store elements. The three factory methods +/// determine the thread group: +/// - `thread()` - single thread (no tiling). +/// - `warp()` - all threads in a warp cooperate. +/// - `cta()` - all threads in the CTA cooperate. + +#pragma once +#include + +#include + +namespace device::tile { + +/** + * \brief Represents a contiguous memory region for cooperative tiled access. + * + * Each instance is parameterized by an element type `T` and bound to a + * specific thread id (`tid`) within a group of `tsize` threads. + * + * \tparam T The storage element type (e.g. `AlignedVector, 4>`). + */ +template +struct Memory { + public: + SGL_DEVICE constexpr Memory(uint32_t tid, uint32_t tsize) : tid(tid), tsize(tsize) {} + /// \brief Create a Memory accessor for a single thread (no cooperation). + SGL_DEVICE static constexpr Memory thread() { + return Memory{0, 1}; + } + /// \brief Create a Memory accessor distributed across warp threads. + SGL_DEVICE static Memory warp(int warp_threads = kWarpThreads) { + return Memory{static_cast(threadIdx.x % warp_threads), static_cast(warp_threads)}; + } + /// \brief Create a Memory accessor distributed across all CTA threads. + SGL_DEVICE static Memory cta(int cta_threads = blockDim.x) { + return Memory{static_cast(threadIdx.x), static_cast(cta_threads)}; + } + /// \brief Load one element from `ptr` at the position assigned to this thread. + /// \param ptr Base pointer (cast to `const T*`). + /// \param offset Optional tile offset (multiplied by `tsize`). + SGL_DEVICE T load(const void* ptr, int64_t offset = 0) const { + return static_cast(ptr)[tid + offset * tsize]; + } + /// \brief Store one element to `ptr` at the position assigned to this thread. + SGL_DEVICE void store(void* ptr, T val, int64_t offset = 0) const { + static_cast(ptr)[tid + offset * tsize] = val; + } + /// \brief Check whether this thread's element index is within bounds. + SGL_DEVICE bool in_bound(int64_t element_count, int64_t offset = 0) const { + return tid + offset * tsize < element_count; + } + + private: + uint32_t tid; + uint32_t tsize; +}; + +} // namespace device::tile diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh new file mode 100644 index 0000000000..a7a5346196 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh @@ -0,0 +1,120 @@ +/// \file type.cuh +/// \brief Dtype trait system for CUDA scalar/packed types. +/// +/// `dtype_trait` provides per-type metadata: packed type alias, +/// conversion functions (`from`), and unary/binary math operations. +/// Use `device::cast(from_value)` for type conversion on device. +/// +/// Registered types: +/// | Scalar | Packed (x2) | Notes | +/// |-----------|-------------|-------------------------------| +/// | `fp32_t` | `fp32x2_t` | Full math ops (abs,sqrt,...) | +/// | `fp16_t` | `fp16x2_t` | Conversion only | +/// | `bf16_t` | `bf16x2_t` | Conversion only | +/// | `fp32x2_t`| `fp32x4_t` | Packed float2 <-> half2/bf162 | + +#pragma once +#include + +template +struct dtype_trait {}; + +#define SGL_REGISTER_DTYPE_TRAIT(TYPE, PACK2, ...) \ + template <> \ + struct dtype_trait { \ + using self_t = TYPE; \ + using packed_t = PACK2; \ + template \ + SGL_DEVICE static self_t from(const S& value) { \ + return static_cast(value); \ + } \ + __VA_ARGS__ \ + } + +#define SGL_REGISTER_TYPE_END static_assert(true) + +#define SGL_REGISTER_FROM_FUNCTION(FROM, FN) \ + SGL_DEVICE static self_t from(const FROM& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_UNARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_BINARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x, const self_t& y) { \ + return FN(x, y); \ + } \ + static_assert(true) + +SGL_REGISTER_DTYPE_TRAIT( + fp32_t, fp32x2_t, SGL_REGISTER_TYPE_END; // + SGL_REGISTER_FROM_FUNCTION(fp16_t, __half2float); + SGL_REGISTER_FROM_FUNCTION(bf16_t, __bfloat162float); + SGL_REGISTER_UNARY_FUNCTION(abs, fabsf); + SGL_REGISTER_UNARY_FUNCTION(sqrt, sqrtf); + SGL_REGISTER_UNARY_FUNCTION(rsqrt, rsqrtf); + SGL_REGISTER_UNARY_FUNCTION(exp, expf); + SGL_REGISTER_UNARY_FUNCTION(sin, sinf); + SGL_REGISTER_UNARY_FUNCTION(cos, cosf); + SGL_REGISTER_BINARY_FUNCTION(max, fmaxf); + SGL_REGISTER_BINARY_FUNCTION(min, fminf);); +SGL_REGISTER_DTYPE_TRAIT(fp16_t, fp16x2_t); +SGL_REGISTER_DTYPE_TRAIT(bf16_t, bf16x2_t); + +/// TODO: Add ROCM implementation +SGL_REGISTER_DTYPE_TRAIT( + fp32x2_t, fp32x4_t, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp16x2_t, __half22float2); + SGL_REGISTER_FROM_FUNCTION(bf16x2_t, __bfloat1622float2);); + +SGL_REGISTER_DTYPE_TRAIT( + fp16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22half2_rn);); + +SGL_REGISTER_DTYPE_TRAIT( + bf16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn);); + +#ifndef USE_ROCM +SGL_REGISTER_DTYPE_TRAIT(fp8_e4m3_t, fp8x2_e4m3_t); +#endif + +#undef SGL_REGISTER_DTYPE_TRAIT +#undef SGL_REGISTER_FROM_FUNCTION + +/// \brief Alias: the packed (x2) type for `T`. +template +using packed_t = typename dtype_trait::packed_t; + +namespace device { + +/** + * \brief Cast a value from type `From` to type `To` on device. + * + * Dispatches through `dtype_trait::from()`, which uses the appropriate + * CUDA intrinsic (e.g. `__half2float`, `__float22half2_rn`). + */ +template +SGL_DEVICE To cast(const From& value) { + return dtype_trait::from(value); +} + +} // namespace device + +// --------------------------------------------------------------------------- +// FP8 max clamp value — platform-dependent +// CUDA (e4m3fn): 448.0f +// AMD FNUZ (e4m3fnuz): 224.0f +// AMD E4M3 (e4m3fn): 448.0f +// --------------------------------------------------------------------------- +#ifndef USE_ROCM +constexpr float kFP8E4M3Max = 448.0f; +#else // USE_ROCM +#if HIP_FP8_TYPE_FNUZ +constexpr float kFP8E4M3Max = 224.0f; +#else // HIP_FP8_TYPE_E4M3 +constexpr float kFP8E4M3Max = 448.0f; +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh new file mode 100644 index 0000000000..2dd6f3dc93 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh @@ -0,0 +1,333 @@ +/// \file utils.cuh +/// \brief Core CUDA/device utilities: type aliases, PDL helpers, +/// typed pointer access, kernel launch wrapper, and error checking. +/// +/// This header is included (directly or transitively) by nearly every +/// JIT kernel. It provides: +/// - Scalar/packed type aliases (`fp16_t`, `bf16_t`, `fp8_e4m3_t`, ...). +/// - `SGL_DEVICE` macro (forced-inline device function qualifier). +/// - `kWarpThreads` constant (32). +/// - PDL (Programmatic Dependent Launch) helpers for Hopper (sm_90+). +/// - Typed `load_as` / `store_as` for void-pointer access. +/// - `pointer::offset` for safe void-pointer arithmetic. +/// - `host::LaunchKernel` - kernel launcher with optional PDL. +/// - `host::RuntimeDeviceCheck` - CUDA error checking. + +#pragma once + +#include + +#include +#include + +#include +#include +#include +#ifndef USE_ROCM +#include +#include +#include +#include +#else +#include +#include +#include +#ifndef __grid_constant__ +#define __grid_constant__ +#endif +using cudaError_t = hipError_t; +using cudaStream_t = hipStream_t; +using cudaLaunchConfig_t = hipLaunchConfig_t; +using cudaLaunchAttribute = hipLaunchAttribute; +inline constexpr auto cudaSuccess = hipSuccess; +#define cudaStreamPerThread hipStreamPerThread +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaLaunchKernel hipLaunchKernel +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#endif + +#ifndef USE_ROCM +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __nv_bfloat16; +using fp8_e4m3_t = __nv_fp8_e4m3; +using fp8_e5m2_t = __nv_fp8_e5m2; + +using fp32x2_t = float2; +using fp16x2_t = __half2; +using bf16x2_t = __nv_bfloat162; +using fp8x2_e4m3_t = __nv_fp8x2_e4m3; +using fp8x2_e5m2_t = __nv_fp8x2_e5m2; + +using fp32x4_t = float4; +#else +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __hip_bfloat16; +using fp8_e4m3_t = uint8_t; +using fp8_e5m2_t = uint8_t; +using fp32x2_t = float2; +using fp16x2_t = half2; +using bf16x2_t = __hip_bfloat162; +using fp8x2_e4m3_t = uint16_t; +using fp8x2_e5m2_t = uint16_t; +using fp32x4_t = float4; +#endif + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + +// DLPack device type for the current platform +#ifndef USE_ROCM +inline constexpr auto kDLGPU = kDLCUDA; +#else +inline constexpr auto kDLGPU = kDLROCM; +#endif + +namespace device { + +/// \brief Macro: forced-inline device function qualifier. +#define SGL_DEVICE __forceinline__ __device__ + +// Architecture detection: SGL_CUDA_ARCH is injected by load_jit() and is +// available in both host and device compilation passes, whereas __CUDA_ARCH__ +// is only defined by nvcc during the device pass. +#if !defined(USE_ROCM) +#if !defined(SGL_CUDA_ARCH) +#error "SGL_CUDA_ARCH is not defined. JIT compilation must inject -DSGL_CUDA_ARCH via load_jit()." +#endif +#if defined(__CUDA_ARCH__) +static_assert( + __CUDA_ARCH__ == SGL_CUDA_ARCH, "SGL_CUDA_ARCH mismatch: injected arch flag does not match device target"); +#endif +#define SGL_ARCH_HOPPER_OR_GREATER (SGL_CUDA_ARCH >= 900) +#define SGL_ARCH_BLACKWELL_OR_GREATER ((SGL_CUDA_ARCH >= 1000) && (CUDA_VERSION >= 12090)) +#else // USE_ROCM +#define SGL_ARCH_HOPPER_OR_GREATER 0 +#define SGL_ARCH_BLACKWELL_OR_GREATER 0 +#endif + +// Maximum vector size in bytes supported by current architecture. +// Pre-Blackwell / AMD: 128-bit (16 bytes) +// Blackwell or greater: 256-bit (32 bytes) +inline constexpr std::size_t kMaxVecBytes = SGL_ARCH_BLACKWELL_OR_GREATER ? 32 : 16; + +/// \brief Number of threads per warp (always 32 on NVIDIA/AMD GPUs). +inline constexpr auto kWarpThreads = 32u; +/// \brief Full warp active mask (all 32 lanes). +#ifndef USE_ROCM +inline constexpr auto kFullMask = 0xffffffffu; +#else +inline constexpr auto kFullMask = 0xffffffffffffffffULL; +#endif + +/** + * \brief PDL (Programmatic Dependent Launch): wait for the primary kernel. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.wait` instruction to + * synchronize with a preceding kernel in the same stream. On older + * architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLWaitPrimary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.wait;" ::: "memory"); + } +#endif +} + +/** + * \brief PDL: trigger dependent (secondary) kernel launch. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.launch_dependents` + * instruction. On older architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLTriggerSecondary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.launch_dependents;" :::); + } +#endif +} + +template +SGL_DEVICE constexpr auto div_ceil(T a, U b) { + return (a + b - 1) / b; +} + +/** + * \brief Load data with the specified type and offset from a void pointer. + * \tparam T The type to load. + * \param ptr The base pointer. + * \param offset The offset in number of elements of type T. + */ +template +SGL_DEVICE T load_as(const void* ptr, int64_t offset = 0) { + return static_cast(ptr)[offset]; +} + +/** + * \brief Store data with the specified type and offset to a void pointer. + * \tparam T The type to store. + * \param ptr The base pointer. + * \param val The value to store. + * \param offset The offset in number of elements of type T. + * \note we use type_identity_t to force the caller to explicitly specify + * the template parameter `T`, which can avoid accidentally using the wrong type. + */ +template +SGL_DEVICE void store_as(void* ptr, std::type_identity_t val, int64_t offset = 0) { + static_cast(ptr)[offset] = val; +} + +/// \brief Safe void-pointer arithmetic (byte-level by default). +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template +SGL_DEVICE auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} + +template +SGL_DEVICE auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +} // namespace device + +namespace host { + +/** + * \brief Check the CUDA error code and panic with location info on failure. + */ +inline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) { + if (error != ::cudaSuccess) { + [[unlikely]]; + ::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error)); + } +} + +/// \brief Check the last CUDA error (calls `cudaGetLastError`). +inline void RuntimeDeviceCheck(DebugInfo location = {}) { + return RuntimeDeviceCheck(::cudaGetLastError(), location); +} + +/** + * \brief Kernel launcher with automatic stream resolution and PDL support. + * + * Usage: + * \code + * host::LaunchKernel(grid, block, device) + * .enable_pdl(true) + * (my_kernel, arg1, arg2); + * \endcode + * + * The constructor resolves the CUDA stream from a `DLDevice` (via + * `TVMFFIEnvGetStream`) or accepts a raw `cudaStream_t`. The call + * operator launches the kernel and checks for errors. + */ +struct LaunchKernel { + public: + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + DLDevice device, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)), + m_location(location) {} + + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)), m_location(location) {} + + LaunchKernel(const LaunchKernel&) = delete; + LaunchKernel& operator=(const LaunchKernel&) = delete; + + static auto resolve_device(DLDevice device) -> cudaStream_t { + return static_cast(::TVMFFIEnvGetStream(device.device_type, device.device_id)); + } + + auto enable_pdl(bool enabled = true) -> LaunchKernel& { +#ifdef USE_ROCM + (void)enabled; + m_config.numAttrs = 0; +#else + if (enabled) { + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = true; + m_config.attrs = m_attrs; + } +#endif + return *this; + } + + auto enable_cluster(dim3 cluster_dim) -> LaunchKernel& { +#ifdef USE_ROCM + (void)cluster_dim; +#else + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {cluster_dim.x, cluster_dim.y, cluster_dim.z}; + m_config.attrs = m_attrs; +#endif + return *this; + } + + template + auto operator()(T&& kernel, Args&&... args) const -> void { +#ifdef USE_ROCM + hipLaunchKernelGGL( + std::forward(kernel), + m_config.gridDim, + m_config.blockDim, + m_config.dynamicSmemBytes, + m_config.stream, + std::forward(args)...); + RuntimeDeviceCheck(m_location); +#else + RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward(args)...), m_location); +#endif + } + + private: + static auto s_make_config( // Make a config for kernel launch + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t smem) -> cudaLaunchConfig_t { + auto config = ::cudaLaunchConfig_t{}; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem; + config.stream = stream; + config.numAttrs = 0; + return config; + } + + cudaLaunchConfig_t m_config; + const DebugInfo m_location; + cudaLaunchAttribute m_attrs[2]; +}; + +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h new file mode 100644 index 0000000000..3226f79ddc --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h @@ -0,0 +1,186 @@ +/// \file utils.h +/// \brief Host-side C++ utilities used by JIT kernel wrappers. +/// +/// Provides: +/// - `DebugInfo` - wraps `std::source_location` for error reporting. +/// - `RuntimeCheck` - runtime assertion with formatted error messages. +/// - `Panic` - unconditional abort with formatted error messages. +/// - `pointer::offset` - safe void-pointer arithmetic (host side). +/// - `div_ceil` - integer ceiling division. +/// - `dtype_bytes` - byte width of a `DLDataType`. +/// - `irange` - Python-style integer range for range-for loops. + +#pragma once + +// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3 +#ifdef __CUDACC__ +#include +#if CUDA_VERSION <= 12010 + +#pragma push_macro("__cpp_consteval") +#pragma push_macro("_NODISCARD") +#pragma push_macro("__builtin_LINE") + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbuiltin-macro-redefined" +#define __cpp_consteval 201811L +#pragma clang diagnostic pop + +#ifdef _NODISCARD +#undef _NODISCARD +#define _NODISCARD +#endif + +#define consteval constexpr + +#include "source_location.h" + +#undef consteval +#pragma pop_macro("__cpp_consteval") +#pragma pop_macro("_NODISCARD") +#else // __CUDACC__ && CUDA_VERSION > 12010 +#include "source_location.h" +#endif +#else // no __CUDACC__ +#include "source_location.h" +#endif + +#include + +#include +#include +#include +#include +#include +#include + +namespace host { + +template +inline constexpr bool dependent_false_v = false; + +/// \brief Source-location wrapper for debug/error messages. +struct DebugInfo : public source_location_t { + DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} +}; + +/// \brief Exception type thrown by `RuntimeCheck` and `Panic`. +struct PanicError : public std::runtime_error { + public: + explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} + auto root_cause() const -> std::string_view { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); + } + + private: + std::string m_message; +}; + +/// \brief Unconditionally abort with a formatted error message. +template +[[noreturn]] +inline auto panic(DebugInfo location, Args&&... args) -> void { + std::ostringstream os; + os << "Runtime check failed at " << location.file_name() << ":" << location.line(); + if constexpr (sizeof...(args) > 0) { + os << ": "; + (os << ... << std::forward(args)); + } else { + os << " in " << location.function_name(); + } + throw PanicError(std::move(os).str()); +} + +/** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ +template +struct RuntimeCheck { + template + explicit RuntimeCheck(Cond&& condition, Args&&... args, DebugInfo location = {}) { + if (condition) return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + template + explicit RuntimeCheck(DebugInfo location, Cond&& condition, Args&&... args) { + if (condition) return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } +}; + +template +struct Panic { + explicit Panic(Args&&... args, DebugInfo location = {}) { + ::host::panic(location, std::forward(args)...); + } + explicit Panic(DebugInfo location, Args&&... args) { + ::host::panic(location, std::forward(args)...); + } + [[noreturn]] ~Panic() { + std::terminate(); + } +}; + +template +explicit RuntimeCheck(Cond&&, Args&&...) -> RuntimeCheck; + +template +explicit RuntimeCheck(DebugInfo, Cond&&, Args&&...) -> RuntimeCheck; + +template +explicit Panic(Args&&...) -> Panic; + +template +explicit Panic(DebugInfo, Args&&...) -> Panic; + +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template +inline auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} + +template +inline auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +/// \brief Integer ceiling division: ceil(a / b). +template +inline constexpr auto div_ceil(T a, U b) { + return (a + b - 1) / b; +} + +/// \brief Returns the byte width of a DLPack data type. +inline auto dtype_bytes(DLDataType dtype) -> std::size_t { + return static_cast(dtype.bits / 8); +} + +namespace stdr = std::ranges; +namespace stdv = stdr::views; + +/// \brief Python-style integer range: `irange(n)` -> `[0, n)`. +template +inline auto irange(T end) { + return stdv::iota(static_cast(0), end); +} + +/// \brief Python-style integer range: `irange(start, end)` -> `[start, end)`. +template +inline auto irange(T start, T end) { + return stdv::iota(start, end); +} + +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh new file mode 100644 index 0000000000..67f388679f --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh @@ -0,0 +1,118 @@ +/// \file vec.cuh +/// \brief Aligned vector types for coalesced global memory access. +/// +/// `AlignedVector` wraps `N` elements of type `T` in a naturally +/// aligned struct so that the compiler emits wide (vectorized) load/store +/// instructions (e.g. `LDG.128`). The maximum supported vector width is +/// 256 bits (32 bytes), matching CUDA's widest vector load. + +#pragma once +#include + +#include +#include + +namespace device { + +namespace details { + +/// \brief Maps byte-width to the corresponding unsigned integer type. +template +struct uint_trait {}; + +template <> +struct uint_trait<1> { + using type = uint8_t; +}; + +template <> +struct uint_trait<2> { + using type = uint16_t; +}; + +template <> +struct uint_trait<4> { + using type = uint32_t; +}; + +template <> +struct uint_trait<8> { + using type = uint64_t; +}; + +/// \brief Alias: maps `sizeof(T)` to matching unsigned int type. +template +using sized_int = typename uint_trait::type; + +} // namespace details + +/// \brief Raw aligned storage for `N` elements of type `T`. +template +struct alignas(sizeof(T) * N) AlignedStorage { + T data[N]; +}; + +/** + * \brief Aligned vector for vectorized memory access on GPU. + * + * Stores `N` elements of type `T` with natural alignment so that a single + * `load`/`store` call compiles to a wide memory transaction. + * + * \tparam T Element type (e.g. `fp16_t`, `bf16_t`, `float`). + * \tparam N Number of elements. Must be a power of two and + * `sizeof(T) * N <= 32` (256 bits). + * + * Example: + * \code + * AlignedVector vec; // 16 bytes, 128-bit aligned + * vec.load(input_ptr, tid); // vectorized load + * vec[0] = vec[0] + 1; + * vec.store(output_ptr, tid); // vectorized store + * \endcode + */ +template +struct AlignedVector { + private: + static_assert( + (N > 0 && (N & (N - 1)) == 0) && sizeof(T) * N <= kMaxVecBytes, + "CUDA vector size exceeds arch limit: max 16 bytes on pre-Blackwell/AMD, " + "32 bytes on Blackwell or greater"); + using element_t = typename details::sized_int; + using storage_t = AlignedStorage; + + public: + /// \brief Vectorized load from `ptr` at the given element `offset`. + SGL_DEVICE void load(const void* ptr, int64_t offset = 0) { + m_storage = reinterpret_cast(ptr)[offset]; + } + /// \brief Vectorized store to `ptr` at the given element `offset`. + SGL_DEVICE void store(void* ptr, int64_t offset = 0) const { + reinterpret_cast(ptr)[offset] = m_storage; + } + /// \brief Fill all N elements with the same `value`. + SGL_DEVICE void fill(T value) { + const auto store_value = *reinterpret_cast(&value); +#pragma unroll + for (std::size_t i = 0; i < N; ++i) { + m_storage.data[i] = store_value; + } + } + + SGL_DEVICE auto operator[](std::size_t idx) -> T& { + return reinterpret_cast(&m_storage)[idx]; + } + SGL_DEVICE auto operator[](std::size_t idx) const -> T { + return reinterpret_cast(&m_storage)[idx]; + } + SGL_DEVICE auto data() -> T* { + return reinterpret_cast(&m_storage); + } + SGL_DEVICE auto data() const -> const T* { + return reinterpret_cast(&m_storage); + } + + private: + storage_t m_storage; +}; + +} // namespace device diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh new file mode 100644 index 0000000000..9d82efae1e --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh @@ -0,0 +1,56 @@ +/// \file warp.cuh +/// \brief Warp-level reduction primitives. + +#pragma once +#include +#include + +namespace device::warp { + +/// \brief Full warp active mask. +#ifndef USE_ROCM +static constexpr uint32_t kFullMask = 0xffffffffu; +using mask_t = uint32_t; +#else +static constexpr uint64_t kFullMask = 0xffffffffffffffffULL; +using mask_t = uint64_t; +#endif + +/** + * \brief Warp-level sum reduction. + * + * On CUDA: uses __shfl_xor_sync with width=32. + * On HIP: uses __shfl_xor with explicit width parameter (supports wave64 sub-groups). + */ +template +SGL_DEVICE T reduce_sum(T value, mask_t active_mask = kFullMask) { + static_assert(kNumThreads >= 1 && kNumThreads <= kWarpThreads); + static_assert(std::has_single_bit(kNumThreads), "must be pow of 2"); +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) +#ifndef USE_ROCM + value = value + __shfl_xor_sync(active_mask, value, mask, 32); +#else + value = value + __shfl_xor(value, mask, kNumThreads); +#endif + return value; +} + +/** + * \brief Warp-level max reduction. + */ +template +SGL_DEVICE T reduce_max(T value, mask_t active_mask = kFullMask) { + static_assert(kNumThreads >= 1 && kNumThreads <= kWarpThreads); + static_assert(std::has_single_bit(kNumThreads), "must be pow of 2"); +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) +#ifndef USE_ROCM + value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); +#else + value = math::max(value, __shfl_xor(value, mask, kNumThreads)); +#endif + return value; +} + +} // namespace device::warp diff --git a/lightllm/third_party/sglang_jit/jit_utils.py b/lightllm/third_party/sglang_jit/jit_utils.py new file mode 100644 index 0000000000..4096c16bb4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/jit_utils.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import functools +import importlib.util +import logging +import os +import pathlib +from contextlib import contextmanager +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + TypeAlias, + TypeVar, + Union, +) + +import torch + +if TYPE_CHECKING: + from tvm_ffi import Module + +F = TypeVar("F", bound=Callable[..., Any]) +_FULL_TEST_ENV_VAR = "SGLANG_JIT_KERNEL_RUN_FULL_TESTS" + +logger = logging.getLogger(__name__) + + +def is_in_ci() -> bool: + return os.getenv("SGLANG_IS_IN_CI", "").lower() in ("1", "true", "yes", "y") + + +def should_run_full_tests() -> bool: + return os.getenv(_FULL_TEST_ENV_VAR, "false").lower() == "true" + + +def get_ci_test_range(full_range: List[Any], ci_range: List[Any]) -> List[Any]: + if should_run_full_tests(): + return full_range + return ci_range if is_in_ci() else full_range + + +def cache_once(fn: F) -> F: + """ + NOTE: `functools.lru_cache` is not compatible with `torch.compile` + So we manually implement a simple cache_once decorator to replace it. + """ + result_map = {} + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + key = (args, tuple(sorted(kwargs.items()))) + if key not in result_map: + result_map[key] = fn(*args, **kwargs) + return result_map[key] + + return wrapper # type: ignore + + +def _make_wrapper(tup: Tuple[str, str]) -> str: + export_name, kernel_name = tup + return f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({export_name}, ({kernel_name}));" + + +@cache_once +def _resolve_kernel_path() -> pathlib.Path: + cur_dir = pathlib.Path(__file__).parent.resolve() + + # first, try this directory structure + def _environment_install(): + candidate = cur_dir.resolve() + if (candidate / "include").exists() and (candidate / "csrc").exists(): + return candidate + return None + + def _package_install(): + # TODO: support find path by package + return None + + path = _environment_install() or _package_install() + if path is None: + raise RuntimeError("Cannot find sglang.jit_kernel path") + return path + + +KERNEL_PATH = _resolve_kernel_path() +DEFAULT_INCLUDE = [str(KERNEL_PATH / "include")] +DEFAULT_CFLAGS = ["-std=c++20", "-O3"] +DEFAULT_LDFLAGS = [] +CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, str, bool, torch.dtype] + + +class CPPArgList(list[str]): + def __str__(self) -> str: + return ", ".join(self) + + +CPP_DTYPE_MAP = { + torch.float: "fp32_t", + torch.float16: "fp16_t", + torch.float8_e4m3fn: "fp8_e4m3_t", + torch.bfloat16: "bf16_t", + torch.int8: "int8_t", + torch.int32: "int32_t", + torch.int64: "int64_t", +} + + +# AMD/ROCm note: +@cache_once +def is_hip_runtime() -> bool: + return bool(torch.version.hip) + + +# MThreads/MUSA note: +@cache_once +def is_musa_runtime() -> bool: + return hasattr(torch.version, "musa") and torch.version.musa is not None + + +def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList: + def _convert(arg: CPP_TEMPLATE_TYPE) -> str: + if isinstance(arg, bool): + return "true" if arg else "false" + if isinstance(arg, (int, str, float)): + return str(arg) + if isinstance(arg, torch.dtype): + return CPP_DTYPE_MAP[arg] + raise TypeError(f"Unsupported argument type for cpp template: {type(arg)}") + + return CPPArgList(_convert(arg) for arg in args) + + +def load_jit( + *args: str, + cpp_files: List[str] | None = None, + cuda_files: List[str] | None = None, + cpp_wrappers: List[Tuple[str, str]] | None = None, + cuda_wrappers: List[Tuple[str, str]] | None = None, + extra_cflags: List[str] | None = None, + extra_cuda_cflags: List[str] | None = None, + extra_ldflags: List[str] | None = None, + extra_include_paths: List[str] | None = None, + extra_dependencies: List[str] | None = None, + build_directory: str | None = None, + header_only: bool = True, +) -> Module: + """ + Loading a JIT module from C++/CUDA source files. + We define a wrapper as a tuple of (export_name, kernel_name), + where `export_name` is the name used to called from Python, + and `kernel_name` is the name of the kernel class in C++/CUDA source. + + :param args: Unique marker of the JIT module. Must be distinct for different kernels. + :type args: str + :param cpp_files: A list of C++ source files. + :type cpp_files: List[str] | None + :param cuda_files: A list of CUDA source files. + :type cuda_files: List[str] | None + :param cpp_wrappers: A list of C++ wrappers, defining the export name and kernel name. + :type cpp_wrappers: List[Tuple[str, str]] | None + :param cuda_wrappers: A list of CUDA wrappers, defining the export name and kernel name. + :type cuda_wrappers: List[Tuple[str, str]] | None + :param extra_cflags: Extra C++ compiler flags. + :type extra_cflags: List[str] | None + :param extra_cuda_cflags: Extra CUDA compiler flags. + :type extra_cuda_cflags: List[str] | None + :param extra_ldflags: Extra linker flags. + :type extra_ldflags: List[str] | None + :param extra_include_paths: Extra include paths. + :type extra_include_paths: List[str] | None + :param extra_dependencies: Extra dependencies for the JIT module, e.g., cutlass. + :type extra_dependencies: List[str] | None + :param build_directory: The build directory for JIT compilation. + :type build_directory: str | None + :param header_only: Whether the module is header-only. + If true, apply the wrappers to export given class/functions. + Otherwise, we must export from C++/CUDA side. + :return: A just-in-time(JIT) compiled module. + :rtype: Module + """ + + from tvm_ffi.cpp import load, load_inline + + cpp_files = cpp_files or [] + cuda_files = cuda_files or [] + extra_cflags = extra_cflags or [] + extra_cuda_cflags = extra_cuda_cflags or [] + extra_ldflags = extra_ldflags or [] + extra_include_paths = extra_include_paths or [] + + cpp_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cpp_files] + cuda_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cuda_files] + + for dep in set(extra_dependencies or []): + if dep not in _REGISTERED_DEPENDENCIES: + raise ValueError(f"Dependency {dep} is not registered.") + extra_include_paths += _REGISTERED_DEPENDENCIES[dep]() + + module_name = "sgl_kernel_jit_" + "_".join(str(arg) for arg in args) + if header_only: + cpp_wrappers = cpp_wrappers or [] + cuda_wrappers = cuda_wrappers or [] + cpp_sources = [f'#include "{path}"' for path in cpp_files] + cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] + + # include cuda files + cuda_sources = [f'#include "{path}"' for path in cuda_files] + cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] + with _jit_compile_context(): + return load_inline( + module_name, + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + else: + assert cpp_wrappers is None and cuda_wrappers is None + with _jit_compile_context(): + return load( + module_name, + cpp_files=cpp_files, + cuda_files=cuda_files, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + + +@dataclass +class ArchInfo: + major: int + minor: int + suffix: str + + @property + def target_name(self) -> str: + return f"{self.major}.{self.minor}{self.suffix}" + + @property + def jit_flag(self) -> str: + return f"-DSGL_CUDA_ARCH={self.major * 100 + self.minor * 10}" + + +@cache_once +def _init_jit_cuda_arch_once(): + global _CUDA_ARCH + try: + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + except Exception: + logger.warning("Cannot detect CUDA architecture.") + major, minor = 0, 0 # invalid value to trigger compile error if used + _CUDA_ARCH = ArchInfo(major, minor, "") + + +@contextmanager +def _jit_compile_context(): + if is_hip_runtime(): + yield # TODO: support ROCm `TVM_FFI_ROCM_ARCH_LIST` if needed + return + env_key = "TVM_FFI_CUDA_ARCH_LIST" + old_value = os.environ.get(env_key, None) + os.environ[env_key] = get_jit_cuda_arch().target_name + try: + yield + finally: + if old_value is None: + os.environ.pop(env_key, None) + else: + os.environ[env_key] = old_value + + +# NOTE: this might also be used in __main__.py for compile flags export +def _get_default_target_flags() -> List[str]: + if is_hip_runtime(): + flags = ["-DUSE_ROCM", "-std=c++20", "-O3"] + # Detect FP8 type based on GPU architecture + try: + device = torch.cuda.current_device() + gcn_arch = torch.cuda.get_device_properties(device).gcnArchName + if "gfx942" in gcn_arch: + flags.append("-DHIP_FP8_TYPE_FNUZ=1") + else: + flags.append("-DHIP_FP8_TYPE_E4M3=1") + except Exception: + flags.append("-DHIP_FP8_TYPE_E4M3=1") + return flags + else: + return [ + get_jit_cuda_arch().jit_flag, + "-std=c++20", + "-O3", + "--expt-relaxed-constexpr", + ] + + +@contextmanager +def override_jit_cuda_arch(major: int, minor: int, suffix: str = ""): + """A context manager to temporarily override CUDA architecture.""" + global _CUDA_ARCH + old_value = get_jit_cuda_arch() + _CUDA_ARCH = ArchInfo(major, minor, suffix) + try: + yield + finally: + _CUDA_ARCH = old_value + + +def get_jit_cuda_arch() -> ArchInfo: + """Get the current CUDA architecture info.""" + _init_jit_cuda_arch_once() + return _CUDA_ARCH + + +@cache_once +def is_arch_support_pdl() -> bool: + if is_hip_runtime() or is_musa_runtime(): + return False + return get_jit_cuda_arch().major >= 9 + + +def _find_package_root(package: str) -> Optional[pathlib.Path]: + spec = importlib.util.find_spec(package) + if spec is None or spec.origin is None: + return None + return pathlib.Path(spec.origin).resolve().parent + + +# NOTE: this might also be used in __main__.py for compile flags export +_REGISTERED_DEPENDENCIES: Dict[str, Callable[[], List[str]]] = {} + + +def register_dependency(name: str): + def decorator(f: Callable[[], List[str]]) -> Callable[[], List[str]]: + if name in _REGISTERED_DEPENDENCIES: + raise ValueError(f"Dependency {name} already registered") + _REGISTERED_DEPENDENCIES[name] = f + return f + + return decorator + + +@register_dependency("flashinfer") +def get_flashinfer_include_paths() -> List[str]: + include_paths: List[str] = [] + flashinfer_root = _find_package_root("flashinfer") + if flashinfer_root is None: + raise RuntimeError( + "Cannot find flashinfer package. Please install flashinfer to get" + "the required headers for JIT compilation." + ) + + flashinfer_data = flashinfer_root / "data" + candidates = [ + flashinfer_data / "include", + flashinfer_data / "csrc", + flashinfer_data / "cutlass" / "include", + flashinfer_data / "cutlass" / "tools" / "util" / "include", + flashinfer_data / "spdlog" / "include", + ] + + for path in candidates: + if not path.exists(): + raise RuntimeError( + f"Required header path {path} for flashinfer dependency not found." + " Please check your flashinfer installation." + ) + include_paths.append(str(path)) + return include_paths + + +@register_dependency("cutlass") +def get_cutlass_include_paths() -> List[str]: + include_paths: List[str] = [] + + flashinfer_root = _find_package_root("flashinfer") + if flashinfer_root is not None: + candidates = [ + flashinfer_root / "data" / "cutlass" / "include", + flashinfer_root / "data" / "cutlass" / "tools" / "util" / "include", + ] + for path in candidates: + if path.exists(): + include_paths.append(str(path)) + + deep_gemm_root = _find_package_root("deep_gemm") + if deep_gemm_root is not None: + candidate = deep_gemm_root / "include" + if candidate.exists(): + include_paths.append(str(candidate)) + + # De-duplicate while preserving order. + unique_paths = [] + seen = set() + for path in include_paths: + if path in seen: + continue + seen.add(path) + unique_paths.append(path) + + if not unique_paths: + raise RuntimeError( + "Cannot find CUTLASS headers required for JIT compilation. " + "Please install flashinfer or deep_gemm with CUTLASS headers." + ) + return unique_paths + + +__all__ = [ + "should_run_full_tests", + "get_ci_test_range", + "cache_once", + "is_hip_runtime", + "make_cpp_args", + "load_jit", + "override_jit_cuda_arch", + "get_jit_cuda_arch", + "is_arch_support_pdl", + "register_dependency", +] diff --git a/lightllm/third_party/sglang_jit/runtime_utils.py b/lightllm/third_party/sglang_jit/runtime_utils.py new file mode 100644 index 0000000000..d322498ca4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/runtime_utils.py @@ -0,0 +1,5 @@ +import torch + + +def is_hip() -> bool: + return torch.version.hip is not None From e8c49d101b31ffe07f67ae9c5738f3e42a0ca810 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 18 Jun 2026 05:44:58 +0000 Subject: [PATCH 30/30] fix tpsp --- .../layer_infer/transformer_layer_infer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py index 617d0dcd85..57b171a8d1 100644 --- a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -166,7 +166,7 @@ def _get_qkv( freqs_cis=self.freqs_cis, positions=infer_state.position_ids, ) - return q, qa + return q, qa, input def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): # o: [T, tp_q_head_num_, head_dim_] after inverse rope -> grouped low-rank O -> [T, embed_dim_] @@ -185,8 +185,8 @@ def context_attention_forward( # _get_qkv writes the chunk's packed latent into the swa pool (fused kernel) before # attention reads it back via full_to_swa indices (this custom forward bypasses the # tpl _post_cache_kv path). - q, q_lora = self._get_qkv(x, infer_state, layer_weight) - o = self._context_attention_wrapper_run(q, q_lora, x, infer_state, layer_weight) + q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight) + o = self._context_attention_wrapper_run(q, q_lora, full_x, infer_state, layer_weight) return self._get_o(o, infer_state, layer_weight) def _context_attention_wrapper_run( @@ -262,8 +262,8 @@ def _context_attention_kernel( def token_attention_forward( self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight ): - q, q_lora = self._get_qkv(x, infer_state, layer_weight) - o = self._token_attention_kernel(q, q_lora, x, infer_state, layer_weight) + q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight) + o = self._token_attention_kernel(q, q_lora, full_x, infer_state, layer_weight) return self._get_o(o, infer_state, layer_weight) def _token_attention_kernel( @@ -552,7 +552,12 @@ def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, la cos_tok = infer_state.position_cos_compress sin_tok = infer_state.position_sin_compress - idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(x.shape[0], self.index_n_heads, self.index_head_dim) + token_num = q_lora.shape[0] + if x.shape[0] != token_num: + raise RuntimeError( + f"DeepSeek-V4 indexer expects full-token hidden states, got x={x.shape[0]} q_lora={token_num}" + ) + idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(token_num, self.index_n_heads, self.index_head_dim) rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok) idx_q = hadamard_transform(idx_q, scale=self.index_head_dim ** -0.5) idx_q_fp8, q_scale = act_quant(idx_q, self.index_head_dim, None) # fp8 [T,H,d], scale [T,H,1]