diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
index 539ade769e..dc18ecf4ba 100644
--- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
+++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
@@ -9,6 +9,60 @@
from lightllm.common.basemodel.infer_struct import InferStateInfo
+# this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up
+# to the nearest supported count (zero heads are discarded from the output slice).
+FLASHMLA_SUPPORTED_HEADS = (64, 128)
+
+
+def _pad_q_heads(q_4d: torch.Tensor, attn_sink: torch.Tensor):
+ h_q = q_4d.shape[2]
+ if h_q in FLASHMLA_SUPPORTED_HEADS:
+ return q_4d, attn_sink, h_q
+ target = next((h for h in FLASHMLA_SUPPORTED_HEADS if h >= h_q), None)
+ assert target is not None, f"num q heads {h_q} exceeds flash_mla support {FLASHMLA_SUPPORTED_HEADS}"
+ q_pad = torch.nn.functional.pad(q_4d, (0, 0, 0, target - h_q))
+ sink_pad = torch.nn.functional.pad(attn_sink, (0, target - h_q))
+ return q_pad, sink_pad, h_q
+
+
+def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor:
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN
+
+ usable = page_size * DSV4_MLA_BYTES_PER_TOKEN
+ return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN)
+
+
+@dataclasses.dataclass
+class _Dsv4Metadata:
+ swa_indices: torch.Tensor
+ swa_lengths: torch.Tensor
+ extra_cache: torch.Tensor = None
+ extra_indices: torch.Tensor = None
+ extra_lengths: torch.Tensor = None
+
+
+def _metadata_from_dict(infer_state, nsa_dict: dict) -> "_Dsv4Metadata":
+ """Bundle the model-built FINAL index tensors (carried in nsa_dict by DeepseekV4IndexInfer) with
+ the layer-keyed fp8 extra-cache byte view. The cache view is data-independent (a fixed per-layer
+ buffer slice), so it is built here -- a genuine flash_mla ABI concern -- rather than on the model
+ side; only the index/length tensors cross the att_control boundary."""
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE
+
+ ratio = nsa_dict["compress_ratio"]
+ extra_cache = None
+ if ratio:
+ page = DSV4_C4_PAGE_SIZE if ratio == 4 else DSV4_C128_PAGE_SIZE
+ extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(nsa_dict["layer_index"])
+ extra_cache = _view_dsv4_flashmla_cache(extra_buffer, page)
+ return _Dsv4Metadata(
+ swa_indices=nsa_dict["swa_indices"],
+ swa_lengths=nsa_dict["swa_lengths"],
+ extra_cache=extra_cache,
+ extra_indices=nsa_dict.get("extra_indices"),
+ extra_lengths=nsa_dict.get("extra_lengths"),
+ )
+
+
class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
@@ -62,6 +116,12 @@ def prefill_att(
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
+ if att_control.nsa_prefill_dict.get("flashmla_kvcache"):
+ return self._flashmla_kvcache_prefill_att(
+ q=q,
+ packed_kv=k,
+ nsa_dict=att_control.nsa_prefill_dict,
+ )
return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control)
def _nsa_prefill_att(
@@ -78,6 +138,8 @@ def _nsa_prefill_att(
kv_lora_rank = nsa_dict["kv_lora_rank"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
+ attn_sink = nsa_dict.get("attn_sink")
+ topk_length = nsa_dict.get("topk_length")
if self.infer_state.prefix_total_token_num > 0:
# 当前推理生成的token kv部分从 prefill_cache_kv 中获取,历史
@@ -101,9 +163,51 @@ def _nsa_prefill_att(
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
+ attn_sink=attn_sink,
+ topk_length=topk_length,
)
return mla_out
+ def _flashmla_kvcache_prefill_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
+ attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous()
+ metadata = _metadata_from_dict(self.infer_state, nsa_dict)
+ return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)
+
+ def _flashmla_kvcache_att(
+ self,
+ q: torch.Tensor,
+ packed_kv: torch.Tensor,
+ metadata: _Dsv4Metadata,
+ attn_sink: torch.Tensor,
+ nsa_dict: dict,
+ ) -> torch.Tensor:
+ import flash_mla
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE
+
+ q_4d = q.unsqueeze(1).contiguous()
+ q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
+ k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
+ sched_meta, _ = flash_mla.get_mla_metadata()
+ out, _ = flash_mla.flash_mla_with_kvcache(
+ q=q_4d,
+ k_cache=k_cache,
+ block_table=None,
+ cache_seqlens=None,
+ head_dim_v=nsa_dict["head_dim_v"],
+ tile_scheduler_metadata=sched_meta,
+ num_splits=None,
+ softmax_scale=nsa_dict["softmax_scale"],
+ causal=False,
+ is_fp8_kvcache=True,
+ indices=metadata.swa_indices,
+ attn_sink=attn_sink,
+ topk_length=metadata.swa_lengths,
+ extra_k_cache=metadata.extra_cache,
+ extra_indices_in_kvcache=metadata.extra_indices,
+ extra_topk_length=metadata.extra_lengths,
+ )
+ return out[:, 0, :num_real_heads].contiguous()
+
@dataclasses.dataclass
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
@@ -143,7 +247,18 @@ def init_state(self):
)
import flash_mla
- self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
+ # one sched_meta per layer type: the lazy config locks extra-cache geometry (page size,
+ # presence) on first invocation, so swa-only/c4/c128 layers must not share one object.
+ self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
+ return
+
+ def reset_sched_meta_for_capture(self):
+ # cuda-graph capture hook: the warmup pass already locked/stored sched meta on this
+ # (shared) state object; reset so the capture pass re-plans INSIDE the graph and every
+ # replay re-plans from the live tensors instead of binding warmup leftovers.
+ import flash_mla
+
+ self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
return
def decode_att(
@@ -156,6 +271,12 @@ def decode_att(
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
+ if att_control.nsa_decode_dict.get("flashmla_kvcache"):
+ return self._flashmla_kvcache_decode_att(
+ q=q,
+ packed_kv=k,
+ nsa_dict=att_control.nsa_decode_dict,
+ )
return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control)
def _nsa_decode_att(
@@ -170,6 +291,11 @@ def _nsa_decode_att(
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
+ attn_sink = nsa_dict.get("attn_sink")
+ topk_length = nsa_dict.get("topk_length")
+ extra_k_cache = nsa_dict.get("extra_k_cache")
+ extra_indices = nsa_dict.get("extra_indices_in_kvcache")
+ extra_topk_length = nsa_dict.get("extra_topk_length")
if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)
@@ -189,10 +315,54 @@ def _nsa_decode_att(
block_table=None,
cache_seqlens=None,
head_dim_v=kv_lora_rank,
- tile_scheduler_metadata=self.flashmla_sched_meta,
+ tile_scheduler_metadata=self.flashmla_sched_meta[0],
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
+ attn_sink=attn_sink,
+ topk_length=topk_length,
+ extra_k_cache=extra_k_cache,
+ extra_indices_in_kvcache=extra_indices,
+ extra_topk_length=extra_topk_length,
)
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]
+
+ def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
+ attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous()
+ metadata = _metadata_from_dict(self.infer_state, nsa_dict)
+ return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)
+
+ def _flashmla_kvcache_att(
+ self,
+ q: torch.Tensor,
+ packed_kv: torch.Tensor,
+ metadata: _Dsv4Metadata,
+ attn_sink: torch.Tensor,
+ nsa_dict: dict,
+ ) -> torch.Tensor:
+ import flash_mla
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE
+
+ q_4d = q.unsqueeze(1).contiguous()
+ q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
+ k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
+ out, _ = flash_mla.flash_mla_with_kvcache(
+ q=q_4d,
+ k_cache=k_cache,
+ block_table=None,
+ cache_seqlens=None,
+ head_dim_v=nsa_dict["head_dim_v"],
+ tile_scheduler_metadata=self.flashmla_sched_meta[nsa_dict["compress_ratio"]],
+ num_splits=None,
+ softmax_scale=nsa_dict["softmax_scale"],
+ causal=False,
+ is_fp8_kvcache=True,
+ indices=metadata.swa_indices,
+ attn_sink=attn_sink,
+ topk_length=metadata.swa_lengths,
+ extra_k_cache=metadata.extra_cache,
+ extra_indices_in_kvcache=metadata.extra_indices,
+ extra_topk_length=metadata.extra_lengths,
+ )
+ return out[:, 0, :num_real_heads].contiguous()
diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 94f9d4c1a2..5825d9b45f 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -291,6 +291,11 @@ def _init_custom(self):
@torch.no_grad()
def forward(self, model_input: ModelInput):
+ # decode 槽位 prep: 放在 to_cuda 前 (b_req_idx/b_seq_len/mem_indexes_cpu 还是原生 CPU 张量),
+ # 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。
+ # mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。
+ if not model_input.is_prefill and model_input.mem_indexes_cpu is not None:
+ self.req_manager.prepare_decode(model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes_cpu)
model_input.to_cuda()
assert model_input.mem_indexes.is_cuda
@@ -521,6 +526,18 @@ def _prefill(
alloc_mem_index=infer_state.mem_index,
max_q_seq_len=infer_state.max_q_seq_len,
)
+ if hasattr(self.req_manager, "prepare_prefill_swa"):
+ self.req_manager.prepare_prefill_swa(
+ b_req_idx=infer_state.b_req_idx,
+ b_ready_cache_len=infer_state.b_ready_cache_len,
+ b_seq_len=infer_state.b_seq_len,
+ )
+ if hasattr(self.req_manager, "prepare_prefill_compress_slots"):
+ self.req_manager.prepare_prefill_compress_slots(
+ b_req_idx=infer_state.b_req_idx,
+ b_ready_cache_len=infer_state.b_ready_cache_len,
+ b_seq_len=infer_state.b_seq_len,
+ )
prefill_mem_indexes_ready_event = torch.cuda.Event()
prefill_mem_indexes_ready_event.record()
@@ -741,6 +758,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state0.mem_index,
max_q_seq_len=infer_state0.max_q_seq_len,
)
+ if hasattr(self.req_manager, "prepare_prefill_swa"):
+ self.req_manager.prepare_prefill_swa(
+ b_req_idx=infer_state0.b_req_idx,
+ b_ready_cache_len=infer_state0.b_ready_cache_len,
+ b_seq_len=infer_state0.b_seq_len,
+ )
+ if hasattr(self.req_manager, "prepare_prefill_compress_slots"):
+ self.req_manager.prepare_prefill_compress_slots(
+ b_req_idx=infer_state0.b_req_idx,
+ b_ready_cache_len=infer_state0.b_ready_cache_len,
+ b_seq_len=infer_state0.b_seq_len,
+ )
infer_state0.init_some_extra_state(self)
infer_state0.init_att_state()
@@ -754,6 +783,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state1.mem_index,
max_q_seq_len=infer_state1.max_q_seq_len,
)
+ if hasattr(self.req_manager, "prepare_prefill_swa"):
+ self.req_manager.prepare_prefill_swa(
+ b_req_idx=infer_state1.b_req_idx,
+ b_ready_cache_len=infer_state1.b_ready_cache_len,
+ b_seq_len=infer_state1.b_seq_len,
+ )
+ if hasattr(self.req_manager, "prepare_prefill_compress_slots"):
+ self.req_manager.prepare_prefill_compress_slots(
+ b_req_idx=infer_state1.b_req_idx,
+ b_ready_cache_len=infer_state1.b_ready_cache_len,
+ b_seq_len=infer_state1.b_seq_len,
+ )
infer_state1.init_some_extra_state(self)
infer_state1.init_att_state()
@@ -781,6 +822,10 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
@torch.no_grad()
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
+ # decode 槽位 prep: 在 to_cuda 前 (原生 CPU 张量)、且已在 forward 的 CUDA stream 上 (见 forward 注释)。
+ for mi in (model_input0, model_input1):
+ if mi.mem_indexes_cpu is not None:
+ self.req_manager.prepare_decode(mi.b_req_idx, mi.b_seq_len, mi.mem_indexes_cpu)
model_input0.to_cuda()
model_input1.to_cuda()
assert self.args.enable_tpsp_mix_mode
diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py
index 782150661e..e1d96b744e 100644
--- a/lightllm/common/basemodel/cuda_graph.py
+++ b/lightllm/common/basemodel/cuda_graph.py
@@ -14,6 +14,20 @@
logger = init_logger(__name__)
+def _reset_att_state_sched_meta(infer_state: InferStateInfo):
+ # capture 前调用: warmup 趟用 copy.copy 浅拷贝共享 decode_att_state,其内部惰性初始化的
+ # 调度对象(如 FlashMLASchedMeta,首次内核调用时按当时数据规划并回写)会被 warmup 的
+ # dummy 负载锁定;若不重置,捕获趟将绑定为 dummy 规划的调度张量,所有 replay 都用错误
+ # 的 tile schedule(DSV4 实测 gsm8k 0.96 -> 0.74)。重置后规划发生在捕获区内,随 replay 重算。
+ for att_state in (infer_state.decode_att_state, infer_state.decode_att_state1):
+ if att_state is None:
+ continue
+ reset_fn = getattr(att_state, "reset_sched_meta_for_capture", None)
+ if reset_fn is not None:
+ reset_fn()
+ return
+
+
class CudaGraph:
# CudaGraph forward pass for the decoding stage.
@@ -94,6 +108,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
if param_name not in pure_para_set:
delattr(infer_state, param_name)
+ _reset_att_state_sched_meta(infer_state)
+
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
@@ -128,6 +144,9 @@ def _capture_decode_overlap(
if para_name not in pure_para_set1:
delattr(infer_state1, para_name)
+ _reset_att_state_sched_meta(infer_state)
+ _reset_att_state_sched_meta(infer_state1)
+
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
index fca9b80fcf..24842ed383 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
@@ -68,12 +68,13 @@ def __init__(
auto_update_redundancy_expert=self.auto_update_redundancy_expert,
)
self.lock = threading.Lock()
+ self._moe_weight_finalized = False
self._create_weight()
def _init_config(self, network_config: Dict[str, Any]):
self.n_group = network_config.get("n_group", 0)
self.use_grouped_topk = self.n_group > 0
- self.norm_topk_prob = network_config["norm_topk_prob"]
+ self.norm_topk_prob = network_config.get("norm_topk_prob", False)
self.topk_group = network_config.get("topk_group", 0)
self.num_experts_per_tok = network_config["num_experts_per_tok"]
self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0)
@@ -136,6 +137,7 @@ def experts(
is_prefill: Optional[bool] = None,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
+ self._finalize_moe_weight()
return self.fuse_moe_impl(
input_tensor=input_tensor,
router_logits=router_logits,
@@ -152,6 +154,25 @@ def experts(
per_expert_scale=self.per_expert_scale,
)
+ def experts_with_preselected(
+ self,
+ input_tensor: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
+ ) -> torch.Tensor:
+ self._finalize_moe_weight()
+ return self.fuse_moe_impl.fused_experts_with_topk(
+ input_tensor=input_tensor,
+ w13=self.w13,
+ w2=self.w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ is_prefill=is_prefill,
+ clamp_limit=clamp_limit,
+ )
+
def low_latency_dispatch(
self,
hidden_states: torch.Tensor,
@@ -280,7 +301,18 @@ def verify_load(self):
e_score_correction_bias_load_ok = (
True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False)
)
- return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok
+ load_ok = weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok
+ if load_ok:
+ self._finalize_moe_weight()
+ return load_ok
+
+ def _finalize_moe_weight(self):
+ if self._moe_weight_finalized:
+ return
+ finalize = getattr(self.quant_method, "finalize_moe_weight", None)
+ if finalize is not None:
+ finalize(self)
+ self._moe_weight_finalized = True
def _create_weight(self):
intermediate_size = self.split_inter_size
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py
index 67bb90e4ef..282c0abdce 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py
@@ -2,9 +2,15 @@
from .triton_impl import FuseMoeTriton
from .marlin_impl import FuseMoeMarlin
from .deepgemm_impl import FuseMoeDeepGEMM
+from .mxfp4_impl import FuseMoeMXFP4
def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool):
+ if quant_method.method_name == "marlin-mxfp4w4a16-b32":
+ if enable_ep_moe:
+ raise RuntimeError("marlin-mxfp4w4a16-b32 does not support enable_ep_moe yet")
+ return FuseMoeMXFP4
+
if enable_ep_moe:
return FuseMoeDeepGEMM
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py
index 4d4614c007..72acf2430a 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py
@@ -76,7 +76,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
):
+ assert clamp_limit is None, "EP deepgemm fused MoE does not support clamp_limit yet"
output = fused_experts(
hidden_states=input_tensor,
w13=w13,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py
index 0094b09b1c..1fdfd94d0d 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py
@@ -30,7 +30,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
):
+ assert clamp_limit is None, "awq_marlin fused MoE does not support clamp_limit yet"
w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point
w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py
new file mode 100644
index 0000000000..97cf238116
--- /dev/null
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py
@@ -0,0 +1,44 @@
+import torch
+from typing import Optional
+
+from lightllm.common.quantization.quantize_method import WeightPack
+from .triton_impl import FuseMoeTriton
+
+
+class FuseMoeMXFP4(FuseMoeTriton):
+ def create_workspace(self):
+ return None
+
+ def _fused_experts(
+ self,
+ input_tensor: torch.Tensor,
+ w13: WeightPack,
+ w2: WeightPack,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ router_logits: Optional[torch.Tensor] = None,
+ is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
+ ):
+ try:
+ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
+ from vllm.model_executor.layers.fused_moe.experts.marlin_moe import fused_marlin_moe
+ from vllm.scalar_type import scalar_types
+ except Exception as e:
+ raise RuntimeError(f"MXFP4 fused MoE requires vLLM fused kernels, error={repr(e)}") from e
+
+ return fused_marlin_moe(
+ hidden_states=input_tensor.contiguous(),
+ w1=w13.weight,
+ w2=w2.weight,
+ bias1=None,
+ bias2=None,
+ w1_scale=w13.weight_scale,
+ w2_scale=w2.weight_scale,
+ topk_weights=topk_weights.to(torch.float32).contiguous(),
+ topk_ids=topk_ids.to(torch.long).contiguous(),
+ quant_type_id=scalar_types.float4_e2m1f.id,
+ global_num_experts=self.n_routed_experts,
+ activation=MoEActivation.SILU,
+ clamp_limit=clamp_limit,
+ )
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
index a0d30547a3..8967dda34e 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
@@ -94,6 +94,7 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: bool = False,
+ clamp_limit: Optional[float] = None,
):
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
@@ -111,9 +112,30 @@ def _fused_experts(
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
+ limit=clamp_limit,
)
return input_tensor
+ def fused_experts_with_topk(
+ self,
+ input_tensor: torch.Tensor,
+ w13: WeightPack,
+ w2: WeightPack,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_prefill: Optional[bool] = None,
+ clamp_limit: Optional[float] = None,
+ ):
+ return self._fused_experts(
+ input_tensor=input_tensor,
+ w13=w13,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ is_prefill=is_prefill,
+ clamp_limit=clamp_limit,
+ )
+
def __call__(
self,
input_tensor: torch.Tensor,
diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py
index cb2e370cb9..28fe6e4304 100644
--- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py
+++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py
@@ -49,7 +49,7 @@ def check_ep_expert_dtype(quant_method: Any):
"EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], "
f"but the resolved fused_moe quant method is `{expert_dtype}`. "
"Please start with --expert_dtype fp8 or --expert_dtype fp4. "
- "Note that --expert_dtype fp4 is only supported on SM100 GPUs."
+ "Note that --expert_dtype fp4 with EP MoE is only supported on SM100 GPUs."
)
if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu():
raise RuntimeError(
diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py
index 45c7ea73c6..82fc9131c1 100644
--- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py
+++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py
@@ -24,6 +24,7 @@ def _silu_and_mul_kernel_fast(
NEED_MASK: tl.constexpr,
layout: tl.constexpr = "blocked", # "blocked" or "interleaved"
USE_LIMIT_AND_ALPHA: tl.constexpr = False,
+ USE_LIMIT_ONLY: tl.constexpr = False,
USE_TANH_APPROXIMATE_GELU: tl.constexpr = False,
):
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
@@ -76,6 +77,11 @@ def _silu_and_mul_kernel_fast(
mask=mask,
)
else:
+ if USE_LIMIT_ONLY:
+ # clamped swiglu (DeepSeek-V4 swiglu_limit): clamp 后接标准 silu,
+ # 无 gpt-oss 的 alpha 缩放与 (up+1)。
+ gate = tl.minimum(gate, limit)
+ up = tl.minimum(tl.maximum(up, -limit), limit)
if USE_TANH_APPROXIMATE_GELU:
# tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP.
gate_cubed = gate * gate * gate
@@ -124,7 +130,8 @@ def silu_and_mul_fwd(
):
assert input.is_contiguous()
assert output.is_contiguous()
- assert (limit is None and alpha is None) or (limit is not None and alpha is not None)
+ # limit+alpha: gpt-oss 语义 (up+1)*silu(alpha*gate); 仅 limit: clamp 后标准 silu (DeepSeek-V4)
+ assert alpha is None or limit is not None
stride_input_m = input.stride(0)
stride_input_n = input.stride(1)
@@ -147,6 +154,7 @@ def silu_and_mul_fwd(
while triton.cdiv(size_m, BLOCK_M) > 8192:
BLOCK_M *= 2
USE_LIMIT_AND_ALPHA = limit is not None and alpha is not None
+ USE_LIMIT_ONLY = limit is not None and alpha is None
grid = (
triton.cdiv(size_n, BLOCK_N),
@@ -171,6 +179,7 @@ def silu_and_mul_fwd(
num_warps=num_warps,
layout=layout,
USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA,
+ USE_LIMIT_ONLY=USE_LIMIT_ONLY,
USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(),
)
return
diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py
index 05544e149a..95f7e8ab76 100644
--- a/lightllm/common/kv_cache_mem_manager/__init__.py
+++ b/lightllm/common/kv_cache_mem_manager/__init__.py
@@ -4,6 +4,7 @@
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from .deepseek2_mem_manager import Deepseek2MemoryManager
from .deepseek3_2mem_manager import Deepseek3_2MemoryManager
+from .deepseek4_mem_manager import DeepseekV4MemoryManager
from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager
@@ -17,6 +18,7 @@
"PPLINT8KVMemoryManager",
"Deepseek2MemoryManager",
"Deepseek3_2MemoryManager",
+ "DeepseekV4MemoryManager",
"FP8PerTokenGroupQuantDeepseek3_2MemoryManager",
"FP8StaticPerHeadQuantMemManager",
"FP8StaticPerTensorQuantMemManager",
diff --git a/lightllm/common/kv_cache_mem_manager/allocator.py b/lightllm/common/kv_cache_mem_manager/allocator.py
index 850c158778..0179ed2714 100644
--- a/lightllm/common/kv_cache_mem_manager/allocator.py
+++ b/lightllm/common/kv_cache_mem_manager/allocator.py
@@ -3,13 +3,13 @@
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.log_utils import init_logger
-from typing import Union, List
+from typing import Union, List, Optional
logger = init_logger(__name__)
class KvCacheAllocator:
- def __init__(self, size: int) -> None:
+ def __init__(self, size: int, shared_name: Optional[str] = None) -> None:
self.size = size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
@@ -26,9 +26,11 @@ def __init__(self, size: int) -> None:
rank_in_node = get_current_rank_in_node()
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
- self.shared_can_use_token_num = SharedInt(
- f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
- )
+ # shared_name 为 None 时使用主 kv 池的默认名(router 调度据此估算);DeepSeek-V4 的压缩子池等
+ # 需要各自独立的计数器,传入区别于主池的唯一名,避免多个 allocator 写同一个共享计数器。
+ if shared_name is None:
+ shared_name = f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
+ self.shared_can_use_token_num = SharedInt(shared_name)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
return
diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py
new file mode 100644
index 0000000000..cfb149dcec
--- /dev/null
+++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py
@@ -0,0 +1,872 @@
+import torch
+import torch.distributed as dist
+from typing import List, Optional, Union
+from .mem_manager import MemoryManager
+from .operator import DeepseekV4MemOperator
+from .allocator import KvCacheAllocator
+from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node
+from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
+
+logger = init_logger(__name__)
+
+
+# fp8_ds_mla packed-latent byte layout (ABI shared with the flash_mla extra-cache fork and
+# sglang/vllm): 448B NoPE fp8 + 64*2B RoPE bf16 + 7B ue8m0 scale + 1B pad = 584B per token,
+# stored in page slabs whose tail carries the per-token scale bytes.
+DSV4_MLA_NOPE_DIM = 448
+DSV4_MLA_ROPE_DIM = 64
+DSV4_MLA_HEAD_DIM = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM
+DSV4_MLA_QUANT_GROUP_SIZE = 64
+DSV4_MLA_SCALE_BYTES = DSV4_MLA_NOPE_DIM // DSV4_MLA_QUANT_GROUP_SIZE + 1
+DSV4_MLA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 + DSV4_MLA_SCALE_BYTES
+DSV4_MLA_DATA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2
+DSV4_MLA_PAGE_ALIGN_BYTES = DSV4_MLA_DATA_BYTES_PER_TOKEN
+DSV4_INDEXER_HEAD_DIM = 128
+DSV4_INDEXER_SCALE_BYTES = 4
+DSV4_INDEXER_BYTES_PER_TOKEN = DSV4_INDEXER_HEAD_DIM + DSV4_INDEXER_SCALE_BYTES
+DSV4_FP8_E4M3_MAX = 448.0
+DSV4_FP8_SCALE_MIN = 1e-4
+DSV4_SWA_PAGE_SIZE = 128
+DSV4_C4_PAGE_SIZE = 64
+DSV4_C128_PAGE_SIZE = 2
+DSV4_PROMPT_CACHE_PAGE_SIZE = DSV4_C4_PAGE_SIZE * 4
+# compressor state ring: c4 overlap 对为每页 2 个分组槽 × ratio 4 行;c128 离线聚合为每页 1 组。
+DSV4_C4_STATE_RING = 8
+DSV4_C128_STATE_RING = 128
+# swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 同值)。
+# v5 的 swa 压力阀(借页/驱逐)已覆盖 radix 树与准入波次的瞬时增长,结构性预算
+# (max_req×window + batch_max_tokens 余量)另行叠加,0.1 仅作 full 池比例下限。
+DSV4_SWA_FULL_TOKENS_RATIO = 0.1
+
+
+def _ceil_div(a: int, b: int) -> int:
+ return (a + b - 1) // b
+
+
+class PackedPagePool:
+ """fp8_ds_mla 风格的 page-slab 存储: 每页前段连续放 token 的 data 字节,页尾放 per-token scale 字节。
+
+ 寻址是纯 token 槽位 (page = slot // page_size),page 只是 scale-tail/对齐的物理打包技巧,
+ 不存在页粒度的分配。``write``/``read`` 是 torch 参考实现(单测 oracle);生产写入走
+ triton packed writer(destindex_copy_kv_flashmla_dsv4 等),kernel 直接消费 ``buffer``。
+ """
+
+ def __init__(
+ self,
+ size: int,
+ page_size: int,
+ layer_num: int,
+ data_bytes: int,
+ scale_bytes: int,
+ align_bytes: int = 1,
+ device: str = "cuda",
+ ):
+ self.size = size
+ self.page_size = page_size
+ self.layer_num = layer_num
+ self.data_bytes_per_token = data_bytes
+ self.scale_bytes_per_token = scale_bytes
+ self.bytes_per_token = data_bytes + scale_bytes
+ self.num_pages = _ceil_div(size + 1, page_size)
+ self.bytes_per_page = _ceil_div(page_size * self.bytes_per_token, align_bytes) * align_bytes
+ self.scale_offset_in_page = page_size * data_bytes
+ self.buffer = torch.zeros((layer_num, self.num_pages, self.bytes_per_page), dtype=torch.uint8, device=device)
+ self.HOLD_TOKEN_MEMINDEX = size
+
+ def get_layer_buffer(self, layer_index: int) -> torch.Tensor:
+ return self.buffer[layer_index]
+
+ def _loc_offsets(self, loc: torch.Tensor):
+ loc = loc.long()
+ page = torch.div(loc, self.page_size, rounding_mode="floor")
+ token = loc % self.page_size
+ page_base = page * self.bytes_per_page
+ data_offsets = page_base + token * self.data_bytes_per_token
+ scale_offsets = page_base + self.scale_offset_in_page + token * self.scale_bytes_per_token
+ return data_offsets, scale_offsets
+
+ def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None:
+ if loc.numel() == 0:
+ return
+ loc = loc.reshape(-1)
+ packed = packed.reshape(-1, self.bytes_per_token).contiguous()
+ flat = self.buffer[layer_index].view(-1)
+ data_offsets, scale_offsets = self._loc_offsets(loc)
+ data_range = torch.arange(self.data_bytes_per_token, device=loc.device)
+ scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device)
+ flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] = packed[:, : self.data_bytes_per_token]
+ flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = packed[:, self.data_bytes_per_token :]
+ return
+
+ def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor:
+ loc = loc.reshape(-1)
+ if loc.numel() == 0:
+ return torch.empty((0, self.bytes_per_token), dtype=torch.uint8, device=self.buffer.device)
+ flat = self.buffer[layer_index].view(-1)
+ data_offsets, scale_offsets = self._loc_offsets(loc)
+ data_range = torch.arange(self.data_bytes_per_token, device=loc.device)
+ scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device)
+ data = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)]
+ scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)]
+ return torch.cat([data, scale], dim=1).contiguous()
+
+
+class DeepseekV4MemoryManager(MemoryManager):
+ """DeepSeek-V4 KV cache: 窗口 latent(全层) + c4/c128 压缩 latent(压实层) + c4 indexer-K。
+
+ 与兄弟 manager 一致的 token-slot 设计;req 索引的表都在 DeepseekV4ReqManager。
+
+ - ``swa_pool``: 584B packed latent,所有层。池子小于 full token 空间;prep 阶段
+ ``alloc_swa_prefill/decode`` 按**页**(128 槽,位置对齐: slot(p)=page_base+p%128)分配,
+ 映射记录到 ``full_to_swa_indexs``(以 full token 槽位为键)。出窗槽位由 DeepseekV4ReqManager
+ 在 prep 阶段批量惰性回收(``evict_swa``,页存活计数减到 0 才整页归还);full 槽位释放时
+ ``free`` 级联回收对应 swa 槽,所以 radix 驱逐/请求释放/暂停无需任何额外协议。
+ 页 allocator 触底时先走压力阀(radix 对 ref==0 节点回收)再 assert。
+ 没有 ring buffer,prefill chunk 大小不受 sliding_window 限制。
+ - ``c4_pool``/``c128_pool``: 压缩 latent,按 qwen3next 的层号压实手法只为压缩层建层;
+ c4 另带 packed indexer-K 池。槽位映射(``full_to_c4/c128_indexs``)以组末 token 的 full
+ 槽位为键(prep 阶段分配/scatter),``free`` 级联回收,与 swa 完全同构。
+ - 写入走标准 operator 路径(``pack_mla_kv_to_cache``),内部为 triton packed writer;
+ torch codecs 保留为 ABI 的可执行规格(单测 oracle)。
+ """
+
+ operator_class = DeepseekV4MemOperator
+
+ mla_nope_dim = DSV4_MLA_NOPE_DIM
+ mla_rope_dim = DSV4_MLA_ROPE_DIM
+ mla_head_dim = DSV4_MLA_HEAD_DIM
+ mla_quant_group_size = DSV4_MLA_QUANT_GROUP_SIZE
+ mla_scale_bytes = DSV4_MLA_SCALE_BYTES
+ mla_bytes_per_token = DSV4_MLA_BYTES_PER_TOKEN
+ indexer_head_dim_default = DSV4_INDEXER_HEAD_DIM
+ indexer_bytes_per_token = DSV4_INDEXER_BYTES_PER_TOKEN
+
+ def __init__(
+ self,
+ size,
+ dtype,
+ head_num,
+ head_dim,
+ layer_num,
+ compress_rates: List[int],
+ indexer_head_dim: int = 128,
+ max_request_num: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ swa_extra_token_num: int = 0,
+ swa_full_tokens_ratio: float = DSV4_SWA_FULL_TOKENS_RATIO,
+ always_copy=False,
+ mem_fraction=0.9,
+ ):
+ assert head_num == 1, "DeepSeek-V4 是 MLA(MQA),dense latent 的 head_num 必须为 1"
+ assert head_dim == self.mla_head_dim, f"DeepSeek-V4 packed KV 期望 head_dim={self.mla_head_dim}"
+ assert (
+ indexer_head_dim == self.indexer_head_dim_default
+ ), f"DeepSeek-V4 packed indexer-K 期望 indexer_head_dim={self.indexer_head_dim_default}"
+ assert len(compress_rates) == layer_num, f"compress_rates 长度 {len(compress_rates)} 必须等于 layer_num {layer_num}"
+ assert all(r in (0, 4, 128) for r in compress_rates), "compress_rates 取值只能是 0/4/128"
+
+ self.compress_rates = list(compress_rates)
+ self.n_c4 = sum(1 for r in self.compress_rates if r == 4)
+ self.n_c128 = sum(1 for r in self.compress_rates if r == 128)
+ self.indexer_head_dim = indexer_head_dim
+ self.max_request_num = max_request_num
+ self.sliding_window = sliding_window
+ # 活跃窗口(max_request_num * sliding_window)之外的余量: 在途 prefill chunk 的瞬时占用
+ # (出窗槽位要到下一次 prep 才回收) + radix cache 持有的窗口尾部。
+ self.swa_extra_token_num = int(swa_extra_token_num)
+ self.swa_full_tokens_ratio = float(swa_full_tokens_ratio)
+
+ # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法)
+ self.layer_to_c4_idx = {}
+ self.layer_to_c128_idx = {}
+ c4 = c128 = 0
+ for lid, r in enumerate(self.compress_rates):
+ if r == 4:
+ self.layer_to_c4_idx[lid] = c4
+ c4 += 1
+ elif r == 128:
+ self.layer_to_c128_idx[lid] = c128
+ c128 += 1
+
+ super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
+
+ # ------------------------------------------------------------------ sizing
+ def _swa_per_req_budget(self) -> int:
+ # 活跃请求保留 window + 一个 radix 页(req_manager._swa_retain_len: 让最近完成的
+ # prompt-cache 边界的结尾页恒驻留)。V4 的 prompt-cache 边界取 256 token,
+ # 避免 radix 共享前缀落在 c4 物理页(64 c4 entry = 256 token)中间。
+ return int(self.sliding_window) + DSV4_PROMPT_CACHE_PAGE_SIZE
+
+ def _planned_swa_size(self, full_size: int) -> int:
+ # swa 池按页分配(页 = 128 = sliding_window),容量向上取整到整页。
+ if self.max_request_num is None or self.sliding_window is None:
+ return _ceil_div(full_size, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE
+ cap = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num
+ cap = max(cap, int(full_size * self.swa_full_tokens_ratio))
+ cap = max(1, min(full_size, cap))
+ return _ceil_div(cap, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE
+
+ @staticmethod
+ def _slab_bytes_per_slot(page_size: int, data_bytes: int, scale_bytes: int, align_bytes: int = 1) -> float:
+ bytes_per_page = _ceil_div(page_size * (data_bytes + scale_bytes), align_bytes) * align_bytes
+ return bytes_per_page / page_size
+
+ @staticmethod
+ def _paged_state_rows(num_swa_pages: int, ring: int, ratio: int) -> int:
+ rows = num_swa_pages * ring + ring + 1
+ return _ceil_div(rows, ratio) * ratio
+
+ @staticmethod
+ def _init_state_sentinel(buffer: torch.Tensor) -> None:
+ half = buffer.shape[-1] // 2
+ buffer[:, -1, :half].zero_()
+ buffer[:, -1, half:].fill_(float("-inf"))
+ return
+
+ def _paged_state_bytes_per_swa_slot(self) -> float:
+ """c4/c128 compressor state(swa 页派生寻址)摊到每个 swa 槽的字节数。"""
+ per_page = 0.0
+ if self.n_c4 > 0:
+ per_page += DSV4_C4_STATE_RING * (4 * self.head_dim + 4 * self.indexer_head_dim) * 4 * self.n_c4
+ if self.n_c128 > 0:
+ per_page += DSV4_C128_STATE_RING * (2 * self.head_dim) * 4 * self.n_c128
+ return per_page / DSV4_SWA_PAGE_SIZE
+
+ def _swa_slot_bytes(self) -> float:
+ per_layer = self._slab_bytes_per_slot(
+ DSV4_SWA_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES
+ )
+ return per_layer * self.layer_num + self._paged_state_bytes_per_swa_slot()
+
+ def _compressed_cell_size(self) -> float:
+ """每个 full token 摊到压缩池上的精确字节数(按 page-slab 对齐后)。"""
+ c4_latent = self._slab_bytes_per_slot(
+ DSV4_C4_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES
+ )
+ c128_latent = self._slab_bytes_per_slot(
+ DSV4_C128_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES
+ )
+ c4_indexer = self._slab_bytes_per_slot(DSV4_C4_PAGE_SIZE, self.indexer_head_dim, DSV4_INDEXER_SCALE_BYTES)
+ return (c4_latent + c4_indexer) * self.n_c4 / 4 + c128_latent * self.n_c128 / 128
+
+ def get_cell_size(self):
+ compressed = self._compressed_cell_size()
+ if self.size is None:
+ return self._swa_slot_bytes() + compressed
+ swa_ratio = self._planned_swa_size(self.size) / max(1, self.size)
+ return self._swa_slot_bytes() * swa_ratio + compressed
+
+ def profile_size(self, mem_fraction):
+ if self.size is not None:
+ return
+
+ torch.cuda.empty_cache()
+ world_size = dist.get_world_size()
+ available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction)
+ available_bytes = available_memory * 1024 ** 3
+ swa_slot_bytes = self._swa_slot_bytes()
+ compressed_cell = self._compressed_cell_size()
+
+ if self.max_request_num is not None and self.sliding_window is not None and compressed_cell > 0:
+ swa_budget = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num
+ full_cell = swa_slot_bytes + compressed_cell
+ if available_bytes <= full_cell * swa_budget:
+ # 小显存: full token 数还到不了 swa 预算,swa 池跟随 full token 数(每 token 一个 swa 槽)。
+ self.size = max(1, int(available_bytes / full_cell))
+ else:
+ size_budget = max(1, int((available_bytes - swa_slot_bytes * swa_budget) / compressed_cell))
+ if size_budget * self.swa_full_tokens_ratio > swa_budget:
+ # 比例下限生效(_planned_swa_size 会取 ratio*full),按该机制反解 full。
+ self.size = max(
+ 1, int(available_bytes / (swa_slot_bytes * self.swa_full_tokens_ratio + compressed_cell))
+ )
+ else:
+ self.size = size_budget
+ else:
+ self.size = max(1, int(available_bytes / (swa_slot_bytes + compressed_cell)))
+
+ if world_size > 1:
+ tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}")
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
+ self.size = tensor.item()
+
+ logger.info(
+ f"{str(available_memory)} GB space is available after load the model weight\n"
+ f"{str(self.get_cell_size() / 1024 ** 2)} MB is the conservative size of one token kv cache\n"
+ f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n"
+ )
+ return
+
+ # ------------------------------------------------------------------ buffers
+ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
+ rank_in_node = get_current_rank_in_node()
+ server = get_unique_server_name()
+
+ self.swa_size = self._planned_swa_size(size)
+ assert self.swa_size % DSV4_SWA_PAGE_SIZE == 0
+ self.swa_pool = PackedPagePool(
+ size=self.swa_size,
+ page_size=DSV4_SWA_PAGE_SIZE,
+ layer_num=layer_num,
+ data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN,
+ scale_bytes=self.mla_scale_bytes,
+ align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES,
+ )
+ # 注意: 该别名是 page 索引([layer, num_pages, bytes_per_page])而非 token 索引,
+ # 只允许 get_att_input_params 的消费者使用;token 索引语义的继承接口已显式 fence。
+ self.kv_buffer = self.swa_pool.buffer
+ # 页粒度分配(页 = 128 槽,位置对齐): 槽位不变式 slot(p) = page_base + p%128。
+ # swa_size 整页对齐 ⇒ HOLD 槽(swa_size)独占池子最后一个物理页,永不参与分配。
+ self.swa_num_pages = self.swa_size // DSV4_SWA_PAGE_SIZE
+ self.swa_page_allocator = KvCacheAllocator(
+ self.swa_num_pages, shared_name=f"{server}_dsv4_swa_can_use_page_num_{rank_in_node}"
+ )
+ # 页存活计数 = 指向该页的有效 full_to_swa 行数;减到 0 归还 allocator(出窗逐 token
+ # 回收下,「部分出窗页」计数 > 0 自然受保护)。下标含 HOLD 页(只读不增减)。
+ self.swa_page_live_count = torch.zeros((self.swa_pool.num_pages,), dtype=torch.int32, device="cuda")
+ # swa 压力阀(可选): 页 allocator 触底时回调(radix 对 ref==0 节点回收 swa 页),
+ # 由 backend 在 radix cache 创建后注入;assert 仍是最后防线。
+ self._swa_pressure_valve = None
+ self.full_to_swa_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda")
+ self.full_to_swa_indexs[size] = self.swa_pool.HOLD_TOKEN_MEMINDEX
+
+ self.c4_size = _ceil_div(size, 4)
+ self.c128_size = _ceil_div(size, 128)
+ self.c4_pool: Optional[PackedPagePool] = None
+ self.c4_indexer_pool: Optional[PackedPagePool] = None
+ self.c4_allocator: Optional[KvCacheAllocator] = None
+ self.c4_page_allocator: Optional[KvCacheAllocator] = None
+ self.c4_page_live_count: Optional[torch.Tensor] = None
+ self.c128_pool: Optional[PackedPagePool] = None
+ self.c128_allocator: Optional[KvCacheAllocator] = None
+ self.c4_state_buffer: Optional[torch.Tensor] = None
+ self.c4_indexer_state_buffer: Optional[torch.Tensor] = None
+ self.c128_state_buffer: Optional[torch.Tensor] = None
+ # 压缩槽映射: 键 = 组末 token(位置 (g+1)%ratio==0)的 full 槽位,值 = 压缩池槽位。
+ # 与 full_to_swa_indexs 同构: radix 持有 full 槽 => 映射行存活,free 级联回收。
+ self.full_to_c4_indexs: Optional[torch.Tensor] = None
+ self.full_to_c128_indexs: Optional[torch.Tensor] = None
+ if self.n_c4 > 0:
+ self.c4_pool = PackedPagePool(
+ size=self.c4_size,
+ page_size=DSV4_C4_PAGE_SIZE,
+ layer_num=self.n_c4,
+ data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN,
+ scale_bytes=self.mla_scale_bytes,
+ align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES,
+ )
+ self.c4_indexer_pool = PackedPagePool(
+ size=self.c4_size,
+ page_size=DSV4_C4_PAGE_SIZE,
+ layer_num=self.n_c4,
+ data_bytes=self.indexer_head_dim,
+ scale_bytes=DSV4_INDEXER_SCALE_BYTES,
+ )
+ self.c4_num_pages = self.c4_size // DSV4_C4_PAGE_SIZE
+ assert self.c4_num_pages > 0, "DeepSeek-V4 c4 pool must have at least one usable full page"
+ self.c4_page_allocator = KvCacheAllocator(
+ self.c4_num_pages, shared_name=f"{server}_dsv4_c4_can_use_page_num_{rank_in_node}"
+ )
+ self.c4_page_live_count = torch.zeros((self.c4_pool.num_pages,), dtype=torch.int32, device="cuda")
+ self.full_to_c4_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda")
+ self.full_to_c4_indexs[size] = self.c4_pool.HOLD_TOKEN_MEMINDEX
+ # c4 compressor 在途状态(attention + indexer): swa 页派生寻址(翻译③),随 swa 页
+ # 生灭 -> radix 命中零拷贝续算。行数 = 页数*ring + ring(HOLD 页) + 1(哨兵),
+ # 取整到 ratio;末行哨兵 kv=0/score=-inf(KVAndScore.clear 语义),其余行由内核在
+ # 组起点覆写,无需按页清零。last_dim = 2*coff*head_dim(overlap coff=2)。
+ state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C4_STATE_RING, 4)
+ self.c4_state_buffer = torch.zeros(
+ (self.n_c4, state_rows, 4 * self.head_dim), dtype=torch.float32, device="cuda"
+ )
+ self.c4_indexer_state_buffer = torch.zeros(
+ (self.n_c4, state_rows, 4 * self.indexer_head_dim), dtype=torch.float32, device="cuda"
+ )
+ for buf in (self.c4_state_buffer, self.c4_indexer_state_buffer):
+ self._init_state_sentinel(buf)
+ if self.n_c128 > 0:
+ self.c128_pool = PackedPagePool(
+ size=self.c128_size,
+ page_size=DSV4_C128_PAGE_SIZE,
+ layer_num=self.n_c128,
+ data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN,
+ scale_bytes=self.mla_scale_bytes,
+ align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES,
+ )
+ self.c128_allocator = KvCacheAllocator(
+ self.c128_size, shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}"
+ )
+ self.full_to_c128_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda")
+ self.full_to_c128_indexs[size] = self.c128_pool.HOLD_TOKEN_MEMINDEX
+ # c128 compressor 在途状态: 与 c4 同样由 full->swa 推导行号,但 ring=128 且无 overlap。
+ # last_dim = 2*head_dim;末行是 swa 缺失/出窗时读取的哨兵。
+ state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C128_STATE_RING, 128)
+ self.c128_state_buffer = torch.zeros(
+ (self.n_c128, state_rows, 2 * self.head_dim), dtype=torch.float32, device="cuda"
+ )
+ self._init_state_sentinel(self.c128_state_buffer)
+
+ logger.info(
+ f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size}({self.swa_num_pages}p) "
+ f"c4={self.c4_size}(L={self.n_c4}) c128={self.c128_size}(L={self.n_c128}) "
+ f"packed_kv_bytes={self.mla_bytes_per_token} indexer_bytes={self.indexer_bytes_per_token}"
+ )
+
+ # ------------------------------------------------------------------ buffer accessors
+ def get_att_input_params(self, layer_index: int):
+ return self.swa_pool.get_layer_buffer(layer_index)
+
+ def _pool_and_local_layer(self, layer_index: int):
+ r = self.compress_rates[layer_index]
+ if r == 4:
+ return self.c4_pool, self.layer_to_c4_idx[layer_index]
+ if r == 128:
+ return self.c128_pool, self.layer_to_c128_idx[layer_index]
+ raise AssertionError(f"layer {layer_index} (rate {r}) 不是压缩层,没有压缩池")
+
+ def get_compressed_kv_buffer(self, layer_index: int) -> torch.Tensor:
+ pool, local_layer = self._pool_and_local_layer(layer_index)
+ return pool.get_layer_buffer(local_layer)
+
+ def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K"
+ return self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index])
+
+ def get_c4_state_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged compressor state"
+ return self.c4_state_buffer[self.layer_to_c4_idx[layer_index]]
+
+ def get_c4_indexer_state_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged indexer state"
+ return self.c4_indexer_state_buffer[self.layer_to_c4_idx[layer_index]]
+
+ def get_c128_state_buffer(self, layer_index: int) -> torch.Tensor:
+ assert self.compress_rates[layer_index] == 128, "只有 c128(HCA) 层有 paged compressor state"
+ return self.c128_state_buffer[self.layer_to_c128_idx[layer_index]]
+
+ # ------------------------------------------------------------------ swa slot lifecycle
+ def set_swa_pressure_valve(self, valve) -> None:
+ """valve(need_pages): 在页 allocator 不足时尝试腾页(radix 对 ref==0 节点回收 swa)。"""
+ self._swa_pressure_valve = valve
+ return
+
+ def _alloc_swa_pages(self, need_pages: int) -> torch.Tensor:
+ if need_pages > self.swa_page_allocator.can_use_mem_size and self._swa_pressure_valve is not None:
+ self._swa_pressure_valve(need_pages - self.swa_page_allocator.can_use_mem_size)
+ return self.swa_page_allocator.alloc(need_pages)
+
+ def _count_swa_pages(self, swa_slots: torch.Tensor, delta: int) -> torch.Tensor:
+ """按 slot 所在页更新存活计数,返回触达的页(去重)。"""
+ pages = torch.div(swa_slots.long(), DSV4_SWA_PAGE_SIZE, rounding_mode="floor")
+ ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device)
+ self.swa_page_live_count.index_add_(0, pages, ones)
+ return torch.unique(pages)
+
+ def alloc_swa_prefill(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ ) -> None:
+ """prefill prep: 为各请求位置 [ready, seq) 的新 token 分配位置对齐的 swa 槽。
+
+ 槽位不变式: slot(p) = page_base(p 所在页) + p%128,page_base % 128 == 0。
+ 续页(start 非整页,只可能是首页)的 base 从上一 token 的映射派生
+ (full_to_swa[req_to_token[req, start-1]],该 token 必在保留窗内);其余页全新分配。
+ radix 命中(ready 必 128 对齐)的借用方从全新页开始,与节点持有页天然不相交。
+ 必须在 init_req_to_token_indexes 之后调用(scatter 目标经 req_to_token 行)。
+ """
+ page = DSV4_SWA_PAGE_SIZE
+ hold_req_id = self.max_request_num # padding 行的请求 id(req_manager.HOLD_REQUEST_ID)
+ req_list = b_req_idx.detach().cpu().tolist()
+ ready_list = b_ready_cache_len.detach().cpu().tolist()
+ seq_list = b_seq_len.detach().cpu().tolist()
+
+ segs = [] # (req_idx, start, end, n_new_pages, has_cont_page)
+ total_new_pages = 0
+ for req_idx, start, end in zip(req_list, ready_list, seq_list):
+ req_idx, start, end = int(req_idx), int(start), int(end)
+ if req_idx == hold_req_id or end <= start:
+ continue
+ first_new_page = _ceil_div(start, page)
+ n_new = max(0, (end - 1) // page - first_new_page + 1)
+ segs.append((req_idx, start, end, n_new, start % page != 0))
+ total_new_pages += n_new
+ if not segs:
+ return
+
+ new_pages = self._alloc_swa_pages(total_new_pages).cuda(non_blocking=True).long() if total_new_pages else None
+ page_cursor = 0
+ for req_idx, start, end, n_new, has_cont in segs:
+ positions = torch.arange(start, end, dtype=torch.long, device="cuda")
+ page_local = torch.div(positions, page, rounding_mode="floor") - start // page
+ bases = torch.empty(((end - 1) // page - start // page + 1,), dtype=torch.long, device="cuda")
+ if has_cont:
+ prev_slot = int(self.full_to_swa_indexs[req_to_token_indexs[req_idx, start - 1].long()].item())
+ # 续页不变式: 上一 token 必驻留(retain >= 2)且位置对齐(未来 resume/MTP 改动的哨兵)。
+ assert prev_slot >= 0 and prev_slot % page == (start - 1) % page
+ bases[0] = prev_slot - (start - 1) % page
+ if n_new:
+ bases[1 if has_cont else 0 :] = new_pages[page_cursor : page_cursor + n_new] * page
+ page_cursor += n_new
+ slots = (bases[page_local] + positions % page).to(torch.int32)
+ self.full_to_swa_indexs[req_to_token_indexs[req_idx, start:end].long()] = slots
+ self._count_swa_pages(slots, 1)
+ return
+
+ def alloc_swa_decode(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ ) -> None:
+ """decode prep: 本步 token(位置 seq-1)的 swa 槽。整页起点开新页,否则上一 token 槽 +1
+ (位置对齐不变式保证同页连续)。scatter 目标用 mem_indexes(此刻 req_to_token 尚未写本步)。
+
+ 注意: 续槽从上一位置的映射派生,故同一请求的多行(MTP 多 token/步)在同一批内不支持
+ (DSV4 启动参数已拒绝 MTP;支持需按步内顺序分段派生)。"""
+ page = DSV4_SWA_PAGE_SIZE
+ hold_req_id = self.max_request_num
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ cont_rows, cont_prev_pos, new_rows = [], [], []
+ for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)):
+ req_idx, seq_len = int(req_idx), int(seq_len)
+ if req_idx == hold_req_id or seq_len <= 0:
+ continue
+ if (seq_len - 1) % page == 0:
+ new_rows.append(i)
+ else:
+ cont_rows.append(i)
+ cont_prev_pos.append(seq_len - 2)
+ mem_indexes = mem_indexes.cuda().long().reshape(-1)
+ if cont_rows:
+ req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda")
+ prev_full = req_to_token_indexs[req_rows, torch.tensor(cont_prev_pos, device="cuda")].long()
+ prev_slots = self.full_to_swa_indexs[prev_full]
+ # 续槽不变式哨兵: 上一位置必驻留(retain 覆盖)。prep 阶段本就有同步,代价可忽略。
+ assert bool((prev_slots >= 0).all())
+ slots = prev_slots + 1
+ self.full_to_swa_indexs[mem_indexes[cont_rows]] = slots
+ self._count_swa_pages(slots, 1)
+ if new_rows:
+ pages = self._alloc_swa_pages(len(new_rows)).cuda(non_blocking=True).long()
+ slots = (pages * page).to(torch.int32)
+ self.full_to_swa_indexs[mem_indexes[new_rows]] = slots
+ self._count_swa_pages(slots, 1)
+ return
+
+ def evict_swa(self, full_slots: torch.Tensor) -> None:
+ """回收 full 槽位对应的 swa 槽(出窗惰性回收 / free 级联 / 压力阀共用)。
+ 未映射(-1)的槽位跳过;页计数减到 0 时整页归还 allocator。"""
+ if full_slots.numel() == 0:
+ return
+ full_slots = full_slots.cuda().long().reshape(-1)
+ full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX])
+ if full_slots.numel() == 0:
+ return
+ swa_slots = self.full_to_swa_indexs[full_slots]
+ valid = swa_slots >= 0
+ valid_slots = swa_slots[valid]
+ if valid_slots.numel() == 0:
+ return
+ self.full_to_swa_indexs[full_slots[valid]] = -1
+ touched = self._count_swa_pages(valid_slots, -1)
+ empty = touched[self.swa_page_live_count[touched] == 0]
+ if empty.numel() > 0:
+ self.swa_page_allocator.free(empty.to(torch.int32))
+ return
+
+ def _evict_compress(self, full_slots: torch.Tensor, mapping: torch.Tensor, allocator: KvCacheAllocator) -> None:
+ full_slots = full_slots.cuda().long().reshape(-1)
+ # 去重: 同批重复槽会 gather 出重复的压缩槽 -> allocator 双重释放(free 已去重,直呼叫方防御)。
+ full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX])
+ if full_slots.numel() == 0:
+ return
+ slots = mapping[full_slots]
+ valid = slots >= 0
+ valid_slots = slots[valid]
+ if valid_slots.numel() == 0:
+ return
+ allocator.free(valid_slots)
+ mapping[full_slots[valid]] = -1
+ return
+
+ def alloc_c4_pages(self, need_pages: int) -> torch.Tensor:
+ assert self.c4_page_allocator is not None, "DeepSeek-V4 c4 page allocator is not initialized"
+ return self.c4_page_allocator.alloc(need_pages)
+
+ def count_c4_slots(self, c4_slots: torch.Tensor, delta: int) -> torch.Tensor:
+ """按 c4 slot 所在页更新存活计数,返回触达的页(去重)。"""
+ assert self.c4_page_live_count is not None, "DeepSeek-V4 c4 page live count is not initialized"
+ pages = torch.div(c4_slots.long(), DSV4_C4_PAGE_SIZE, rounding_mode="floor")
+ ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device)
+ self.c4_page_live_count.index_add_(0, pages, ones)
+ return torch.unique(pages)
+
+ def evict_c4(self, full_slots: torch.Tensor) -> None:
+ """回收 full 槽位(组末 token)映射的 c4 槽。非组末/未映射(-1)的槽位跳过。"""
+ if self.c4_page_allocator is None or full_slots.numel() == 0:
+ return
+ full_slots = full_slots.cuda().long().reshape(-1)
+ full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX])
+ if full_slots.numel() == 0:
+ return
+ slots = self.full_to_c4_indexs[full_slots]
+ valid = slots >= 0
+ valid_slots = slots[valid]
+ if valid_slots.numel() == 0:
+ return
+ self.full_to_c4_indexs[full_slots[valid]] = -1
+ touched = self.count_c4_slots(valid_slots, -1)
+ empty = touched[self.c4_page_live_count[touched] == 0]
+ if empty.numel() > 0:
+ self.c4_page_allocator.free(empty.to(torch.int32))
+ return
+
+ def evict_c128(self, full_slots: torch.Tensor) -> None:
+ """回收 full 槽位(组末 token)映射的 c128 槽。非组末/未映射(-1)的槽位跳过。"""
+ if self.c128_allocator is None or full_slots.numel() == 0:
+ return
+ self._evict_compress(full_slots, self.full_to_c128_indexs, self.c128_allocator)
+ return
+
+ # ------------------------------------------------------------------ alloc/free (cascade)
+ def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
+ """释放 full token 槽位,级联回收其 swa 槽与 c4/c128 压缩槽。radix 驱逐、请求释放/暂停都走这里。
+
+ 先对 full 槽去重: 同批重复槽位会让映射 gather 出重复的压缩/swa 槽,导致 allocator 双重释放。"""
+ if isinstance(free_index, list):
+ free_index = torch.tensor(free_index, dtype=torch.int64)
+ if free_index.numel() > 0:
+ free_index = torch.unique(free_index)
+ self.evict_swa(free_index)
+ self.evict_c4(free_index)
+ self.evict_c128(free_index)
+ super().free(free_index)
+ return
+
+ def free_all(self):
+ super().free_all()
+ self.swa_page_allocator.free_all()
+ self.swa_page_live_count.zero_()
+ self.full_to_swa_indexs.fill_(-1)
+ self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX
+ if self.c4_page_allocator is not None:
+ self.c4_page_allocator.free_all()
+ self.c4_page_live_count.zero_()
+ self.full_to_c4_indexs.fill_(-1)
+ self.full_to_c4_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c4_pool.HOLD_TOKEN_MEMINDEX
+ if self.c128_allocator is not None:
+ self.c128_allocator.free_all()
+ self.full_to_c128_indexs.fill_(-1)
+ self.full_to_c128_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c128_pool.HOLD_TOKEN_MEMINDEX
+ return
+
+ def alloc_c4(self, need_size) -> torch.Tensor:
+ raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation; call alloc_c4_pages instead")
+
+ def alloc_c128(self, need_size) -> torch.Tensor:
+ return self.c128_allocator.alloc(need_size)
+
+ def free_c4(self, free_index) -> None:
+ raise AssertionError("DeepSeek-V4 c4 uses page live-count release; call evict_c4 instead")
+
+ def free_c128(self, free_index) -> None:
+ self.c128_allocator.free(free_index)
+
+ # ------------------------------------------------------------------ packed codecs (torch reference)
+ # 与 sglang/vllm 的 fp8_ds_mla 字节布局逐位对齐(ue8m0 幂次 scale)。这些 torch 实现是该 ABI 的
+ # 可执行规格(单测 oracle,triton writer 与其逐字节对拍),不可删除。
+ def _pack_mla_kv(self, kv: torch.Tensor) -> torch.Tensor:
+ kv = kv.reshape(-1, self.mla_head_dim)
+ out = torch.empty((kv.shape[0], self.mla_bytes_per_token), dtype=torch.uint8, device=kv.device)
+ nope = kv[:, : self.mla_nope_dim].float().reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size)
+ scale = torch.clamp(nope.abs().amax(dim=-1) / DSV4_FP8_E4M3_MAX, min=DSV4_FP8_SCALE_MIN)
+ scale_exp = torch.ceil(torch.log2(scale)).to(torch.int32)
+ scale = torch.exp2(scale_exp.float())
+ nope_fp8 = torch.clamp(nope / scale.unsqueeze(-1), -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(
+ torch.float8_e4m3fn
+ )
+ out[:, : self.mla_nope_dim].copy_(nope_fp8.reshape(-1, self.mla_nope_dim).view(dtype=torch.uint8))
+ rope_start = self.mla_nope_dim
+ rope_end = rope_start + self.mla_rope_dim * 2
+ rope = kv[:, self.mla_nope_dim : self.mla_head_dim].contiguous().to(torch.bfloat16)
+ out[:, rope_start:rope_end].copy_(rope.view(dtype=torch.uint8).reshape(-1, self.mla_rope_dim * 2))
+ scale_start = rope_end
+ scale_end = scale_start + self.mla_scale_bytes - 1
+ out[:, scale_start:scale_end].copy_((scale_exp + 127).to(torch.uint8))
+ out[:, scale_end].zero_()
+ return out
+
+ def _unpack_mla_kv(self, packed: torch.Tensor) -> torch.Tensor:
+ packed = packed.reshape(-1, self.mla_bytes_per_token)
+ if packed.shape[0] == 0:
+ return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=packed.device)
+ nope_fp8 = packed[:, : self.mla_nope_dim].view(dtype=torch.float8_e4m3fn).float()
+ nope_fp8 = nope_fp8.reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size)
+ rope_start = self.mla_nope_dim
+ rope_end = rope_start + self.mla_rope_dim * 2
+ scale_start = rope_end
+ scale_end = scale_start + self.mla_scale_bytes - 1
+ scale_exp = packed[:, scale_start:scale_end].to(torch.int32) - 127
+ scale = torch.exp2(scale_exp.float())
+ nope = (nope_fp8 * scale.reshape(-1, self.mla_scale_bytes - 1, 1)).reshape(-1, self.mla_nope_dim)
+ rope = packed[:, rope_start:rope_end].view(dtype=torch.bfloat16)
+ return torch.cat([nope.to(self.dtype), rope.to(self.dtype)], dim=-1)
+
+ def _pack_indexer_k(self, indexer_k: torch.Tensor) -> torch.Tensor:
+ indexer_k = indexer_k.reshape(-1, self.indexer_head_dim)
+ out = torch.empty(
+ (indexer_k.shape[0], self.indexer_bytes_per_token),
+ dtype=torch.uint8,
+ device=indexer_k.device,
+ )
+ k_float = indexer_k.float()
+ scale = torch.clamp(
+ k_float.abs().amax(dim=-1, keepdim=True) / DSV4_FP8_E4M3_MAX,
+ min=DSV4_FP8_SCALE_MIN,
+ )
+ k_fp8 = torch.clamp(k_float / scale, -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(torch.float8_e4m3fn)
+ out[:, : self.indexer_head_dim].copy_(k_fp8.view(dtype=torch.uint8))
+ out[:, self.indexer_head_dim :].copy_(scale.view(dtype=torch.uint8).reshape(-1, DSV4_INDEXER_SCALE_BYTES))
+ return out
+
+ def _unpack_indexer_k(self, packed: torch.Tensor) -> torch.Tensor:
+ packed = packed.reshape(-1, self.indexer_bytes_per_token)
+ if packed.shape[0] == 0:
+ return torch.empty((0, self.indexer_head_dim), dtype=self.dtype, device=packed.device)
+ k_fp8 = packed[:, : self.indexer_head_dim].view(dtype=torch.float8_e4m3fn).float()
+ scale = packed[:, self.indexer_head_dim :].view(dtype=torch.float32)
+ return (k_fp8 * scale).to(self.dtype)
+
+ # ------------------------------------------------------------------ cache write paths
+ def pack_mla_kv_to_cache(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
+ """标准 operator 写入路径。要求本步已对 mem_index 调过 ``alloc_swa``(prep 阶段);
+ HOLD/padding 槽位映射到 swa HOLD 槽,写入无害。"""
+ if kv.shape[0] == 0:
+ return
+ from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import (
+ destindex_copy_kv_flashmla_dsv4,
+ )
+
+ swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)]
+ destindex_copy_kv_flashmla_dsv4(
+ kv.reshape(-1, self.mla_head_dim),
+ swa_slots,
+ self.swa_pool.get_layer_buffer(layer_index),
+ self.swa_pool.page_size,
+ )
+ return
+
+ def pack_mla_kv_to_cache_fused_norm_rope(
+ self,
+ layer_index: int,
+ mem_index: torch.Tensor,
+ kv: torch.Tensor,
+ kv_weight: torch.Tensor,
+ eps: float,
+ freqs_cis: torch.Tensor,
+ positions: torch.Tensor,
+ ):
+ """同 pack_mla_kv_to_cache,但 rmsnorm + 尾部交错 rope 融合进写入 kernel
+ (sglang fused_k_norm_rope_flashmla,即 sglang _compute_kv_to_cache 的池侧),
+ 省掉 bf16 kv 中间量。kv 为 wkv 投影原始输出 [T, head_dim+rope_dim]。"""
+ if kv.shape[0] == 0:
+ return
+ from sglang.jit_kernel.dsv4 import fused_k_norm_rope_flashmla
+
+ swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)]
+ # 未映射槽位(-1, 如 decode 图 warmup 的 HOLD 行: prep 跳过 alloc_swa)对老 triton
+ # 写入核是显式 no-op;sglang fused 核无负槽位防护(负页偏移=非法访存),mask 到
+ # swa HOLD 槽(垃圾桶语义,与 padding 行写入一致)。
+ swa_slots = torch.where(swa_slots < 0, torch.full_like(swa_slots, self.swa_pool.HOLD_TOKEN_MEMINDEX), swa_slots)
+ fused_k_norm_rope_flashmla(
+ kv=kv,
+ kv_weight=kv_weight,
+ eps=eps,
+ freqs_cis=freqs_cis,
+ positions=positions,
+ out_loc=swa_slots,
+ kvcache=self.swa_pool.get_layer_buffer(layer_index),
+ page_size=self.swa_pool.page_size,
+ )
+ return
+
+ def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor):
+ if comp.shape[0] == 0:
+ return
+ from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import (
+ destindex_copy_kv_flashmla_dsv4,
+ )
+
+ pool, local_layer = self._pool_and_local_layer(layer_index)
+ destindex_copy_kv_flashmla_dsv4(
+ comp.reshape(-1, self.mla_head_dim),
+ slots.to(comp.device),
+ pool.get_layer_buffer(local_layer),
+ pool.page_size,
+ )
+
+ def pack_indexer_k_to_cache(self, layer_index: int, slots: torch.Tensor, indexer_k: torch.Tensor):
+ if indexer_k.shape[0] == 0:
+ return
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K"
+ from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_indexer_k_dsv4 import (
+ destindex_copy_indexer_k_dsv4,
+ )
+
+ destindex_copy_indexer_k_dsv4(
+ indexer_k.reshape(-1, self.indexer_head_dim),
+ slots.to(indexer_k.device),
+ self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]),
+ self.c4_indexer_pool.page_size,
+ )
+
+ def gather_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor:
+ """反量化 gather c4 indexer-K: slots [N](c4 槽位,HOLD 合法) -> [N, indexer_head_dim] bf16。
+ indexer top-k 打分用(纯张量操作,cuda-graph 安全)。"""
+ assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K"
+ pool = self.c4_indexer_pool
+ flat = pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]).view(-1)
+ data_offsets, scale_offsets = pool._loc_offsets(slots.reshape(-1))
+ data_range = torch.arange(pool.data_bytes_per_token, device=flat.device)
+ scale_range = torch.arange(pool.scale_bytes_per_token, device=flat.device)
+ k_fp8 = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)].view(torch.float8_e4m3fn)
+ scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)].contiguous().view(torch.float32)
+ return (k_fp8.float() * scale).to(torch.bfloat16)
+
+ # ------------------------------------------------------------------ fenced inherited APIs
+ # kv_buffer 是 page 索引的 uint8 slab,基类按 token 索引读写的接口会静默写坏数据,显式拦截。
+ def get_index_kv_buffer(self, index):
+ raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io")
+
+ def load_index_kv_buffer(self, index, load_tensor_dict):
+ raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io")
+
+ def alloc_kv_move_buffer(self, max_req_total_len):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def write_mem_to_page_kv_move_buffer(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def read_page_kv_move_buffer_to_mem(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def send_to_decode_node(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def receive_from_prefill_node(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def send_to_decode_node_p2p(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
+
+ def receive_from_prefill_node_p2p(self, *args, **kwargs):
+ raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented")
diff --git a/lightllm/common/kv_cache_mem_manager/operator/__init__.py b/lightllm/common/kv_cache_mem_manager/operator/__init__.py
index 85c37ad39b..442c2e300e 100644
--- a/lightllm/common/kv_cache_mem_manager/operator/__init__.py
+++ b/lightllm/common/kv_cache_mem_manager/operator/__init__.py
@@ -5,6 +5,7 @@
from .deepseek import (
Deepseek2MemOperator,
Deepseek3_2MemOperator,
+ DeepseekV4MemOperator,
FP8PerTokenGroupQuantDeepseek3_2MemOperator,
)
from .fp8_quant import (
diff --git a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py
index 6e05b96e10..0725ce9b93 100644
--- a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py
+++ b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py
@@ -8,7 +8,9 @@
class Deepseek2MemOperator(NormalMemOperator):
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
- from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager
+ from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import (
+ Deepseek2MemoryManager,
+ )
mem_manager: Deepseek2MemoryManager = self.mem_manager
@@ -30,7 +32,9 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:
class Deepseek3_2MemOperator(Deepseek2MemOperator):
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
- from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import Deepseek3_2MemoryManager
+ from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import (
+ Deepseek3_2MemoryManager,
+ )
mem_manager: Deepseek3_2MemoryManager = self.mem_manager
from ...basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv
@@ -78,3 +82,14 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:
o_rope,
)
return
+
+
+class DeepseekV4MemOperator(BaseMemManagerOperator):
+ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
+ from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import (
+ DeepseekV4MemoryManager,
+ )
+
+ mem_manager: DeepseekV4MemoryManager = self.mem_manager
+ mem_manager.pack_mla_kv_to_cache(layer_index, mem_index, kv)
+ return
diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py
index cd534d53ec..4db55e1555 100644
--- a/lightllm/common/quantization/__init__.py
+++ b/lightllm/common/quantization/__init__.py
@@ -14,6 +14,7 @@
EXPERT_DTYPE_TO_QUANT_TYPE = {
"fp8": "deepgemm-fp8w8a8-b128",
"fp4": "deepgemm-fp4fp8-b32",
+ "mxfp4": "marlin-mxfp4w4a16-b32",
}
SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE)
@@ -64,10 +65,13 @@ def _mapping_quant_method(self):
logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}")
# fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度,
- # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。
+ # 按其值给 fused_moe 选用对应的量化方法。
expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None)
if expert_dtype is None:
return
+ # DeepSeek-V4 的 fp4 发布版自带预打包 MXFP4 专家。
+ if expert_dtype == "fp4" and self.network_config_.get("model_type") == "deepseek_v4":
+ expert_dtype = "mxfp4"
target = self._get_expert_quant_type(expert_dtype)
for layer_num in range(self.layer_num):
if self.expert_dtype is not None:
diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py
index ec1ee90fd4..bedf22ee95 100644
--- a/lightllm/common/quantization/deepgemm.py
+++ b/lightllm/common/quantization/deepgemm.py
@@ -198,6 +198,144 @@ def _create_weight(
return mm_param, mm_param_list
+@QUANTMETHODS.register(["marlin-mxfp4w4a16-b32"], platform="cuda")
+class MXFP4MoEQuantizationMethod(QuantizationMethod):
+ def __init__(self):
+ super().__init__()
+ self.block_size = 32
+ self.weight_suffix = "weight"
+ self.weight_zero_point_suffix = None
+ self.weight_scale_suffix = "scale"
+ self.has_weight_scale = True
+ self.has_weight_zero_point = False
+
+ @property
+ def method_name(self):
+ return "marlin-mxfp4w4a16-b32"
+
+ def quantize(self, weight: torch.Tensor, output: WeightPack):
+ raise NotImplementedError("marlin-mxfp4w4a16-b32 only loads pre-packed MXFP4 expert weights")
+
+ def apply(
+ self,
+ input_tensor: torch.Tensor,
+ weight_pack: "WeightPack",
+ out: Optional[torch.Tensor] = None,
+ workspace: Optional[torch.Tensor] = None,
+ use_custom_tensor_mananger: bool = True,
+ bias: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ raise NotImplementedError("marlin-mxfp4w4a16-b32 is only implemented for fused MoE expert weights")
+
+ def _probe_marlin_layout(self, size_n: int, size_k: int, dtype: torch.dtype, device_id: int):
+ """用零输入走一遍真实的 per-expert repack 路径,探出 marlin 终态布局的形状与类型。
+ 只调用 finalize 同款的 vllm 函数,不复刻其内部公式,杜绝形状漂移。结果按维度缓存
+ (各 MoE 层同维,全程只探两次: w13 一次、w2 一次)。"""
+ cache_key = (size_n, size_k, dtype)
+ cache = getattr(self, "_marlin_layout_cache", None)
+ if cache is None:
+ cache = self._marlin_layout_cache = {}
+ if cache_key in cache:
+ return cache[cache_key]
+
+ import vllm._custom_ops as ops
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
+ get_marlin_input_dtype,
+ marlin_permute_scales,
+ )
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ mxfp4_marlin_process_scales,
+ )
+
+ input_dtype = get_marlin_input_dtype()
+ is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
+ device = f"cuda:{device_id}"
+ qweight = torch.zeros((size_n, size_k // 2), dtype=torch.int8, device=device).view(torch.int32).T.contiguous()
+ marlin_qweight = ops.gptq_marlin_repack(
+ b_q_weight=qweight,
+ perm=torch.empty(0, dtype=torch.int, device=device),
+ size_k=size_k,
+ size_n=size_n,
+ num_bits=4,
+ is_a_8bit=is_a_8bit,
+ )
+ scale = torch.zeros((size_k // self.block_size, size_n), dtype=dtype, device=device)
+ marlin_scale = marlin_permute_scales(
+ s=scale, size_k=size_k, size_n=size_n, group_size=self.block_size, is_a_8bit=is_a_8bit
+ )
+ marlin_scale = mxfp4_marlin_process_scales(marlin_scale, input_dtype=input_dtype)
+ layout = (
+ (tuple(marlin_qweight.shape), marlin_qweight.dtype),
+ (tuple(marlin_scale.shape), marlin_scale.dtype),
+ )
+ cache[cache_key] = layout
+ return layout
+
+ def _create_weight(
+ self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1
+ ) -> Tuple[WeightPack, List[WeightPack]]:
+ out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims
+ assert in_dim % self.block_size == 0, "MXFP4 scale dimension must be divisible by block_size"
+ expert_prefix = (num_experts,) if num_experts > 1 else ()
+ # CPU 暂存区: load_hf_weights 灌入原始预打包 MXFP4,finalize 时 repack 进 CUDA 终态。
+ weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8, device="cpu")
+ weight_scale = torch.empty(
+ expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float8_e8m0fnu, device="cpu"
+ )
+ mm_param = WeightPack(weight=weight, weight_scale=weight_scale)
+ # CUDA 终态(marlin 布局)在构造期物化,使 mem manager 的 profile 看到真实权重占用
+ # ("构造即分配、load 只灌数"的框架契约,与其它 quant 方法一致;惰性到 finalize 才
+ # 进卡会让空卡 profile 把 kv 池撑到挤爆权重加载)。finalize 时 repack 结果拷入。
+ (w_shape, w_dtype), (s_shape, s_dtype) = self._probe_marlin_layout(out_dim, in_dim, dtype, device_id)
+ mm_param.marlin_weight = torch.empty((num_experts,) + w_shape, dtype=w_dtype, device=f"cuda:{device_id}")
+ mm_param.marlin_weight_scale = torch.empty((num_experts,) + s_shape, dtype=s_dtype, device=f"cuda:{device_id}")
+ mm_param_list = self._split_weight_pack(
+ mm_param,
+ weight_out_dims=out_dims,
+ weight_split_dim=-2,
+ weight_scale_out_dims=out_dims,
+ weight_scale_split_dim=-2,
+ )
+ return mm_param, mm_param_list
+
+ def finalize_moe_weight(self, moe_weight):
+ try:
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ prepare_moe_mxfp4_layer_for_marlin,
+ )
+ except Exception as e:
+ raise RuntimeError(f"marlin-mxfp4w4a16-b32 requires vLLM MXFP4 packing utilities, error={repr(e)}") from e
+
+ class _MXFP4Layer:
+ pass
+
+ device = torch.device("cuda", moe_weight.device_id_)
+ layer = _MXFP4Layer()
+ layer.params_dtype = moe_weight.data_type_
+ w13 = moe_weight.w13.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous()
+ w2 = moe_weight.w2.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous()
+ w13_scale = moe_weight.w13.weight_scale.to(device=device, non_blocking=True).contiguous()
+ w2_scale = moe_weight.w2.weight_scale.to(device=device, non_blocking=True).contiguous()
+ (
+ w13_new,
+ w2_new,
+ w13_scale_new,
+ w2_scale_new,
+ _,
+ _,
+ ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None)
+ # repack 结果拷入构造期预分配的 marlin 终态 buffer(与 AWQ marlin 路径同形态),
+ # CPU 暂存与 repack 临时随引用释放;shape 失配会在 copy_ 处显式报错(探针保证一致)。
+ moe_weight.w13.marlin_weight.copy_(w13_new)
+ moe_weight.w13.marlin_weight_scale.copy_(w13_scale_new)
+ moe_weight.w2.marlin_weight.copy_(w2_new)
+ moe_weight.w2.marlin_weight_scale.copy_(w2_scale_new)
+ moe_weight.w13.weight = moe_weight.w13.marlin_weight
+ moe_weight.w13.weight_scale = moe_weight.w13.marlin_weight_scale
+ moe_weight.w2.weight = moe_weight.w2.marlin_weight
+ moe_weight.w2.weight_scale = moe_weight.w2.marlin_weight_scale
+
+
def _deepgemm_fp8_nt(a_tuple, b_tuple, out):
if HAS_DEEPGEMM:
if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"):
diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py
index 01e9c4ad35..5e7c4f96dd 100644
--- a/lightllm/common/req_manager.py
+++ b/lightllm/common/req_manager.py
@@ -1,17 +1,26 @@
import torch
import collections
+from dataclasses import dataclass
from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig
from lightllm.utils.log_utils import init_logger
-from .kv_cache_mem_manager import MemoryManager
+from .kv_cache_mem_manager import MemoryManager, DeepseekV4MemoryManager
from typing import List, Optional, TYPE_CHECKING
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
-from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
+from lightllm.common.basemodel.triton_kernel.gen_sampling_params import (
+ update_req_to_token_id_counter,
+)
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache
-from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager
+from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import (
+ LinearAttCacheManager,
+)
+from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import (
+ DSV4_C4_PAGE_SIZE,
+ DSV4_PROMPT_CACHE_PAGE_SIZE,
+)
if TYPE_CHECKING:
from lightllm.server.router.model_infer.infer_batch import InferReq
@@ -19,6 +28,58 @@
logger = init_logger(__name__)
+@dataclass
+class DeepseekV4PromptCachePayload:
+ """prompt cache 载荷: 只剩 swa 按页有效性 bitmap。
+
+ 槽位与 compressor 状态都不进载荷: full_to_swa/full_to_c4/full_to_c128 以 full token 槽位
+ 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4/c128 compressor 状态以 swa
+ 页派生寻址(随 swa 页生灭,命中零拷贝续算)。prompt cache 对齐到 256 token,
+ 避免共享前缀停在 c4 物理页中间。
+
+ * ``swa_page_valid``: cpu bool [cache_len // page],插入时按当下 full_to_swa 映射写定
+ (页内 token 映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 page 边界,
+ swa 压力阀回收节点页时清零。"""
+
+ cache_len: int
+ swa_page_valid: Optional[torch.Tensor] = None
+
+
+class DeepseekV4PromptCacheValueOps:
+ def __init__(self, req_manager: "DeepseekV4ReqManager"):
+ self.req_manager = req_manager
+
+ def slice(self, payload: DeepseekV4PromptCachePayload, start: int, end: int):
+ return self.req_manager.slice_prompt_cache_payload(payload, start, end)
+
+ def concat(self, payloads: List[DeepseekV4PromptCachePayload]):
+ return self.req_manager.concat_prompt_cache_payloads(payloads)
+
+ def free(self, payload: DeepseekV4PromptCachePayload):
+ # 槽位资源全部由 mem_manager.free(full_slots) 级联回收,载荷本身没有需要释放的资源。
+ return
+
+ def invalidate_swa_pages(self, payload: DeepseekV4PromptCachePayload) -> None:
+ """swa 压力阀回收了该节点的 swa 页后清 bitmap: 后续命中按缩短语义裁剪,不会复活。"""
+ if payload is not None and payload.swa_page_valid is not None:
+ payload.swa_page_valid.fill_(False)
+ return
+
+ def valid_match_length(self, payload: Optional[DeepseekV4PromptCachePayload], natural_len: int) -> int:
+ """radix 匹配裁剪: 返回 <= natural_len 的最大 prompt-cache 边界 L',使结尾页有效。
+
+ 有效性可能非单调(owner 生前从左驱逐、后续阀从尾回收),按候选边界回查 bitmap;
+ 中段 invalid 页不挡更靠后的有效命中(注意力只回看最后一个窗口)。"""
+ page = self.req_manager.get_prompt_cache_page_size()
+ if payload is None or payload.swa_page_valid is None:
+ return 0
+ n_pages = min(natural_len // page, int(payload.swa_page_valid.numel()))
+ valid_idx = torch.nonzero(payload.swa_page_valid[:n_pages])
+ if valid_idx.numel() == 0:
+ return 0
+ return (int(valid_idx[-1]) + 1) * page
+
+
class _ReqNode:
def __init__(self, index):
self.index = index
@@ -100,6 +161,11 @@ def free_all(self):
self.req_list = _ReqLinkedList(self.max_request_num)
return
+ def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu):
+ """每个 decode step 在 to_cuda 之前调用的钩子 (数据为原生 CPU 张量, 且已在 forward 的
+ CUDA stream 上)。基类 no-op; 需要 per-step KV 槽位 prep 的模型 (DeepSeek-V4) override。"""
+ return
+
class ReqSamplingParamsManager:
"""
@@ -299,3 +365,469 @@ def copy_small_page_buffer_to_linear_att_state(
self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state
self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state
return
+
+
+class DeepseekV4ReqManager(ReqManager):
+ """DeepSeek-V4 的请求级管理。
+
+ 在基类 ReqManager 之上补 V4 专有的 per-request 结构。该对象在 mem manager profile 前创建,
+ 所以初始化只依赖 config 派生出的 compress_rates/head_dim/indexer_head_dim/sliding_window;
+ 真实 mem_manager 会在 `_init_mem_manager()` 后通过 `bind_mem_manager()` 接入。
+
+ * 压缩槽位不在本类: ``full_to_c4/c128_indexs``(mem manager)以组末 token 的 full 槽位为键。
+ 本类只负责 prep 阶段的分配与 scatter(``prepare_prefill_compress_slots`` /
+ ``prepare_decode_compress_slots``)——必须先于 attention metadata 构建/图捕获;
+ 条目内容由 layer-infer 的 compressor 前向写入。
+ * compressor 在途状态不在本类: c4/c128 都在 mem manager 的 swa 页派生池,
+ 随页生灭,命中零拷贝续算。
+ * SWA 槽位分配/出窗回收(``prepare_prefill_swa`` / ``prepare_decode_swa``): 每步 prep 阶段
+ 为新 token 调 mem_manager.alloc_swa,并按 per-req 水位线(``_swa_evict_marks``)惰性回收
+ 已出窗位置的 swa 槽。水位线首次置为该请求首个 chunk 的 ready_cache_len(radix 共享前缀
+ 的边界),因此共享前缀的 swa 槽永远不会被本请求回收(归 radix 经 mem_manager.free 级联释放)。
+ """
+
+ def __init__(
+ self,
+ max_request_num,
+ max_sequence_length,
+ mem_manager: Optional[DeepseekV4MemoryManager] = None,
+ compress_rates: Optional[List[int]] = None,
+ head_dim: Optional[int] = None,
+ indexer_head_dim: Optional[int] = None,
+ sliding_window: Optional[int] = None,
+ ):
+ super().__init__(max_request_num, max_sequence_length, mem_manager)
+
+ self.mem_manager = mem_manager
+ if mem_manager is not None:
+ if compress_rates is None:
+ compress_rates = mem_manager.compress_rates
+ if head_dim is None:
+ head_dim = mem_manager.head_dim
+ if indexer_head_dim is None:
+ indexer_head_dim = mem_manager.indexer_head_dim
+ if sliding_window is None:
+ sliding_window = mem_manager.sliding_window
+ self.sliding_window = sliding_window
+ # 出窗回收水位线: -1 表示该 req 尚未见过 prefill chunk(首个 chunk 的 ready_cache_len
+ # 即共享前缀边界,作为永不下探的回收下界)。
+ self._swa_evict_marks = [-1 for _ in range(max_request_num + 1)]
+ self.compress_rates = list(compress_rates)
+ self.n_c4 = sum(1 for r in self.compress_rates if r == 4)
+ self.n_c128 = sum(1 for r in self.compress_rates if r == 128)
+ self.head_dim = head_dim
+ self.indexer_head_dim = indexer_head_dim
+ self.layer_to_c4_idx = {}
+ self.layer_to_c128_idx = {}
+ c4 = c128 = 0
+ for lid, r in enumerate(self.compress_rates):
+ if r == 4:
+ self.layer_to_c4_idx[lid] = c4
+ c4 += 1
+ elif r == 128:
+ self.layer_to_c128_idx[lid] = c128
+ c128 += 1
+
+ return
+
+ def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager):
+ assert isinstance(mem_manager, DeepseekV4MemoryManager)
+ assert self.compress_rates == mem_manager.compress_rates
+ assert self.head_dim == mem_manager.head_dim
+ assert self.indexer_head_dim == mem_manager.indexer_head_dim
+ if self.sliding_window is None:
+ self.sliding_window = mem_manager.sliding_window
+ else:
+ assert mem_manager.sliding_window is None or self.sliding_window == mem_manager.sliding_window
+ self.mem_manager = mem_manager
+ return
+
+ # ------------------------------------------------------------------ swa slot prep (per step)
+ def _swa_retain_len(self) -> int:
+ """出窗回收的保留长度 = window + 一个 radix 页。
+
+ 多留一页使「最近一个完成的 prompt-cache 边界」的结尾页恒驻留: 若回收只留 window,
+ 则任何非对齐时刻该边界的结尾页都已被部分回收,插入门会把所有插入裁到 0。
+ V4 prompt-cache 页取 256 token,正好覆盖一个 c4 物理页对应的 token 范围。"""
+ return int(self.sliding_window) + self.get_prompt_cache_page_size()
+
+ def prepare_prefill_swa(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ ) -> None:
+ """prefill prep: 为本 chunk 全部新 token(位置 [ready, seq))分配位置对齐的 swa 槽,
+ 并回收已出窗位置的槽。
+
+ 本 chunk 起点 L = ready_cache_len,首个新 token(位置 L)的窗口是 [L-W+1, L];回收
+ 边界再额外保留一个 radix 页(_swa_retain_len),即位置 < L-retain+1。先回收再分配。
+ 必须在 init_req_to_token_indexes 之后调用(位置对齐分配经 req_to_token 行派生/scatter)。"""
+ assert self.mem_manager is not None
+ if self.sliding_window is not None:
+ retain = self._swa_retain_len()
+ evict_slots = []
+ req_list = b_req_idx.detach().cpu().tolist()
+ ready_list = b_ready_cache_len.detach().cpu().tolist()
+ for req_idx, ready_len in zip(req_list, ready_list):
+ req_idx = int(req_idx)
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ ready_len = int(ready_len)
+ mark = self._swa_evict_marks[req_idx]
+ if mark < 0:
+ # 首个 chunk: [0, ready_len) 是 radix 共享前缀,其 swa 槽归 radix 所有,不可回收。
+ self._swa_evict_marks[req_idx] = ready_len
+ continue
+ evict_end = ready_len - retain + 1
+ if evict_end > mark:
+ evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end])
+ self._swa_evict_marks[req_idx] = evict_end
+ if evict_slots:
+ self.mem_manager.evict_swa(torch.cat(evict_slots))
+ self.mem_manager.alloc_swa_prefill(b_req_idx, b_ready_cache_len, b_seq_len, self.req_to_token_indexs)
+ return
+
+ def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu):
+ """decode 每步槽位 prep: 先 swa 再 compress。由 BaseModel.forward / microbatch_overlap_decode
+ 在 to_cuda 之前调用 (CPU 数据 + forward 的 CUDA stream); 不再放在 _decode 里。"""
+ self.prepare_decode_swa(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu)
+ self.prepare_decode_compress_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu)
+ return
+
+ def prepare_decode_swa(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ ) -> None:
+ """decode prep: 回收出窗槽并为本步新 token 分配位置对齐的 swa 槽。当前 query 位置
+ seq_len-1 的窗口是 [seq_len-W, seq_len-1];回收边界额外保留一个 radix 页
+ (_swa_retain_len),即位置 < seq_len-retain。先回收再分配。
+ seq_len/req_idx 从 CPU 镜像读(host 算术,无 D2H);水位线 _swa_evict_marks 仍是 host 状态。"""
+ assert self.mem_manager is not None
+ if self.sliding_window is not None:
+ retain = self._swa_retain_len()
+ evict_slots = []
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ for req_idx, seq_len in zip(req_list, seq_list):
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ mark = self._swa_evict_marks[req_idx]
+ if mark < 0:
+ # 未经过 prefill prep 的保守路径: 不回收旧位置,仅推进水位线。
+ self._swa_evict_marks[req_idx] = max(0, seq_len - retain)
+ continue
+ evict_end = seq_len - retain
+ if evict_end > mark:
+ evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end])
+ self._swa_evict_marks[req_idx] = evict_end
+ if evict_slots:
+ self.mem_manager.evict_swa(torch.cat(evict_slots))
+ self.mem_manager.alloc_swa_decode(b_req_idx_cpu, b_seq_len_cpu, mem_indexes, self.req_to_token_indexs)
+ return
+
+ def init_compress_state(self, req_idx: int):
+ """新请求开始时重置 runtime 水位线(对应 mamba 的 init_linear_att_state 调用点)。
+
+ c4/c128 compressor state 都随 swa 页寻址,由内核按组覆写;请求复用时不做 per-req 清零。"""
+ self.clear_runtime_state(req_idx)
+ return
+
+ # ------------------------------------------------------------------ compress slot prep (per step)
+ def _compress_mapping_alloc(self, ratio: int):
+ assert self.mem_manager is not None, "DeepSeek-V4 mem manager is not bound yet"
+ if ratio == 4:
+ raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation")
+ if ratio == 128:
+ return self.mem_manager.full_to_c128_indexs, self.mem_manager.alloc_c128
+ raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}")
+
+ def _scatter_c4_prefill_slots_slow(self, req_idx: int, first: int, last: int) -> None:
+ """Idempotence fallback for overlapped/repeated c4 prep."""
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ for page_base in range((first // page) * page, last, page):
+ e0 = max(first, page_base)
+ e1 = min(last, page_base + page)
+ entries = torch.arange(e0, e1, dtype=torch.long, device="cuda")
+ full_slots = self.req_to_token_indexs[req_idx, entries * 4 + 3].long()
+ existing = mapping[full_slots]
+ missing = existing < 0
+ if not bool(missing.any()):
+ continue
+
+ mapped = torch.nonzero(existing >= 0, as_tuple=False)
+ if mapped.numel() > 0:
+ j = int(mapped[0].item())
+ base = int(existing[j].item()) - ((e0 + j) % page)
+ elif e0 > page_base:
+ prev_full = self.req_to_token_indexs[req_idx, e0 * 4 - 1].long()
+ prev_slot = int(mapping[prev_full].item())
+ assert prev_slot >= 0 and prev_slot % page == (e0 - 1) % page
+ base = prev_slot - ((e0 - 1) % page)
+ else:
+ base = int(self.mem_manager.alloc_c4_pages(1)[0].item()) * page
+
+ slots = (base + entries % page).to(torch.int32)
+ if mapped.numel() > 0:
+ assert bool((existing[existing >= 0] == slots[existing >= 0]).all())
+ mapping[full_slots[missing]] = slots[missing]
+ self.mem_manager.count_c4_slots(slots[missing], 1)
+ return
+
+ def _scatter_c4_prefill_slots(self, req_idx: int, first: int, last: int) -> None:
+ """为 logical c4 entry [first, last) 分配 page-safe c4 槽。
+
+ 不变式: logical entry e 映射到 physical_page * 64 + e % 64,同一 logical page
+ 内 entry 共享 physical_page。这是 DeepGEMM paged MQA logits 直接消费 page table 的前提。
+ """
+ if last <= first:
+ return
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ entries = torch.arange(first, last, dtype=torch.long, device="cuda")
+ full_slots = self.req_to_token_indexs[req_idx, entries * 4 + 3].long()
+ need = mapping[full_slots] < 0
+ if not bool(need.any()):
+ return
+ if not bool(need.all()):
+ self._scatter_c4_prefill_slots_slow(req_idx, first, last)
+ return
+
+ first_page = first // page
+ last_page = (last - 1) // page
+ n_pages = last_page - first_page + 1
+ bases = torch.empty((n_pages,), dtype=torch.long, device="cuda")
+
+ base_start = 0
+ if first % page != 0:
+ prev_full = self.req_to_token_indexs[req_idx, first * 4 - 1].long()
+ prev_slot = int(mapping[prev_full].item())
+ assert prev_slot >= 0 and prev_slot % page == (first - 1) % page
+ bases[0] = prev_slot - ((first - 1) % page)
+ base_start = 1
+
+ new_page_count = n_pages - base_start
+ if new_page_count > 0:
+ new_pages = self.mem_manager.alloc_c4_pages(new_page_count).cuda(non_blocking=True).long()
+ bases[base_start:] = new_pages * page
+
+ page_local = torch.div(entries, page, rounding_mode="floor") - first_page
+ slots = (bases[page_local] + entries % page).to(torch.int32)
+ mapping[full_slots] = slots
+ self.mem_manager.count_c4_slots(slots, 1)
+ return
+
+ def _scatter_c4_decode_slots(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ ) -> None:
+ page = DSV4_C4_PAGE_SIZE
+ mapping = self.mem_manager.full_to_c4_indexs
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ mem_indexes = mem_indexes.cuda().long().reshape(-1)
+
+ cont_rows, cont_prev_pos, cont_offsets = [], [], []
+ new_rows = []
+ for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)):
+ req_idx, seq_len = int(req_idx), int(seq_len)
+ if req_idx == self.HOLD_REQUEST_ID or seq_len <= 0 or seq_len % 4 != 0:
+ continue
+ entry = seq_len // 4 - 1
+ offset = entry % page
+ if offset == 0:
+ new_rows.append(i)
+ else:
+ cont_rows.append(i)
+ cont_prev_pos.append(entry * 4 - 1)
+ cont_offsets.append(offset)
+
+ if cont_rows:
+ req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda")
+ prev_pos = torch.tensor(cont_prev_pos, dtype=torch.long, device="cuda")
+ prev_full = self.req_to_token_indexs[req_rows, prev_pos].long()
+ prev_slots = mapping[prev_full]
+ offsets = torch.tensor(cont_offsets, dtype=torch.int32, device="cuda")
+ assert bool((prev_slots >= 0).all())
+ assert bool(((prev_slots % page) == (offsets - 1)).all())
+ slots = (prev_slots + 1).to(torch.int32)
+ mapping[mem_indexes[cont_rows]] = slots
+ self.mem_manager.count_c4_slots(slots, 1)
+
+ if new_rows:
+ pages = self.mem_manager.alloc_c4_pages(len(new_rows)).cuda(non_blocking=True).long()
+ slots = (pages * page).to(torch.int32)
+ mapping[mem_indexes[new_rows]] = slots
+ self.mem_manager.count_c4_slots(slots, 1)
+ return
+
+ def _scatter_compress_slots(self, ratio: int, full_slots: torch.Tensor) -> None:
+ """为组末 full 槽位分配压缩槽并写入映射。已映射(>=0)的行跳过——重复 prep 幂等。"""
+ if full_slots.numel() == 0:
+ return
+ mapping, alloc = self._compress_mapping_alloc(ratio)
+ full_slots = full_slots.cuda().long().reshape(-1)
+ # 去重: 同批重复键会让后写覆盖先写,先分配的压缩槽成为孤儿(allocator 泄漏)。
+ need = torch.unique(full_slots[mapping[full_slots] < 0])
+ if need.numel() == 0:
+ return
+ new_slots = alloc(need.numel()).cuda(non_blocking=True).to(torch.int32)
+ mapping[need] = new_slots
+ return
+
+ def prepare_prefill_compress_slots(
+ self,
+ b_req_idx: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ ) -> None:
+ """prefill prep: 为本 chunk 内的组末 token(位置 (g+1)*ratio-1 ∈ [ready, seq))分配压缩槽,
+ scatter 进 full_to_c4/c128_indexs。必须在 init_req_to_token_indexes 之后(组末 full 槽
+ 从 req_to_token_indexs 取)、attention metadata 构建之前调用。"""
+ if self.n_c4 == 0 and self.n_c128 == 0:
+ return
+ req_list = b_req_idx.detach().cpu().tolist()
+ ready_list = b_ready_cache_len.detach().cpu().tolist()
+ seq_list = b_seq_len.detach().cpu().tolist()
+ if self.n_c4 > 0:
+ for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list):
+ req_idx = int(req_idx)
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ self._scatter_c4_prefill_slots(req_idx, int(ready_len) // 4, int(seq_len) // 4)
+
+ if self.n_c128 > 0:
+ ratio = 128
+ end_slots = []
+ for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list):
+ req_idx = int(req_idx)
+ if req_idx == self.HOLD_REQUEST_ID:
+ continue
+ first, last = int(ready_len) // ratio, int(seq_len) // ratio
+ if last > first:
+ ends = self.req_to_token_indexs[req_idx, ratio - 1 : last * ratio : ratio]
+ end_slots.append(ends[first:])
+ if end_slots:
+ self._scatter_compress_slots(ratio, torch.cat(end_slots))
+ return
+
+ def prepare_decode_compress_slots(
+ self,
+ b_req_idx_cpu: torch.Tensor,
+ b_seq_len_cpu: torch.Tensor,
+ mem_indexes: torch.Tensor,
+ ) -> None:
+ """decode prep: 本步 token 关闭一个组(seq_len % ratio == 0)时为其分配压缩槽并 scatter。
+ 组末 full 槽即本步的 mem_index(此刻 req_to_token_indexs 尚未写入本步槽位)。
+ 从 CPU 镜像读 seq_len/req_idx(host 算术,无 D2H);非关组步 rows 为空 => 不调 _scatter,零同步。"""
+ if self.n_c4 == 0 and self.n_c128 == 0:
+ return
+ req_list = b_req_idx_cpu.tolist()
+ seq_list = b_seq_len_cpu.tolist()
+ if self.n_c4 > 0:
+ self._scatter_c4_decode_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes)
+
+ if self.n_c128 > 0:
+ ratio = 128
+ rows = [
+ i
+ for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list))
+ if req_idx != self.HOLD_REQUEST_ID and seq_len > 0 and seq_len % ratio == 0
+ ]
+ if rows:
+ self._scatter_compress_slots(ratio, mem_indexes.reshape(-1)[rows])
+ return
+
+ def alloc(self):
+ req_idx = super().alloc()
+ if req_idx is not None:
+ self.init_compress_state(req_idx)
+ return req_idx
+
+ def clear_runtime_state(self, req_idx: int):
+ # swa 槽位本身由 mem_manager.free 级联回收(随 full 槽位),这里只复位出窗水位线。
+ self._swa_evict_marks[req_idx] = -1
+ return
+
+ def get_prompt_cache_value_ops(self):
+ return DeepseekV4PromptCacheValueOps(self)
+
+ def get_prompt_cache_page_size(self):
+ return DSV4_PROMPT_CACHE_PAGE_SIZE
+
+ def compute_swa_page_valid(self, full_slots: torch.Tensor) -> torch.Tensor:
+ """按当下 full_to_swa 映射给出按页有效性: full_slots [L](L 为 page 整数倍) ->
+ cpu bool [L/page],页内全部映射有效才为 True。GPU gather + 同步,测试/校验用;
+ 插入热路径用 swa_page_valid_from_watermark(纯 CPU,免同步)。"""
+ page = self.get_prompt_cache_page_size()
+ assert full_slots.numel() % page == 0
+ if full_slots.numel() == 0:
+ return torch.zeros((0,), dtype=torch.bool)
+ swa = self.mem_manager.full_to_swa_indexs[full_slots.cuda().long().reshape(-1)]
+ return (swa.view(-1, page) >= 0).all(dim=1).cpu()
+
+ def swa_page_valid_from_watermark(self, req_idx: int, cache_len: int) -> torch.Tensor:
+ """插入时的按页有效性,纯 CPU: 请求自有 token 的 swa 映射只被出窗水位线回收
+ (阀不触活跃请求,级联只在 free 时),页 p 全驻留 ⟺ 页起点 128p >= 水位线。
+
+ 与 compute_swa_page_valid 在插入时刻对自有 token 等价,但不做 GPU gather/同步——
+ router 关键路径上每次插入省一次对全部在途 kernel 的等待。bitmap 中借入前缀
+ ([0, ready) 的页)的行在 radix insert 切片时被丢弃(既有节点保留自己的 bitmap),
+ 其取值无影响。"""
+ page = self.get_prompt_cache_page_size()
+ mark = max(0, self._swa_evict_marks[req_idx])
+ n_pages = int(cache_len) // page
+ return torch.arange(n_pages, dtype=torch.long) * page >= mark
+
+ def slice_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload, start: int, end: int):
+ start = int(start)
+ end = int(end)
+ page = self.get_prompt_cache_page_size()
+ # radix page 保证分裂点页对齐,bitmap 可整页切分。
+ return DeepseekV4PromptCachePayload(
+ cache_len=end - start,
+ swa_page_valid=payload.swa_page_valid[start // page : end // page].clone()
+ if payload.swa_page_valid is not None
+ else None,
+ )
+
+ def concat_prompt_cache_payloads(self, payloads: List[DeepseekV4PromptCachePayload]):
+ if len(payloads) == 0:
+ return None
+ bitmaps = [p.swa_page_valid for p in payloads]
+ return DeepseekV4PromptCachePayload(
+ cache_len=sum(p.cache_len for p in payloads),
+ swa_page_valid=torch.cat(bitmaps, dim=0) if all(b is not None for b in bitmaps) else None,
+ )
+
+ def build_prompt_cache_payload(
+ self,
+ req_idx: int,
+ cache_len: int,
+ ) -> DeepseekV4PromptCachePayload:
+ """构造插入载荷。compressor 状态不进载荷(c4 随 swa 页生灭、c128 边界自然归零),
+ cache_len 不再受序列末端约束——任意 128 对齐前缀皆可插入。
+ swa_page_valid 不在此填: 它必须用插入时刻的映射(infer batch 在 insert 前补)。"""
+ assert self.mem_manager is not None
+ return DeepseekV4PromptCachePayload(cache_len=int(cache_len))
+
+ def free(self, free_req_indexes, free_token_index):
+ """dense/swa/压缩槽全部经 mem_manager.free(free_token_index) 级联回收。"""
+ for req_index in free_req_indexes:
+ self.clear_runtime_state(req_index)
+ super().free(free_req_indexes, free_token_index)
+ return
+
+ def free_req(self, free_req_index: int):
+ self.clear_runtime_state(free_req_index)
+ return super().free_req(free_req_index)
+
+ def free_all(self):
+ super().free_all()
+ self._swa_evict_marks = [-1 for _ in range(self.max_request_num + 1)]
+ return
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index f619b1d88f..3d376d160d 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -20,6 +20,7 @@
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel
+from lightllm.models.deepseek_v4.model import DeepseekV4TpPartModel
from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel
from lightllm.models.internvl.model import (
InternVLLlamaTpPartModel,
diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py
index 30e5a59248..a8f851de2a 100644
--- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py
+++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py
@@ -29,6 +29,8 @@ def _rotary_kernel(
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
NUM_STAGE: tl.constexpr,
+ HAS_K: tl.constexpr,
+ INVERSE: tl.constexpr,
):
head_start_index = tl.program_id(0)
seq_block_index = tl.program_id(1)
@@ -44,6 +46,8 @@ def _rotary_kernel(
off_dimcos_sin = seq_index * stride_cosbs + cos_range * stride_cosd
cos = tl.load(Cos + off_dimcos_sin)
sin = tl.load(Sin + off_dimcos_sin)
+ if INVERSE:
+ sin = -sin
if HEAD_PARALLEL_NUM == 1:
for q_head_index in tl.static_range(0, HEAD_Q, step=1):
@@ -56,18 +60,19 @@ def _rotary_kernel(
tl.store(Q + off_q0, out_q0)
tl.store(Q + off_q1, out_q1)
- for k_head_index in tl.static_range(0, HEAD_K, step=1):
- off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
- off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
+ if HAS_K:
+ for k_head_index in tl.static_range(0, HEAD_K, step=1):
+ off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
+ off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
- k0 = tl.load(K + off_k0)
- k1 = tl.load(K + off_k1)
+ k0 = tl.load(K + off_k0)
+ k1 = tl.load(K + off_k1)
- out_k0 = k0 * cos - k1 * sin
- out_k1 = k0 * sin + k1 * cos
+ out_k0 = k0 * cos - k1 * sin
+ out_k1 = k0 * sin + k1 * cos
- tl.store(K + off_k0, out_k0)
- tl.store(K + off_k1, out_k1)
+ tl.store(K + off_k0, out_k0)
+ tl.store(K + off_k1, out_k1)
else:
for q_head_index in tl.range(head_start_index, HEAD_Q, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE):
off_q0 = seq_index * stride_qbs + q_head_index * stride_qh + dim_range0 * stride_qd
@@ -79,18 +84,19 @@ def _rotary_kernel(
tl.store(Q + off_q0, out_q0)
tl.store(Q + off_q1, out_q1)
- for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE):
- off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
- off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
+ if HAS_K:
+ for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE):
+ off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd
+ off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd
- k0 = tl.load(K + off_k0)
- k1 = tl.load(K + off_k1)
+ k0 = tl.load(K + off_k0)
+ k1 = tl.load(K + off_k1)
- out_k0 = k0 * cos - k1 * sin
- out_k1 = k0 * sin + k1 * cos
+ out_k0 = k0 * cos - k1 * sin
+ out_k1 = k0 * sin + k1 * cos
- tl.store(K + off_k0, out_k0)
- tl.store(K + off_k1, out_k1)
+ tl.store(K + off_k0, out_k0)
+ tl.store(K + off_k1, out_k1)
return
@@ -109,7 +115,10 @@ def get_test_configs():
def get_static_key(q, k):
- head_num_q, head_num_k, head_dim = q.shape[1], k.shape[1], q.shape[2]
+ assert q is not None, "q can not be None"
+ head_num_q = q.shape[1]
+ head_num_k = k.shape[1] if k is not None else 0
+ head_dim = q.shape[2]
return {
"Q_HEAD_NUM": head_num_q,
"K_HEAD_NUM": head_num_k,
@@ -126,12 +135,17 @@ def get_static_key(q, k):
mutates_args=["q", "k"],
)
@torch.no_grad()
-def rotary_emb_fwd(q, k, cos, sin, run_config=None):
+def rotary_emb_fwd(q, k, cos, sin, inverse=False, run_config=None):
+ assert q is not None, "q can not be None"
+ has_k = k is not None and k.shape[1] != 0
total_len = q.shape[0]
- head_num_q, head_num_k = q.shape[1], k.shape[1]
+ head_num_q = q.shape[1]
+ head_num_k = k.shape[1] if k is not None else 0
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
- assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
+ if k is not None:
+ assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
+ assert k.shape[2] == head_dim, f"k shape {k.shape} q head_dim {head_dim}"
assert triton.next_power_of_2(head_dim) == head_dim
if not run_config:
@@ -157,9 +171,9 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None):
stride_qbs=q.stride(0),
stride_qh=q.stride(1),
stride_qd=q.stride(2),
- stride_kbs=k.stride(0),
- stride_kh=k.stride(1),
- stride_kd=k.stride(2),
+ stride_kbs=k.stride(0) if k is not None else 0,
+ stride_kh=k.stride(1) if k is not None else 0,
+ stride_kd=k.stride(2) if k is not None else 0,
stride_cosbs=cos.stride(0),
stride_cosd=cos.stride(1),
stride_sinbs=sin.stride(0),
@@ -171,6 +185,8 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None):
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_DMODEL=head_dim,
NUM_STAGE=num_stages,
+ HAS_K=has_k,
+ INVERSE=inverse,
num_warps=num_warps,
num_stages=num_stages,
)
diff --git a/lightllm/models/deepseek_v4/__init__.py b/lightllm/models/deepseek_v4/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py
new file mode 100644
index 0000000000..ca2ac83b03
--- /dev/null
+++ b/lightllm/models/deepseek_v4/infer_struct.py
@@ -0,0 +1,58 @@
+import torch
+from lightllm.common.basemodel import InferStateInfo
+from lightllm.common.req_manager import DeepseekV4ReqManager
+from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager
+
+
+class DeepseekV4InferStateInfo(InferStateInfo):
+ req_manager: DeepseekV4ReqManager
+ mem_manager: DeepseekV4MemoryManager
+
+ """Per-token interleaved-rope cos/sin for the two rope variants (sliding / compressed), following
+ the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). The full rope tables are
+ model constants and live on the model / layer infers, not here."""
+
+ def __init__(self):
+ super().__init__()
+ self.position_cos_sliding = None
+ self.position_sin_sliding = None
+ self.position_cos_compress = None
+ self.position_sin_compress = None
+ # layer-independent sparse-index metadata, built once per forward in init_some_extra_state
+ # (None until then so copy_for_cuda_graph's tensor-attr loop skips them).
+ self.dsv4_sparse_req_idx = None
+ self.dsv4_swa_indices = None
+ self.dsv4_swa_lengths = None
+
+ def init_some_extra_state(self, model):
+ super().init_some_extra_state(model) # sets position_ids, b_q_seq_len, b_q_start_loc (prefill)
+ pos = self.position_ids
+ self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, pos) # [T, rope_dim//2]
+ self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos)
+ self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos)
+ self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos)
+ # Per-token request id (decode: one token per req; prefill: ragged -> repeat by q-len).
+ # Layer-independent; the swa kernel + build_metadata's c4/c128 readers all reuse it.
+ if self.is_prefill:
+ self.dsv4_sparse_req_idx = torch.repeat_interleave(self.b_req_idx, self.b_q_seq_len.long())
+ else:
+ self.dsv4_sparse_req_idx = self.b_req_idx
+ # Sliding-window indices: layer-independent (full_to_swa is global, window is const), so build
+ # once here via one fused kernel instead of recomputing per layer. const [T, window] shape is
+ # cuda-graph-safe (no max_kv_seq_len dependence) and auto-staged by copy_for_cuda_graph.
+ from lightllm.models.deepseek_v4.triton_kernel.build_swa_index_dsv4 import build_swa_index
+
+ self.dsv4_swa_indices, self.dsv4_swa_lengths = build_swa_index(
+ req_idx=self.dsv4_sparse_req_idx,
+ positions=self.position_ids,
+ req_to_token_indexs=self.req_manager.req_to_token_indexs,
+ full_to_swa_indexs=self.mem_manager.full_to_swa_indexs,
+ window=int(self.mem_manager.sliding_window),
+ )
+ # prefill-cudagraph 桶填充的 HOLD 尾请求的 q 行数。其注意力读 HOLD 槽位(内容被并发写
+ # 竞争,每轮不同),输出必须清零,否则 pad 行 hidden 不确定 -> MoE 路由抖动 -> 共享 expert
+ # 批次组成变化 -> 真实行 GEMM 归约顺序变化(ulp 级),44 层放大后翻转低置信 token。
+ self._dsv4_prefill_pad_q_len = 0
+ if self.is_prefill and self.b_req_idx.numel() > 0:
+ if int(self.b_req_idx[-1].item()) == self.req_manager.HOLD_REQUEST_ID:
+ self._dsv4_prefill_pad_q_len = int((self.b_seq_len[-1] - self.b_ready_cache_len[-1]).item())
diff --git a/lightllm/models/deepseek_v4/layer_infer/__init__.py b/lightllm/models/deepseek_v4/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py
new file mode 100644
index 0000000000..c66fd22de5
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py
@@ -0,0 +1,471 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra import libdevice
+
+from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager
+from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import (
+ DSV4_C4_STATE_RING,
+ DSV4_C128_STATE_RING,
+ DSV4_SWA_PAGE_SIZE,
+)
+
+
+@dataclass
+class CoreCompressorMetadata:
+ layer_idx: int
+ compress_ratio: int
+ out_slots: torch.Tensor
+ mem_index: torch.Tensor
+ state_buffer: torch.Tensor
+ out_buffer: torch.Tensor
+ out_page_size: int
+ position_ids: torch.Tensor
+ b_req_idx: torch.Tensor
+ b_seq_len: torch.Tensor
+ b_ready_cache_len: Optional[torch.Tensor]
+ b_q_start_loc: Optional[torch.Tensor]
+ req_to_token_indexs: torch.Tensor
+ full_to_swa_indexs: torch.Tensor
+ token_to_batch_idx: Optional[torch.Tensor]
+ kv_score: Optional[torch.Tensor]
+ is_prefill: bool
+
+
+@triton.jit
+def _add_ape_to_kv_score_kernel(
+ kv_score,
+ kv_score_stride0,
+ kv_score_stride1,
+ ape,
+ ape_stride0,
+ positions,
+ STATE_WIDTH: tl.constexpr,
+ COMPRESS_RATIO: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ offs = tl.arange(0, BLOCK)
+ mask = offs < STATE_WIDTH
+
+ position = tl.load(positions + token_idx)
+ ape_row = position % COMPRESS_RATIO
+ score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask)
+ ape_value = tl.load(ape + ape_row * ape_stride0 + offs, mask=mask)
+ tl.store(
+ kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1,
+ score + ape_value,
+ mask=mask,
+ )
+ return
+
+
+@triton.jit
+def _save_partial_states_kernel(
+ kv_score,
+ kv_score_stride0,
+ kv_score_stride1,
+ positions,
+ token_to_batch_idx,
+ b_req_idx,
+ b_seq_len,
+ mem_index,
+ full_to_swa,
+ state_buffer,
+ STATE_WIDTH: tl.constexpr,
+ STATE_LAST_DIM: tl.constexpr,
+ COMPRESS_RATIO: tl.constexpr,
+ IS_C4: tl.constexpr,
+ IS_PREFILL: tl.constexpr,
+ SWA_PAGE_SIZE: tl.constexpr,
+ STATE_RING: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx
+ position = tl.load(positions + token_idx)
+ seq_len = tl.load(b_seq_len + batch_idx)
+
+ if IS_C4:
+ same_page_next = (position % SWA_PAGE_SIZE) + STATE_RING < SWA_PAGE_SIZE
+ if same_page_next and position + STATE_RING < seq_len:
+ return
+ else:
+ if position + COMPRESS_RATIO < seq_len:
+ return
+
+ full_slot = tl.load(mem_index + token_idx).to(tl.int64)
+ swa_slot = tl.load(full_to_swa + full_slot).to(tl.int64)
+ if swa_slot < 0:
+ return
+ state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING)
+
+ offs = tl.arange(0, BLOCK)
+ mask = offs < STATE_WIDTH
+ kv = tl.load(kv_score + token_idx * kv_score_stride0 + offs * kv_score_stride1, mask=mask)
+ score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask)
+ state_base = state_buffer + state_row * STATE_LAST_DIM
+ tl.store(state_base + offs, kv, mask=mask)
+ tl.store(state_base + STATE_WIDTH + offs, score, mask=mask)
+ return
+
+
+@triton.jit
+def _fused_compress_norm_rope_insert_kernel(
+ kv_score,
+ kv_score_stride0,
+ kv_score_stride1,
+ state_buffer,
+ positions,
+ token_to_batch_idx,
+ b_req_idx,
+ b_seq_len,
+ b_ready_cache_len,
+ b_q_start_loc,
+ req_to_token,
+ req_to_token_stride0,
+ full_to_swa,
+ out_slots,
+ norm_weight,
+ rms_eps,
+ cos_table,
+ cos_stride0,
+ sin_table,
+ sin_stride0,
+ out_buffer,
+ HEAD_DIM: tl.constexpr,
+ STATE_WIDTH: tl.constexpr,
+ STATE_LAST_DIM: tl.constexpr,
+ COMPRESS_RATIO: tl.constexpr,
+ WINDOW_SIZE: tl.constexpr,
+ IS_C4: tl.constexpr,
+ IS_PREFILL: tl.constexpr,
+ SWA_PAGE_SIZE: tl.constexpr,
+ STATE_RING: tl.constexpr,
+ ROPE_HEAD_DIM: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+ SCALE_MIN: tl.constexpr,
+ NOPE_DIM: tl.constexpr,
+ QUANT_BLOCK: tl.constexpr,
+ SCALE_BYTES: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+ BLOCK: tl.constexpr,
+ OUTPUT_BF16: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ out_slot = tl.load(out_slots + token_idx).to(tl.int64)
+ if out_slot < 0:
+ return
+
+ position = tl.load(positions + token_idx)
+ if (position + 1) % COMPRESS_RATIO != 0:
+ return
+
+ batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx
+ req_idx = tl.load(b_req_idx + batch_idx).to(tl.int64)
+ seq_len = tl.load(b_seq_len + batch_idx)
+ if IS_PREFILL:
+ ready_len = tl.load(b_ready_cache_len + batch_idx)
+ q_start = tl.load(b_q_start_loc + batch_idx)
+ else:
+ ready_len = position
+ q_start = token_idx
+
+ token_offsets = tl.arange(0, WINDOW_SIZE)
+ start = position - WINDOW_SIZE + 1
+ gather_pos = start + token_offsets
+ valid_pos = (gather_pos >= 0) & (gather_pos < seq_len)
+ use_current = (gather_pos >= ready_len) & valid_pos if IS_PREFILL else gather_pos == position
+ current_idx = q_start + (gather_pos - ready_len) if IS_PREFILL else token_idx + token_offsets * 0
+
+ if IS_C4:
+ full_slot = tl.load(
+ req_to_token + req_idx * req_to_token_stride0 + gather_pos,
+ mask=valid_pos & (~use_current),
+ other=0,
+ ).to(tl.int64)
+ swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64)
+ state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING)
+ state_valid = valid_pos & (~use_current) & (swa_slot >= 0)
+ head_offset = tl.where(token_offsets >= COMPRESS_RATIO, HEAD_DIM, 0)
+ else:
+ full_slot = tl.load(
+ req_to_token + req_idx * req_to_token_stride0 + gather_pos,
+ mask=valid_pos & (~use_current),
+ other=0,
+ ).to(tl.int64)
+ swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64)
+ state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING)
+ state_valid = valid_pos & (~use_current) & (swa_slot >= 0)
+ head_offset = token_offsets * 0
+
+ offs = tl.arange(0, BLOCK)
+ dim_mask = offs < HEAD_DIM
+ current_mask = use_current[:, None] & dim_mask[None, :]
+ state_mask = state_valid[:, None] & dim_mask[None, :]
+
+ cur_kv = tl.load(
+ kv_score + current_idx[:, None] * kv_score_stride0 + (head_offset[:, None] + offs[None, :]) * kv_score_stride1,
+ mask=current_mask,
+ other=0.0,
+ )
+ cur_score = tl.load(
+ kv_score
+ + current_idx[:, None] * kv_score_stride0
+ + (STATE_WIDTH + head_offset[:, None] + offs[None, :]) * kv_score_stride1,
+ mask=current_mask,
+ other=float("-inf"),
+ )
+ state_kv = tl.load(
+ state_buffer + state_row[:, None] * STATE_LAST_DIM + head_offset[:, None] + offs[None, :],
+ mask=state_mask,
+ other=0.0,
+ )
+ state_score = tl.load(
+ state_buffer + state_row[:, None] * STATE_LAST_DIM + STATE_WIDTH + head_offset[:, None] + offs[None, :],
+ mask=state_mask,
+ other=float("-inf"),
+ )
+
+ kv = tl.where(current_mask, cur_kv, state_kv)
+ score = tl.where(current_mask, cur_score, state_score)
+ score = tl.softmax(score, dim=0)
+ compressed_kv = tl.sum(kv * score, axis=0)
+
+ rms_w = tl.load(norm_weight + offs, mask=dim_mask, other=0.0)
+ variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_DIM
+ rrms = tl.rsqrt(variance + rms_eps)
+ normed = compressed_kv * rrms * rms_w
+
+ num_pairs: tl.constexpr = BLOCK // 2
+ nope_pairs: tl.constexpr = NOPE_DIM // 2
+ pair_2d = tl.reshape(normed, (num_pairs, 2))
+ even, odd = tl.split(pair_2d)
+ pair_idx = tl.arange(0, num_pairs)
+ rope_pair_local = pair_idx - nope_pairs
+ is_rope_pair = rope_pair_local >= 0
+ cs_idx = tl.maximum(rope_pair_local, 0)
+ compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO
+ cos_v = tl.load(cos_table + compressed_pos * cos_stride0 + cs_idx, mask=is_rope_pair, other=1.0)
+ sin_v = tl.load(sin_table + compressed_pos * sin_stride0 + cs_idx, mask=is_rope_pair, other=0.0)
+ new_even = even * cos_v - odd * sin_v
+ new_odd = odd * cos_v + even * sin_v
+ rotated = tl.interleave(new_even, new_odd)
+
+ if OUTPUT_BF16:
+ # indexer-K path: emit the post-rope full HEAD_DIM vector as dense bf16 (token-indexed),
+ # leaving the fp8 single-amax pack to destindex_copy_indexer_k_dsv4 (the c4_indexer_pool
+ # ABI differs from the latent slab: whole-vector fp8 + one fp32 scale, no bf16 rope tail).
+ tl.store(out_buffer + token_idx * HEAD_DIM + offs, rotated.to(tl.bfloat16), mask=dim_mask)
+ return
+
+ page = out_slot // PAGE_SIZE
+ token_in_page = out_slot % PAGE_SIZE
+ data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_HEAD_DIM * 2)
+ scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_HEAD_DIM * 2) + token_in_page * SCALE_BYTES
+
+ n_quant_blocks: tl.constexpr = BLOCK // QUANT_BLOCK
+ n_nope_blocks: tl.constexpr = NOPE_DIM // QUANT_BLOCK
+ quant_input = normed.to(tl.bfloat16).to(tl.float32)
+ quant_2d = tl.reshape(quant_input, (n_quant_blocks, QUANT_BLOCK))
+ abs_2d = tl.abs(quant_2d)
+ block_absmax = tl.max(abs_2d, axis=1)
+ scale_exp = tl.ceil(libdevice.log2(tl.maximum(block_absmax / FP8_MAX, SCALE_MIN))).to(tl.int32)
+ scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True)
+ kv_fp8 = tl.clamp(quant_2d / scale[:, None], -FP8_MAX, FP8_MAX).to(tl.float8e4nv)
+ kv_u8 = tl.reshape(kv_fp8.to(tl.uint8, bitcast=True), (BLOCK,))
+ tl.store(out_buffer + data_base + offs, kv_u8, mask=offs < NOPE_DIM)
+
+ scale_idx = tl.arange(0, SCALE_BYTES)
+ scale_bytes = tl.where(scale_idx < n_nope_blocks, scale_exp + 127, 0).to(tl.uint8)
+ tl.store(out_buffer + scale_base + scale_idx, scale_bytes)
+
+ rope_local = offs - NOPE_DIM
+ rope_mask = (offs >= NOPE_DIM) & dim_mask
+ rope_ptr = (out_buffer + data_base + NOPE_DIM).to(tl.pointer_type(tl.bfloat16))
+ tl.store(rope_ptr + rope_local, rotated.to(tl.bfloat16), mask=rope_mask)
+ return
+
+
+def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int, is_in_indexer: bool = False):
+ if compress_ratio == 0:
+ return None
+
+ mem_manager: DeepseekV4MemoryManager = infer_state.mem_manager
+ if is_in_indexer:
+ # c4 Lightning-Indexer key compression: same window/state machinery as the c4 latent
+ # compressor but with index_head_dim, a separate state pool, and a DENSE bf16 scratch
+ # out_buffer (the kernel's OUTPUT_BF16 path); the fp8 pack into c4_indexer_pool is done
+ # afterwards by pack_indexer_k_to_cache.
+ assert compress_ratio == 4, "只有 c4(CSA) 层有 indexer-K"
+ out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)]
+ state_buffer = mem_manager.get_c4_indexer_state_buffer(layer_idx)
+ out_buffer = torch.empty(
+ (infer_state.mem_index.numel(), mem_manager.indexer_head_dim),
+ dtype=torch.bfloat16,
+ device=infer_state.mem_index.device,
+ )
+ out_page_size = 1 # unused under OUTPUT_BF16 (token-indexed dense scratch, not paged)
+ else:
+ if compress_ratio == 4:
+ out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)]
+ state_buffer = mem_manager.get_c4_state_buffer(layer_idx)
+ out_pool = mem_manager.c4_pool
+ elif compress_ratio == 128:
+ out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)]
+ state_buffer = mem_manager.get_c128_state_buffer(layer_idx)
+ out_pool = mem_manager.c128_pool
+ else:
+ raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}")
+ out_buffer = mem_manager.get_compressed_kv_buffer(layer_idx)
+ out_page_size = out_pool.page_size
+
+ token_to_batch_idx = infer_state.b_req_idx
+ if infer_state.is_prefill:
+ token_to_batch_idx = getattr(infer_state, "_dsv4_token_to_batch_idx", None)
+ if token_to_batch_idx is None or token_to_batch_idx.numel() != infer_state.position_ids.numel():
+ q_lens = (infer_state.b_seq_len - infer_state.b_ready_cache_len).to(torch.long)
+ batch_idx = torch.arange(infer_state.b_req_idx.shape[0], device=infer_state.b_req_idx.device)
+ token_to_batch_idx = torch.repeat_interleave(batch_idx, q_lens).to(torch.int32)
+ infer_state._dsv4_token_to_batch_idx = token_to_batch_idx
+
+ return CoreCompressorMetadata(
+ layer_idx=layer_idx,
+ compress_ratio=compress_ratio,
+ out_slots=out_slots,
+ mem_index=infer_state.mem_index,
+ state_buffer=state_buffer,
+ out_buffer=out_buffer,
+ out_page_size=out_page_size,
+ position_ids=infer_state.position_ids,
+ b_req_idx=infer_state.b_req_idx,
+ b_seq_len=infer_state.b_seq_len,
+ b_ready_cache_len=infer_state.b_ready_cache_len,
+ b_q_start_loc=infer_state.b_q_start_loc,
+ req_to_token_indexs=infer_state.req_manager.req_to_token_indexs,
+ full_to_swa_indexs=mem_manager.full_to_swa_indexs,
+ token_to_batch_idx=token_to_batch_idx,
+ kv_score=None,
+ is_prefill=infer_state.is_prefill,
+ )
+
+
+def prepare_partial_states(
+ *,
+ kv_score: torch.Tensor,
+ metadata: Optional[CoreCompressorMetadata],
+ ape: torch.Tensor,
+ compress_ratio: int,
+):
+ if metadata is None or kv_score.shape[0] == 0:
+ return
+ state_width = kv_score.shape[-1] // 2
+ _add_ape_to_kv_score_kernel[(kv_score.shape[0],)](
+ kv_score,
+ kv_score.stride(0),
+ kv_score.stride(1),
+ ape,
+ ape.stride(0),
+ metadata.position_ids,
+ STATE_WIDTH=state_width,
+ COMPRESS_RATIO=compress_ratio,
+ BLOCK=triton.next_power_of_2(state_width),
+ num_warps=4,
+ )
+ return
+
+
+def fused_compress(
+ *,
+ kv_score: torch.Tensor,
+ metadata: Optional[CoreCompressorMetadata],
+ norm_weight: torch.Tensor,
+ ape: torch.Tensor,
+ eps: float,
+ head_dim: int,
+ qk_rope_head_dim: int,
+ compress_ratio: int,
+ cos_table: torch.Tensor,
+ sin_table: torch.Tensor,
+ output_bf16: bool = False,
+):
+ if metadata is None or kv_score.shape[0] == 0:
+ return
+
+ state_width = kv_score.shape[-1] // 2
+ state_last_dim = metadata.state_buffer.shape[-1]
+ is_c4 = compress_ratio == 4
+ state_ring = DSV4_C4_STATE_RING if is_c4 else DSV4_C128_STATE_RING
+ block_state = triton.next_power_of_2(state_width)
+ block_head = triton.next_power_of_2(head_dim)
+
+ _fused_compress_norm_rope_insert_kernel[(kv_score.shape[0],)](
+ kv_score,
+ kv_score.stride(0),
+ kv_score.stride(1),
+ metadata.state_buffer,
+ metadata.position_ids,
+ metadata.token_to_batch_idx,
+ metadata.b_req_idx,
+ metadata.b_seq_len,
+ metadata.b_ready_cache_len if metadata.b_ready_cache_len is not None else metadata.b_seq_len,
+ metadata.b_q_start_loc if metadata.b_q_start_loc is not None else metadata.b_seq_len,
+ metadata.req_to_token_indexs,
+ metadata.req_to_token_indexs.stride(0),
+ metadata.full_to_swa_indexs,
+ metadata.out_slots,
+ norm_weight,
+ eps,
+ cos_table,
+ cos_table.stride(0),
+ sin_table,
+ sin_table.stride(0),
+ metadata.out_buffer,
+ HEAD_DIM=head_dim,
+ STATE_WIDTH=state_width,
+ STATE_LAST_DIM=state_last_dim,
+ COMPRESS_RATIO=compress_ratio,
+ WINDOW_SIZE=compress_ratio * (2 if is_c4 else 1),
+ IS_C4=is_c4,
+ IS_PREFILL=metadata.is_prefill,
+ SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE,
+ STATE_RING=state_ring,
+ ROPE_HEAD_DIM=qk_rope_head_dim,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ SCALE_MIN=1e-4,
+ NOPE_DIM=head_dim - qk_rope_head_dim,
+ QUANT_BLOCK=64,
+ SCALE_BYTES=(head_dim - qk_rope_head_dim) // 64 + 1,
+ PAGE_SIZE=metadata.out_page_size,
+ BYTES_PER_PAGE=metadata.out_buffer.shape[-1],
+ BLOCK=block_head,
+ OUTPUT_BF16=output_bf16,
+ num_warps=4,
+ )
+
+ _save_partial_states_kernel[(kv_score.shape[0],)](
+ kv_score,
+ kv_score.stride(0),
+ kv_score.stride(1),
+ metadata.position_ids,
+ metadata.token_to_batch_idx,
+ metadata.b_req_idx,
+ metadata.b_seq_len,
+ metadata.mem_index,
+ metadata.full_to_swa_indexs,
+ metadata.state_buffer,
+ STATE_WIDTH=state_width,
+ STATE_LAST_DIM=state_last_dim,
+ COMPRESS_RATIO=compress_ratio,
+ IS_C4=is_c4,
+ IS_PREFILL=metadata.is_prefill,
+ SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE,
+ STATE_RING=state_ring,
+ BLOCK=block_state,
+ num_warps=4,
+ )
+ return
diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py
new file mode 100644
index 0000000000..080ebabd89
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py
@@ -0,0 +1,75 @@
+import torch
+
+try:
+ import vllm.model_executor.layers.mhc # noqa: F401
+except Exception as e:
+ raise RuntimeError("DeepSeek-V4 requires vLLM mHC custom ops; failed to import vllm MHC kernels") from e
+
+
+# vllm DeepseekV4DecoderLayer.hc_post_alpha
+HC_POST_ALPHA = 2.0
+
+
+def hc_pre(residual, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps):
+ """Standalone hc_pre for the first layer. residual:[T, hc, dim] ->
+ (x[T,dim], residual, post_mix[T,hc,1], res_mix[T,hc,hc]); the sub-layer RMSNorm is fused via norm_weight."""
+ post_mix, res_mix, x = torch.ops.vllm.mhc_pre_tilelang(
+ residual=residual,
+ fn=hc_fn,
+ hc_scale=hc_scale,
+ hc_base=hc_base,
+ rms_eps=rms_eps,
+ hc_pre_eps=hc_eps,
+ hc_sinkhorn_eps=hc_eps,
+ hc_post_mult_value=HC_POST_ALPHA,
+ sinkhorn_repeat=sinkhorn_iters,
+ norm_weight=norm_weight,
+ norm_eps=norm_eps,
+ )
+ return x, residual, post_mix, res_mix
+
+
+def hc_fused_post_pre(
+ x, residual, post_mix, res_mix, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps
+):
+ """hc_post of the previous sub-layer fused with hc_pre of the next one (norm fused too).
+ Returns (x[T,dim], residual[T,hc,dim], post_mix, res_mix)."""
+ residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre_tilelang(
+ x=x,
+ residual=residual,
+ post_layer_mix=post_mix,
+ comb_res_mix=res_mix,
+ fn=hc_fn,
+ hc_scale=hc_scale,
+ hc_base=hc_base,
+ rms_eps=rms_eps,
+ hc_pre_eps=hc_eps,
+ hc_sinkhorn_eps=hc_eps,
+ hc_post_mult_value=HC_POST_ALPHA,
+ sinkhorn_repeat=sinkhorn_iters,
+ norm_weight=norm_weight,
+ norm_eps=norm_eps,
+ )
+ return x, residual, post_mix, res_mix
+
+
+def hc_post(x, residual, post_mix, res_mix):
+ """Complete the hc_post left pending by the last layer. -> streams [T, hc, dim]."""
+ return torch.ops.vllm.mhc_post_tilelang(x, residual, post_mix, res_mix)
+
+
+def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, rms_eps, hc_eps, alloc_func):
+ """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim]."""
+ out = alloc_func((streams.shape[0], dim), dtype=streams.dtype, device=streams.device)
+ torch.ops.vllm.hc_head_fused_kernel_tilelang(
+ streams.view(-1, hc_mult, dim).contiguous(),
+ hc_fn,
+ hc_scale,
+ hc_base,
+ out,
+ dim,
+ rms_eps,
+ hc_eps,
+ hc_mult,
+ )
+ return out
diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py
new file mode 100644
index 0000000000..8eddfb3b9d
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py
@@ -0,0 +1,27 @@
+from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
+from .hyper_connection import hc_head, hc_post
+from ..infer_struct import DeepseekV4InferStateInfo
+
+
+class DeepseekV4PostLayerInfer(LlamaPostLayerInfer):
+ """Collapse the hc_mult residual streams (hc_head) to [T, hidden], then final norm + lm_head."""
+
+ def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ cfg = layer_weight.network_config_
+ if isinstance(input_embdings, tuple):
+ # truncated-layer runs (autotune warmup) end before the last layer's _hc_ffn_out
+ # collapse; finish the pending hc_post here.
+ streams = hc_post(*input_embdings)
+ input_embdings = streams.reshape(streams.shape[0], -1)
+ collapsed = hc_head(
+ input_embdings,
+ layer_weight.hc_head_fn_.weight,
+ layer_weight.hc_head_scale_.weight,
+ layer_weight.hc_head_base_.weight,
+ cfg["hc_mult"],
+ cfg["hidden_size"],
+ cfg["rms_norm_eps"],
+ cfg.get("hc_eps", 1e-6),
+ self.alloc_tensor,
+ )
+ return super().token_forward(collapsed, infer_state, layer_weight)
diff --git a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py
new file mode 100644
index 0000000000..b95f5a14a8
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py
@@ -0,0 +1,24 @@
+import torch
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
+from lightllm.distributed.communication_op import all_reduce
+from ..infer_struct import DeepseekV4InferStateInfo
+
+
+class DeepseekV4PreLayerInfer(LlamaPreLayerInfer):
+ """Token embedding, then expand to the hc_mult parallel residual streams [T, hc_mult*hidden]."""
+
+ def __init__(self, network_config):
+ super().__init__(network_config)
+ self.hc_mult = network_config["hc_mult"]
+ return
+
+ def context_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
+ t, hidden = input_embdings.shape
+ return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden)
+
+ def token_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
+ t, hidden = input_embdings.shape
+ return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden)
diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..57b171a8d1
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,712 @@
+import os
+import torch
+import torch.distributed as dist
+from lightllm.common.basemodel import TransformerLayerInferTpl
+from lightllm.common.basemodel.attention.base_att import AttControl
+from lightllm.distributed.communication_op import all_reduce
+from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer
+from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight
+from lightllm.utils.envs_utils import get_env_start_args
+from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor
+from lightllm.utils.vllm_utils import vllm_ops
+from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post
+from .compressor import fused_compress as fused_compress_op
+from .compressor import prepare_partial_states
+from .compressor import prepare_compress_states
+from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
+from ..infer_struct import DeepseekV4InferStateInfo
+
+
+class DeepseekV4TransformerLayerInfer(Deepseek3_2TransformerLayerInfer):
+ def __init__(self, layer_num, network_config):
+ TransformerLayerInferTpl.__init__(self, layer_num, network_config)
+ self.eps_ = network_config["rms_norm_eps"]
+ self.embed_dim_ = network_config["hidden_size"]
+ self.num_heads = network_config["num_attention_heads"]
+ self.head_dim_ = network_config["head_dim"]
+ self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
+ self.qk_nope_head_dim = self.head_dim_ - self.qk_rope_head_dim
+ self.v_head_dim = self.head_dim_
+ self.o_groups = network_config["o_groups"]
+ self.hc_mult = network_config["hc_mult"]
+ self.sinkhorn_iters = network_config["hc_sinkhorn_iters"]
+ self.hc_eps = network_config["hc_eps"]
+ self.compress_ratio = network_config["compress_ratios"][layer_num]
+ self.is_hash = layer_num < network_config["num_hash_layers"]
+ self.is_last_layer = layer_num == network_config["n_layer"] - 1
+ # complex64 rope table for this layer's variant (sliding / compressed); set by
+ # DeepseekV4TpPartModel._init_to_get_rotary once the tables are built. The full compress
+ # cos/sin tables (compressor entry rope uses entry positions, not token positions) are
+ # wired there too.
+ self.freqs_cis = None
+ self.cos_compress_table = None
+ self.sin_compress_table = None
+ self.num_experts_per_tok = network_config["num_experts_per_tok"]
+ self.routed_scaling_factor = network_config["routed_scaling_factor"]
+ self.swiglu_limit = network_config["swiglu_limit"]
+ self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
+ self.tp_q_head_num_ = self.num_heads // self.tp_world_size_
+ self.tp_groups = self.o_groups // self.tp_world_size_
+ self.enable_ep_moe = get_env_start_args().enable_ep_moe
+ self.compressor = CompressorInfer(
+ layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_
+ )
+ self.index_infer = DeepseekV4IndexInfer(
+ layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_
+ )
+
+ # ------------------------------------------------------------------ forward (HC-threaded)
+ def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWeight):
+ """Layer input -> attention input (attn_norm fused). First layer gets the raw streams
+ and runs a standalone hc_pre; later layers get (x, residual, post_mix, res_mix) and fuse
+ the previous layer's ffn hc_post with this layer's attn hc_pre."""
+ if torch.is_tensor(input_embdings):
+ residual = input_embdings.view(-1, self.hc_mult, self.embed_dim_)
+ return hc_pre(
+ residual,
+ layer_weight.hc_attn_fn_.weight,
+ layer_weight.hc_attn_scale_.weight,
+ layer_weight.hc_attn_base_.weight,
+ self.eps_,
+ self.hc_eps,
+ self.sinkhorn_iters,
+ layer_weight.attn_norm_.weight,
+ self.eps_,
+ )
+ x, residual, post_mix, res_mix = input_embdings
+ return hc_fused_post_pre(
+ x,
+ residual,
+ post_mix,
+ res_mix,
+ layer_weight.hc_attn_fn_.weight,
+ layer_weight.hc_attn_scale_.weight,
+ layer_weight.hc_attn_base_.weight,
+ self.eps_,
+ self.hc_eps,
+ self.sinkhorn_iters,
+ layer_weight.attn_norm_.weight,
+ self.eps_,
+ )
+
+ def _hc_ffn_in(self, x, residual, post_mix, res_mix, layer_weight: DeepseekV4TransformerLayerWeight):
+ """Attention output -> ffn input (ffn_norm fused): fused attn hc_post + ffn hc_pre."""
+ return hc_fused_post_pre(
+ x,
+ residual,
+ post_mix,
+ res_mix,
+ layer_weight.hc_ffn_fn_.weight,
+ layer_weight.hc_ffn_scale_.weight,
+ layer_weight.hc_ffn_base_.weight,
+ self.eps_,
+ self.hc_eps,
+ self.sinkhorn_iters,
+ layer_weight.ffn_norm_.weight,
+ self.eps_,
+ )
+
+ def _hc_ffn_out(self, x, residual, post_mix, res_mix):
+ """Mid layers leave the ffn hc_post pending for the next layer's fused post+pre; the last
+ layer completes it and hands the flat streams [T, hc_mult*hidden] back to the model loop."""
+ if not self.is_last_layer:
+ return x, residual, post_mix, res_mix
+ streams = hc_post(x, residual, post_mix, res_mix)
+ return streams.reshape(streams.shape[0], -1)
+
+ def context_forward(
+ self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight)
+ x = self.context_attention_forward(x, infer_state, layer_weight)
+ x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight)
+ x = self._ffn(x, infer_state, layer_weight)
+ return self._hc_ffn_out(x, residual, post_mix, res_mix)
+
+ def token_forward(
+ self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight)
+ x = self.token_attention_forward(x, infer_state, layer_weight)
+ x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight)
+ x = self._ffn(x, infer_state, layer_weight)
+ return self._hc_ffn_out(x, residual, post_mix, res_mix)
+
+ # ------------------------------------------------------------------ shared projections / cache
+ def _select_rope(self, infer_state: DeepseekV4InferStateInfo):
+ if self.compress_ratio:
+ return infer_state.position_cos_compress, infer_state.position_sin_compress
+ return infer_state.position_cos_sliding, infer_state.position_sin_sliding
+
+ def _get_qkv(
+ self,
+ input: torch.Tensor,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ ):
+ from lightllm.third_party.sglang_jit.dsv4 import fused_q_norm_rope
+
+ input = self._tpsp_allgather(input=input, infer_state=infer_state)
+ T = input.shape[0]
+ qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(input), eps=self.eps_)
+ q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_head_num_, self.head_dim_)
+ # per-(token, head) weightless self-RMSNorm + interleaved rope on the last rope_dim dims,
+ # fused in one sglang dsv4 jit kernel (fp32 norm/rotation, bf16 in between -- same as eager).
+ q = self.alloc_tensor(q_in.shape, dtype=q_in.dtype, device=q_in.device)
+ fused_q_norm_rope(q_in, q, self.eps_, self.freqs_cis, infer_state.position_ids)
+ # kv: rmsnorm + rope + fp8 pack + scatter 进 swa 池,一个 sglang jit kernel 完成
+ # (同 sglang _compute_kv_to_cache),替代 eager norm/rope/cat + _post_cache_kv。
+ # bf16 kv 中间量没有其他消费者: flashmla 路径注意力读 cache,压缩器/indexer 取 x。
+ infer_state.mem_manager.pack_mla_kv_to_cache_fused_norm_rope(
+ layer_index=self.layer_num_,
+ mem_index=infer_state.mem_index,
+ kv=layer_weight.wkv_.mm(input),
+ kv_weight=layer_weight.kv_norm_.weight,
+ eps=self.eps_,
+ freqs_cis=self.freqs_cis,
+ positions=infer_state.position_ids,
+ )
+ return q, qa, input
+
+ def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight):
+ # o: [T, tp_q_head_num_, head_dim_] after inverse rope -> grouped low-rank O -> [T, embed_dim_]
+ position_cos, position_sin = self._select_rope(infer_state)
+ rotary_emb_fwd(o[..., -self.qk_rope_head_dim :], None, position_cos, position_sin, inverse=True)
+ T = o.shape[0]
+ o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in]
+ o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora]
+ o = layer_weight.wo_b_.mm(o)
+ return self._tpsp_reduce(input=o, infer_state=infer_state)
+
+ # ------------------------------------------------------------------ attention (prefill)
+ def context_attention_forward(
+ self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ # _get_qkv writes the chunk's packed latent into the swa pool (fused kernel) before
+ # attention reads it back via full_to_swa indices (this custom forward bypasses the
+ # tpl _post_cache_kv path).
+ q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight)
+ o = self._context_attention_wrapper_run(q, q_lora, full_x, infer_state, layer_weight)
+ return self._get_o(o, infer_state, layer_weight)
+
+ def _context_attention_wrapper_run(
+ self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ if torch.cuda.is_current_stream_capturing():
+ q = q.contiguous()
+ q_lora = q_lora.contiguous()
+ x = x.contiguous()
+ _q = tensor_to_no_ref_tensor(q)
+ _q_lora = tensor_to_no_ref_tensor(q_lora)
+ _x = tensor_to_no_ref_tensor(x)
+
+ pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
+ pre_capture_graph.__exit__(None, None, None)
+
+ infer_state.prefill_cuda_graph_create_graph_obj()
+ infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
+ # Same graph-split output handoff as the template, but avoid its dry-run because
+ # DSV4 attention mutates compressor/cache state before returning.
+ o = self.alloc_tensor((q.shape[0], self.tp_q_head_num_, self.head_dim_), dtype=q.dtype, device=q.device)
+ _o = tensor_to_no_ref_tensor(o)
+
+ def att_func(new_infer_state: DeepseekV4InferStateInfo):
+ tmp_o = self._context_attention_kernel(_q, _q_lora, _x, new_infer_state, layer_weight)
+ assert tmp_o.shape == _o.shape
+ _o.copy_(tmp_o)
+ return
+
+ infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph)
+ return o
+
+ return self._context_attention_kernel(q, q_lora, x, infer_state, layer_weight)
+
+ def _context_attention_kernel(
+ self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ self.compressor.prepare_states(x, infer_state, layer_weight)
+ self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table)
+ # Write this step's c4 Lightning-Indexer keys (no-op off c4) BEFORE build_metadata so the
+ # scorer (gather + deep_gemm.fp8_mqa_logits) reads fresh+accumulated entries from the indexer pool.
+ self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table)
+ # Build the FINAL flash_mla index tensors here (model side), so att_control is a thin
+ # transport of ready-to-forward tensors -- not indexer raw material. Must stay after
+ # fused_compress (c4 reads the indexer-K pool it writes) and before prefill_att (keeps the
+ # c4 scorer/topk at the same cuda-graph capture position).
+ meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight)
+ att_control = AttControl(
+ nsa_prefill=True,
+ nsa_prefill_dict={
+ "flashmla_kvcache": True,
+ "layer_index": self.layer_num_,
+ "compress_ratio": self.compress_ratio,
+ "head_dim_v": self.v_head_dim,
+ "softmax_scale": self.softmax_scale,
+ "attn_sink": layer_weight.attn_sink_.weight,
+ **meta,
+ },
+ )
+ out = infer_state.prefill_att_state.prefill_att(
+ q=q,
+ k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_),
+ v=None,
+ att_control=att_control,
+ )
+ pad_q_len = getattr(infer_state, "_dsv4_prefill_pad_q_len", 0)
+ if pad_q_len:
+ # pad 行读 HOLD 槽位(参见 infer_struct._dsv4_prefill_pad_q_len),清零以保持确定性
+ out[-pad_q_len:] = 0
+ return out
+
+ # ------------------------------------------------------------------ attention (decode)
+ def token_attention_forward(
+ self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight)
+ o = self._token_attention_kernel(q, q_lora, full_x, infer_state, layer_weight)
+ return self._get_o(o, infer_state, layer_weight)
+
+ def _token_attention_kernel(
+ self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ self.compressor.prepare_states(x, infer_state, layer_weight)
+ self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table)
+ self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table)
+ meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight)
+ att_control = AttControl(
+ nsa_decode=True,
+ nsa_decode_dict={
+ "flashmla_kvcache": True,
+ "layer_index": self.layer_num_,
+ "compress_ratio": self.compress_ratio,
+ "head_dim_v": self.v_head_dim,
+ "softmax_scale": self.softmax_scale,
+ "attn_sink": layer_weight.attn_sink_.weight,
+ **meta,
+ },
+ )
+ return infer_state.decode_att_state.decode_att(
+ q=q,
+ k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_),
+ v=None,
+ att_control=att_control,
+ )
+
+ # ------------------------------------------------------------------ moe
+ def _routed_experts(self, x, weights, indices, layer_weight: DeepseekV4TransformerLayerWeight):
+ return layer_weight.experts_.experts_with_preselected(
+ input_tensor=x,
+ topk_weights=weights,
+ topk_ids=indices,
+ clamp_limit=float(self.swiglu_limit),
+ )
+
+ def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight):
+ x = x.view(-1, self.embed_dim_)
+ if not self.enable_ep_moe:
+ x = self._tpsp_allgather(input=x, infer_state=infer_state)
+
+ logits = layer_weight.gate_weight_.mm(x.float()).contiguous()
+ weights, indices = self._select_experts(logits, infer_state, layer_weight)
+ # shared expert 必须先于 routed 计算: fp8 路径 (FuseMoeTriton) 的 fused_experts
+ # 是 inplace 的,_routed_experts 返回后 x 已被覆盖为 routed 输出。
+ # 复用 Llama 的 _ffn_tp: fused gate_up matmul + silu_and_mul triton kernel,无 swiglu clamp,
+ # 对齐参考 DeepseekV4MLP(=LlamaMLP)。swiglu_limit clamp 只属于 routed 专家 (见 _routed_experts)。
+ shared = self._ffn_tp(input=x, infer_state=infer_state, layer_weight=layer_weight)
+ routed = self._routed_experts(x, weights, indices, layer_weight)
+ if self.enable_ep_moe:
+ if self.tp_world_size_ > 1:
+ all_reduce(
+ shared,
+ op=dist.ReduceOp.SUM,
+ group=infer_state.dist_group,
+ async_op=False,
+ )
+ return routed + shared
+ out = routed + shared
+ return self._tpsp_reduce(input=out, infer_state=infer_state)
+
+ def _select_experts(
+ self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ return self._select_experts_vllm(logits, infer_state, layer_weight)
+
+ def _select_experts_vllm(
+ self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight
+ ):
+ M = logits.shape[0]
+ bias = None
+ input_tokens = None
+ hash_indices_table = None
+ indices_dtype = torch.int64
+ if self.is_hash:
+ hash_indices_table = layer_weight.gate_tid2eid_.weight
+ if not hash_indices_table.is_contiguous():
+ hash_indices_table = hash_indices_table.contiguous()
+ indices_dtype = hash_indices_table.dtype
+ input_tokens = infer_state.input_ids.to(dtype=indices_dtype).contiguous()
+ else:
+ bias = layer_weight.gate_bias_.weight
+
+ weights = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.float32, device=logits.device)
+ indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=indices_dtype, device=logits.device)
+ token_expert_indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.int32, device=logits.device)
+ vllm_ops.topk_hash_softplus_sqrt(
+ weights,
+ indices,
+ token_expert_indices,
+ logits,
+ True,
+ self.routed_scaling_factor,
+ bias,
+ input_tokens,
+ hash_indices_table,
+ )
+ return weights, indices.long()
+
+
+class CompressorInfer:
+ """Window-softmax compressor. is_in_indexer=False compresses the c4/c128 latent KV into the
+ paged fp8 slab (attention extra_k); is_in_indexer=True reuses the SAME machinery (mirroring
+ sglang's Compressor(is_in_indexer=...)) with the indexer weights/dims/state pool to produce the
+ per-c4-entry Lightning-Indexer keys, emitted as dense bf16 (OUTPUT_BF16) then fp8-packed into
+ c4_indexer_pool by the caller. Indexer mode is c4-only."""
+
+ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int, is_in_indexer: bool = False):
+ super().__init__()
+ self.layer_idx_ = layer_idx
+ self.network_config_ = network_config
+ self.tp_world_size_ = tp_world_size
+ self.is_in_indexer = is_in_indexer
+ self.compress_ratio = network_config["compress_ratios"][layer_idx]
+ self.head_dim = network_config["head_dim"]
+ self.index_head_dim = network_config["index_head_dim"]
+ self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
+ self.eps = network_config["rms_norm_eps"]
+ self._metadata = None
+
+ def prepare_states(
+ self,
+ x: torch.Tensor,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ ):
+ self._metadata = prepare_compress_states(
+ infer_state=infer_state,
+ layer_idx=self.layer_idx_,
+ compress_ratio=self.compress_ratio,
+ is_in_indexer=self.is_in_indexer,
+ )
+ if self._metadata is not None:
+ if self.is_in_indexer:
+ # indexer wkv/wgate are two separate replicated weights; cat -> [T, 2*coff*idx_hd]
+ # (same [kv | score] layout the fused compressor_wkv_gate_ produces for attention).
+ kv = layer_weight.idx_cmp_wkv_.mm(x)
+ gate = layer_weight.idx_cmp_wgate_.mm(x)
+ self._metadata.kv_score = torch.cat([kv, gate], dim=-1).float()
+ ape = layer_weight.idx_cmp_ape_.weight
+ else:
+ self._metadata.kv_score = layer_weight.compressor_wkv_gate_.mm(x).float()
+ ape = layer_weight.compressor_ape_.weight
+ prepare_partial_states(
+ kv_score=self._metadata.kv_score,
+ metadata=self._metadata,
+ ape=ape,
+ compress_ratio=self.compress_ratio,
+ )
+ return self._metadata
+
+ def fused_compress(
+ self,
+ infer_state: DeepseekV4InferStateInfo,
+ layer_weight: DeepseekV4TransformerLayerWeight,
+ cos_table: torch.Tensor,
+ sin_table: torch.Tensor,
+ ):
+ if self.compress_ratio == 0:
+ return None
+ metadata = self._metadata
+ if metadata is None:
+ raise RuntimeError("DeepSeek-V4 compressor.prepare_states must run before fused_compress")
+ if self.is_in_indexer:
+ norm_weight = layer_weight.idx_cmp_norm_.weight
+ ape = layer_weight.idx_cmp_ape_.weight
+ head_dim = self.index_head_dim
+ else:
+ norm_weight = layer_weight.compressor_norm_.weight
+ ape = layer_weight.compressor_ape_.weight
+ head_dim = self.head_dim
+ return fused_compress_op(
+ kv_score=metadata.kv_score,
+ metadata=metadata,
+ norm_weight=norm_weight,
+ ape=ape,
+ eps=self.eps,
+ head_dim=head_dim,
+ qk_rope_head_dim=self.qk_rope_head_dim,
+ compress_ratio=self.compress_ratio,
+ cos_table=cos_table,
+ sin_table=sin_table,
+ output_bf16=self.is_in_indexer,
+ )
+
+
+class DeepseekV4IndexInfer:
+ """Model-side builder for the FlashMLA sparse-index metadata. Mirrors deepseek3_2's NsaInfer
+ boundary (the model owns ALL index construction; the attention backend only forwards final
+ tensors to flash_mla.flash_mla_with_kvcache) AND its c4 implementation: hadamard'd fp8 q/K, a
+ ragged gather of the compressed c4 keys, deep_gemm.fp8_mqa_logits, then topk -- adapted for the
+ replicated indexer (no gather-q/all_reduce), the c4-compressed entry space, and topk-512 (no
+ inheritance only because of those data-shape differences). swa metadata is precomputed in
+ init_some_extra_state; this class owns the c4/c128 entry gather (build_compress_index) AND the c4
+ Lightning-Indexer scoring (gather + deep_gemm.fp8_mqa_logits + topk). Holds only static per-layer
+ config; all per-request data flows in via args. Invoke from _context/_token_attention_kernel
+ (after compressor.fused_compress, before *_att) so the c4 scorer/topk keep the same cuda-graph
+ capture position they had when this lived in the backend. The indexer is replicated (no TP collective)."""
+
+ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int):
+ self.layer_idx_ = layer_idx
+ self.compress_ratio = network_config["compress_ratios"][layer_idx]
+ self.index_topk = network_config["index_topk"]
+ self.index_head_dim = network_config["index_head_dim"]
+ self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
+ self.index_n_heads = network_config["index_n_heads"]
+ self.tp_world_size_ = tp_world_size
+ self.indexer_score_scale = self.index_head_dim ** -0.5
+ self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5
+ # c4 layers own a second compressor (is_in_indexer) that writes the Lightning-Indexer key
+ # pool every step; _c4_indices gathers it back + scores via deep_gemm.fp8_mqa_logits.
+ self.indexer_compressor = (
+ CompressorInfer(layer_idx, network_config, tp_world_size, is_in_indexer=True)
+ if self.compress_ratio == 4
+ else None
+ )
+
+ def write_indexer_k(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight, cos_table, sin_table):
+ """c4-only: compress this step's tokens into per-c4-entry indexer keys and pack them into
+ c4_indexer_pool. MUST run before build_metadata so the scorer (gather + deep_gemm.fp8_mqa_logits)
+ reads the finished entries; runs every step (incl. in the decode graph) so keys accumulate for
+ later long-context scoring. No-op on c128 / dense layers."""
+ if self.compress_ratio != 4:
+ return
+ self.indexer_compressor.prepare_states(x, infer_state, layer_weight)
+ self.indexer_compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table)
+ scratch = self.indexer_compressor._metadata.out_buffer # [T, index_head_dim] bf16 (group-end rows valid)
+ # Rotate K (post norm+rope) by the SAME 1/sqrt(d) Hadamard the q kernel applies, so
+ # (Hq)·(Hk)=q·k (H orthogonal) and the fp8 quant of K stays accurate.
+ from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform
+
+ scratch = hadamard_transform(scratch, scale=self.index_head_dim ** -0.5)
+ mem_manager = infer_state.mem_manager
+ positions = infer_state.position_ids
+ out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)]
+ # only group-end tokens finish a c4 entry; mask the rest to -1 so the packer skips them
+ # (mid-group tokens share the group's c4 slot -> avoids racing a finished slot).
+ completed = ((positions + 1) % 4 == 0) & (out_slots >= 0)
+ masked_slots = torch.where(completed, out_slots, torch.full_like(out_slots, -1)).to(torch.int32)
+ mem_manager.pack_indexer_k_to_cache(self.layer_idx_, masked_slots, scratch)
+
+ def build_metadata(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ """Return the final flash_mla index tensors for this layer's compress variant. swa indices and
+ the per-token req_idx are layer-independent and precomputed once in init_some_extra_state
+ (read here); only the c4 scorer / c128 gather is per-layer. The backend pairs these with the
+ (data-independent, layer-keyed) fp8 cache-byte views it owns."""
+ swa_indices = infer_state.dsv4_swa_indices.unsqueeze(1)
+ swa_lengths = infer_state.dsv4_swa_lengths
+ req_idx = infer_state.dsv4_sparse_req_idx
+ positions = infer_state.position_ids
+ extra_indices = extra_lengths = None
+ if self.compress_ratio == 4:
+ idx_q_fp8, weights = self._indexer_q_weight(x, q_lora, infer_state, layer_weight)
+ extra_indices, extra_lengths = self._c4_indices(infer_state, idx_q_fp8, weights, positions)
+ elif self.compress_ratio == 128:
+ extra_indices, extra_lengths = self._c128_indices(infer_state, req_idx, positions)
+ return {
+ "swa_indices": swa_indices,
+ "swa_lengths": swa_lengths,
+ "extra_indices": extra_indices,
+ "extra_lengths": extra_lengths,
+ }
+
+ def _c128_indices(self, infer_state: DeepseekV4InferStateInfo, req_idx, positions):
+ from ..triton_kernel.build_compress_index_dsv4 import build_compress_index
+
+ cap = ((max(1, int(infer_state.max_kv_seq_len) // 128) + 63) // 64) * 64
+ indices, lengths = build_compress_index(
+ req_idx,
+ positions,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.mem_manager.full_to_c128_indexs,
+ ratio=128,
+ cap=cap,
+ )
+ return indices.unsqueeze(1), lengths
+
+ def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight):
+ """fp8 indexer q (mirrors deepseek3_2 NsaInfer): wq_b -> rope(last rope dims) -> 1/sqrt(d)
+ Hadamard -> per-token fp8 quant. Returns (idx_q_fp8 [T,H,d], weights [T,H]); the per-token q
+ fp8 scale and the head_dim^-0.5 * n_heads^-0.5 score scale are folded into weights -- the
+ deep_gemm.fp8_mqa_logits contract (fp8 q carries no companion scale). Replicated -> full heads."""
+ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant
+ from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform
+
+ cos_tok = infer_state.position_cos_compress
+ sin_tok = infer_state.position_sin_compress
+ token_num = q_lora.shape[0]
+ if x.shape[0] != token_num:
+ raise RuntimeError(
+ f"DeepSeek-V4 indexer expects full-token hidden states, got x={x.shape[0]} q_lora={token_num}"
+ )
+ idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(token_num, self.index_n_heads, self.index_head_dim)
+ rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok)
+ idx_q = hadamard_transform(idx_q, scale=self.index_head_dim ** -0.5)
+ idx_q_fp8, q_scale = act_quant(idx_q, self.index_head_dim, None) # fp8 [T,H,d], scale [T,H,1]
+ weights = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale # [T, H]
+ weights = weights.unsqueeze(-1) * q_scale # fold per-token q scale
+ return idx_q_fp8, weights.squeeze(-1).contiguous()
+
+ def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, positions):
+ """c4 scorer via ds3.2-style gather + deep_gemm.fp8_mqa_logits. Gather each request's causal c4
+ keys into a padded-per-request ragged fp8 buffer (k row r*c4_cap+e), score every query token
+ over its absolute [ks, ke) range, then masked topk-512 -> c4 slots. Fixed shapes (c4_cap pinned
+ per graph bucket) keep the decode cuda graph capturable."""
+ mem_manager = infer_state.mem_manager
+ index_topk = self.index_topk
+ max_entries = max(1, int(infer_state.max_kv_seq_len) // 4)
+ c4_cap = ((max_entries + 63) // 64) * 64
+
+ # entry space fits the budget -> every causal entry is selected; no scoring needed. The
+ # captured decode graph (graph_max_len -> max_entries > topk) always takes the scorer branch
+ # below, so this only shortcuts tiny eager contexts.
+ if max_entries <= index_topk:
+ from ..triton_kernel.build_compress_index_dsv4 import build_compress_index
+
+ slots, lengths = build_compress_index(
+ infer_state.dsv4_sparse_req_idx,
+ positions,
+ infer_state.req_manager.req_to_token_indexs,
+ mem_manager.full_to_c4_indexs,
+ ratio=4,
+ cap=c4_cap,
+ )
+ return slots.unsqueeze(1), lengths
+
+ b_req_idx = infer_state.b_req_idx
+ batch = b_req_idx.shape[0]
+ device = positions.device
+ c4_len = torch.div(infer_state.b_seq_len, 4, rounding_mode="floor").to(torch.int32) # entries/req
+
+ if os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER", "0") == "1":
+ out = self._c4_indices_paged(
+ infer_state=infer_state,
+ idx_q_fp8=idx_q_fp8,
+ weights=weights,
+ positions=positions,
+ c4_len=c4_len,
+ c4_cap=c4_cap,
+ )
+ if out is not None:
+ return out
+
+ import deep_gemm
+ from ..triton_kernel.gather_c4_indexer_k_dsv4 import gather_c4_indexer_k_ragged
+
+ k_fp8, k_scale, ragged_slots = gather_c4_indexer_k_ragged(
+ mem_manager,
+ self.layer_idx_,
+ b_req_idx,
+ c4_len,
+ c4_cap,
+ infer_state.req_manager.req_to_token_indexs,
+ )
+ # batch position of each query token -> absolute [ks, ke) into the padded buffer.
+ if infer_state.is_prefill:
+ token_batch_pos = torch.repeat_interleave(
+ torch.arange(batch, device=device, dtype=torch.int32), infer_state.b_q_seq_len
+ )
+ else:
+ token_batch_pos = torch.arange(batch, device=device, dtype=torch.int32)
+ valid_len = ((positions + 1) // 4).to(torch.int32) # causal candidate count per query
+ ks = token_batch_pos * c4_cap
+ ke = ks + valid_len
+ logits = deep_gemm.fp8_mqa_logits(
+ idx_q_fp8, (k_fp8, k_scale), weights, ks, ke, clean_logits=False, max_seqlen_k=c4_cap
+ ) # [T, c4_cap] f32, left-aligned: logits[t, j] = q_t . k[ks[t]+j]
+ col = torch.arange(c4_cap, device=device)
+ logits = logits.masked_fill(col.unsqueeze(0) >= valid_len.unsqueeze(1), float("-inf"))
+ top = logits.topk(index_topk, dim=-1).indices.to(torch.int32) # relative positions in [0, valid_len)
+ abs_idx = top + ks.unsqueeze(1) # absolute compact row
+ top_slots = ragged_slots[abs_idx.long()] # compact row -> c4 pool slot
+ invalid = top >= valid_len.unsqueeze(1) # topk over -inf padding when valid_len < index_topk
+ top_slots = torch.where(invalid, torch.full_like(top_slots, -1), top_slots)
+ topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1)
+ return top_slots.unsqueeze(1), topk_lengths.contiguous()
+
+ def _c4_indices_paged(self, infer_state, idx_q_fp8, weights, positions, c4_len, c4_cap):
+ import deep_gemm
+ from lightllm.third_party.sglang_jit.dsv4 import topk_transform_512
+ from ..triton_kernel.gather_c4_indexer_k_dsv4 import build_c4_indexer_page_table
+
+ mem_manager = infer_state.mem_manager
+ index_topk = self.index_topk
+ device = positions.device
+ b_req_idx = infer_state.b_req_idx
+ batch = b_req_idx.shape[0]
+ validate = os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER_VALIDATE", "0") == "1"
+ validate_now = validate and not torch.cuda.is_current_stream_capturing()
+
+ page_table, valid_flag = build_c4_indexer_page_table(
+ mem_manager,
+ b_req_idx,
+ c4_len,
+ c4_cap,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.req_manager.HOLD_REQUEST_ID,
+ validate=validate_now,
+ )
+ if validate_now and int(valid_flag.item()) == 0:
+ if os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER_STRICT", "0") == "1":
+ raise RuntimeError("DeepSeek-V4 paged indexer requires page-aligned c4 slots")
+ return None
+
+ if infer_state.is_prefill:
+ token_batch_pos = torch.repeat_interleave(
+ torch.arange(batch, device=device, dtype=torch.int32), infer_state.b_q_seq_len
+ )
+ row_page_table = page_table[token_batch_pos.long()].contiguous()
+ else:
+ row_page_table = page_table
+
+ valid_len = ((positions + 1) // 4).to(torch.int32)
+ ctx_lens = torch.clamp(valid_len, min=1).reshape(-1, 1).contiguous()
+ kv_cache = mem_manager.c4_indexer_pool.get_layer_buffer(mem_manager.layer_to_c4_idx[self.layer_idx_]).view(
+ mem_manager.c4_indexer_pool.num_pages,
+ mem_manager.c4_indexer_pool.page_size,
+ 1,
+ self.index_head_dim + 4,
+ )
+ metadata = deep_gemm.get_paged_mqa_logits_metadata(
+ ctx_lens,
+ mem_manager.c4_indexer_pool.page_size,
+ deep_gemm.get_num_sms(),
+ )
+ logits = deep_gemm.fp8_paged_mqa_logits(
+ idx_q_fp8.unsqueeze(1),
+ kv_cache,
+ weights,
+ ctx_lens,
+ row_page_table,
+ metadata,
+ c4_cap,
+ False,
+ )
+ top_slots = torch.empty((idx_q_fp8.shape[0], index_topk), dtype=torch.int32, device=device)
+ topk_transform_512(
+ logits,
+ valid_len,
+ row_page_table,
+ top_slots,
+ mem_manager.c4_indexer_pool.page_size,
+ )
+ topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1)
+ return top_slots.unsqueeze(1), topk_lengths.contiguous()
diff --git a/lightllm/models/deepseek_v4/layer_weights/__init__.py b/lightllm/models/deepseek_v4/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 0000000000..54f29ce574
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,37 @@
+import torch
+from lightllm.common.basemodel import PreAndPostLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ EmbeddingWeight,
+ LMHeadWeight,
+ RMSNormWeight,
+ ParameterWeight,
+)
+
+
+class DeepseekV4PreAndPostLayerWeight(PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+
+ hidden = network_config["hidden_size"]
+ vocab = network_config["vocab_size"]
+ hc_mult = network_config["hc_mult"]
+
+ # embeddings / lm_head / final norm (bf16, vocab tensor-parallel). V4 has no `model.` prefix
+ # and does not tie embeddings (tie_word_embeddings=false).
+ self.wte_weight_ = EmbeddingWeight(
+ dim=hidden, vocab_size=vocab, weight_name="embed.weight", data_type=self.data_type_
+ )
+ self.lm_head_weight_ = LMHeadWeight(
+ dim=hidden, vocab_size=vocab, weight_name="head.weight", data_type=self.data_type_
+ )
+ self.final_norm_weight_ = RMSNormWeight(dim=hidden, weight_name="norm.weight", data_type=self.data_type_)
+
+ # final hyper-connection head (collapses the hc_mult residual streams before the lm_head)
+ self.hc_head_fn_ = ParameterWeight(
+ weight_name="hc_head_fn", data_type=torch.float32, weight_shape=(hc_mult, hc_mult * hidden)
+ )
+ self.hc_head_base_ = ParameterWeight(
+ weight_name="hc_head_base", data_type=torch.float32, weight_shape=(hc_mult,)
+ )
+ self.hc_head_scale_ = ParameterWeight(weight_name="hc_head_scale", data_type=torch.float32, weight_shape=(1,))
+ return
diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..5896027b38
--- /dev/null
+++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,331 @@
+import torch
+from lightllm.common.basemodel import TransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ ROWMMWeight,
+ COLMMWeight,
+ ROWBMMWeight,
+ RMSNormWeight,
+ ParameterWeight,
+ TpAttSinkWeight,
+ FusedMoeWeight,
+)
+from ..triton_kernel.quant_convert import dequant_fp8_block_to_bf16
+
+
+class DeepseekV4TransformerLayerWeight(TransformerLayerWeight):
+ """Per-layer weights for DeepSeek-V4-Flash.
+
+ DS4 does not share DS2/DS3.2's ``model.layers.*.self_attn/mlp`` layout. Its attention is
+ HC + CSA, and routed experts are checkpointed as MXFP4 (fp4 release) or
+ FP8 block-128 (fp8 release, same layout as the dense fp8 weights).
+ """
+
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _parse_config(self):
+ cfg = self.network_config_
+ self.hidden = cfg["hidden_size"]
+ self.n_heads = cfg["num_attention_heads"]
+ self.head_dim = cfg["head_dim"]
+ self.rope_dim = cfg["qk_rope_head_dim"]
+ self.q_lora_rank = cfg["q_lora_rank"]
+ self.o_lora_rank = cfg["o_lora_rank"]
+ self.o_groups = cfg["o_groups"]
+ self.index_n_heads = cfg["index_n_heads"]
+ self.index_head_dim = cfg["index_head_dim"]
+ self.n_routed_experts = cfg["n_routed_experts"]
+ self.moe_inter = cfg["moe_intermediate_size"]
+ self.num_hash_layers = cfg["num_hash_layers"]
+ self.vocab_size = cfg["vocab_size"]
+ self.hc_mult = cfg["hc_mult"]
+ self.mix_hc = (2 + self.hc_mult) * self.hc_mult
+ self.compress_ratio = cfg["compress_ratios"][self.layer_num_]
+ self.has_compressor = self.compress_ratio != 0
+ self.has_indexer = self.compress_ratio == 4
+ self.is_hash = self.layer_num_ < self.num_hash_layers
+ assert self.n_heads % self.tp_world_size_ == 0
+ assert self.o_groups % self.tp_world_size_ == 0
+ self.prefix = f"layers.{self.layer_num_}"
+
+ def _init_weight(self):
+ self._init_qkvo()
+ if self.has_compressor:
+ self._init_compressor()
+ if self.has_indexer:
+ self._init_indexer()
+ self._init_moe()
+ self._init_norm()
+ self._init_hyper_connection()
+
+ # ------------------------------------------------------------------ attention
+ def _init_qkvo(self):
+ p = f"{self.prefix}.attn"
+ # q low-rank (a replicated, b column-parallel over heads), kv single head (replicated)
+ self.wq_a_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.q_lora_rank],
+ weight_names=f"{p}.wq_a.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wq_a"),
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.wq_b_ = ROWMMWeight(
+ in_dim=self.q_lora_rank,
+ out_dims=[self.n_heads * self.head_dim],
+ weight_names=f"{p}.wq_b.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wq_b"),
+ )
+ self.wkv_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.head_dim],
+ weight_names=f"{p}.wkv.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wkv"),
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.q_norm_ = RMSNormWeight(dim=self.q_lora_rank, weight_name=f"{p}.q_norm.weight", data_type=self.data_type_)
+ self.kv_norm_ = RMSNormWeight(dim=self.head_dim, weight_name=f"{p}.kv_norm.weight", data_type=self.data_type_)
+ self.attn_sink_ = TpAttSinkWeight(
+ all_q_head_num=self.n_heads, weight_name=f"{p}.attn_sink", data_type=torch.float32
+ )
+ # grouped low-rank output projection: wo_a is a per-group batched matmul [groups, in, o_lora],
+ # wo_b is row-parallel [groups*o_lora -> hidden]. wo_a is reshaped in load_hf_weights.
+ per_group_in = self.n_heads * self.head_dim // self.o_groups
+ self.wo_a_ = ROWBMMWeight(
+ dim0=self.o_groups,
+ dim1=per_group_in,
+ dim2=self.o_lora_rank,
+ weight_names=f"{p}.wo_a.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ )
+ self.wo_b_ = COLMMWeight(
+ in_dim=self.o_groups * self.o_lora_rank,
+ out_dims=[self.hidden],
+ weight_names=f"{p}.wo_b.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("wo_b"),
+ )
+
+ # ------------------------------------------------------------------ compressor / indexer
+ def _init_compressor(self):
+ prefix = f"{self.prefix}.attn.compressor"
+ head_dim = self.head_dim
+ ratio = self.compress_ratio
+
+ coff = 2 if ratio == 4 else 1
+ # wkv/wgate are bf16 (no scale) and replicated (single KV head).
+ self.compressor_wkv_gate_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[coff * head_dim, coff * head_dim],
+ weight_names=[f"{prefix}.wkv.weight", f"{prefix}.wgate.weight"],
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.compressor_norm_ = RMSNormWeight(
+ dim=head_dim, weight_name=f"{prefix}.norm.weight", data_type=self.data_type_
+ )
+ self.compressor_ape_ = ParameterWeight(
+ weight_name=f"{prefix}.ape", data_type=torch.float32, weight_shape=(ratio, coff * head_dim)
+ )
+
+ def _init_indexer(self):
+ p = f"{self.prefix}.attn.indexer"
+ # The Lightning-Indexer is REPLICATED across TP ranks (like sglang/vllm), not head-sharded:
+ # q_lora and the attn input are already full on every rank, so each rank scores all
+ # index_n_heads locally and the c4 top-k is identical everywhere -- no gather/all_reduce.
+ # wq_b is FP8 in the checkpoint -> de-quantized to bf16 at load.
+ self.idx_wq_b_ = ROWMMWeight(
+ in_dim=self.q_lora_rank,
+ out_dims=[self.index_n_heads * self.index_head_dim],
+ weight_names=f"{p}.wq_b.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("idx_wq_b"),
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.idx_weights_proj_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.index_n_heads],
+ weight_names=f"{p}.weights_proj.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ coff = 2 # indexer compressor always uses ratio 4 (overlap)
+ self.idx_cmp_wkv_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[coff * self.index_head_dim],
+ weight_names=f"{p}.compressor.wkv.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.idx_cmp_wgate_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[coff * self.index_head_dim],
+ weight_names=f"{p}.compressor.wgate.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ self.idx_cmp_norm_ = RMSNormWeight(
+ dim=self.index_head_dim, weight_name=f"{p}.compressor.norm.weight", data_type=self.data_type_
+ )
+ self.idx_cmp_ape_ = ParameterWeight(
+ weight_name=f"{p}.compressor.ape", data_type=torch.float32, weight_shape=(4, coff * self.index_head_dim)
+ )
+
+ # ------------------------------------------------------------------ moe
+ def _init_moe(self):
+ p = f"{self.prefix}.ffn"
+ # router gate (replicated). Stored as fp32: the topk_hash_softplus_sqrt router wants fp32 logits,
+ # so keep the gate matmul in fp32 — but store the (constant) weight as fp32 once here instead of
+ # re-casting it to fp32 on every forward in _ffn.
+ self.gate_weight_ = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.n_routed_experts],
+ weight_names=f"{p}.gate.weight",
+ data_type=torch.float32,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+ if self.is_hash:
+ self.gate_tid2eid_ = ParameterWeight(
+ weight_name=f"{p}.gate.tid2eid",
+ data_type=torch.int64,
+ weight_shape=(self.vocab_size, self.network_config_["num_experts_per_tok"]),
+ )
+ else:
+ self.gate_bias_ = ParameterWeight(
+ weight_name=f"{p}.gate.bias", data_type=torch.float32, weight_shape=(self.n_routed_experts,)
+ )
+ # shared expert (dense, bf16 after de-quant): w1=gate, w3=up fused (row), w2=down (col).
+ # Named gate_up_proj/down_proj so the inherited Llama `_ffn_tp` (fused gate_up matmul +
+ # silu_and_mul triton kernel, no swiglu clamp) drives it directly. Order [w1, w3] = [gate, up]
+ # matches silu_and_mul_fwd's blocked layout (first half gate, second half up).
+ sp = f"{p}.shared_experts"
+ self.gate_up_proj = ROWMMWeight(
+ in_dim=self.hidden,
+ out_dims=[self.moe_inter, self.moe_inter],
+ weight_names=[f"{sp}.w1.weight", f"{sp}.w3.weight"],
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("shared_gate"),
+ )
+ self.down_proj = COLMMWeight(
+ in_dim=self.moe_inter,
+ out_dims=[self.hidden],
+ weight_names=f"{sp}.w2.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("shared_down"),
+ )
+ self.experts_ = FusedMoeWeight(
+ gate_proj_name="w1",
+ down_proj_name="w2",
+ up_proj_name="w3",
+ e_score_correction_bias_name="",
+ weight_prefix=f"{p}.experts",
+ n_routed_experts=self.n_routed_experts,
+ hidden_size=self.hidden,
+ moe_intermediate_size=self.moe_inter,
+ data_type=self.data_type_,
+ quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
+ layer_num=self.layer_num_,
+ network_config=self.network_config_,
+ )
+
+ def _init_norm(self):
+ self.attn_norm_ = RMSNormWeight(
+ dim=self.hidden, weight_name=f"{self.prefix}.attn_norm.weight", data_type=self.data_type_
+ )
+ self.ffn_norm_ = RMSNormWeight(
+ dim=self.hidden, weight_name=f"{self.prefix}.ffn_norm.weight", data_type=self.data_type_
+ )
+
+ def _init_hyper_connection(self):
+ p = self.prefix
+ self.hc_attn_fn_ = ParameterWeight(
+ weight_name=f"{p}.hc_attn_fn",
+ data_type=torch.float32,
+ weight_shape=(self.mix_hc, self.hc_mult * self.hidden),
+ )
+ self.hc_attn_base_ = ParameterWeight(
+ weight_name=f"{p}.hc_attn_base", data_type=torch.float32, weight_shape=(self.mix_hc,)
+ )
+ self.hc_attn_scale_ = ParameterWeight(
+ weight_name=f"{p}.hc_attn_scale", data_type=torch.float32, weight_shape=(3,)
+ )
+ self.hc_ffn_fn_ = ParameterWeight(
+ weight_name=f"{p}.hc_ffn_fn",
+ data_type=torch.float32,
+ weight_shape=(self.mix_hc, self.hc_mult * self.hidden),
+ )
+ self.hc_ffn_base_ = ParameterWeight(
+ weight_name=f"{p}.hc_ffn_base", data_type=torch.float32, weight_shape=(self.mix_hc,)
+ )
+ self.hc_ffn_scale_ = ParameterWeight(
+ weight_name=f"{p}.hc_ffn_scale", data_type=torch.float32, weight_shape=(3,)
+ )
+
+ # ------------------------------------------------------------------ loading
+ def load_hf_weights(self, weights):
+ self._dequant_in_place(weights)
+ return super().load_hf_weights(weights)
+
+ def _fp8_scale_renames(self):
+ """Map weight name -> the scale name its quant method loads (e.g. `weight_scale_inv`
+ for DeepGEMM). Read from each MM weight's own `weight_scale_names`, so the rename
+ target always matches what that weight will look up; no-quant weights have None
+ entries and are skipped."""
+ renames = {}
+ for attr in self.__dict__.values():
+ weight_names = getattr(attr, "weight_names", ())
+ scale_names = getattr(attr, "weight_scale_names", ())
+ for weight_name, scale_name in zip(weight_names, scale_names):
+ if scale_name is not None:
+ renames[weight_name] = scale_name
+ return renames
+
+ def _dequant_in_place(self, weights):
+ p = self.prefix + "."
+ scale_renames = self._fp8_scale_renames()
+ # Convert every `.scale` belonging to this layer. Weights are loaded incrementally
+ # per safetensors shard, so the paired weight may live in another shard:
+ # - routed expert `.scale` follows the fused_moe quant method's weight_scale_suffix:
+ # MXFP4 consumes `.scale` as-is, FP8 DeepGEMM expects `.weight_scale_inv` (rename only);
+ # - FP8 matmul scales only need renaming for DeepGEMM, no weight required;
+ # - FP8 pairs on no-quant paths (wo_a's ROWBMMWeight) are expanded to bf16,
+ # the only case that truly requires weight and scale in the same shard.
+ expert_scale_suffix = self.experts_.quant_method.weight_scale_suffix
+ for scale_k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".scale")]:
+ if scale_k.startswith(f"{p}ffn.experts."):
+ if expert_scale_suffix is not None and expert_scale_suffix != "scale":
+ weights[scale_k[: -len("scale")] + expert_scale_suffix] = weights[scale_k].to(torch.float32)
+ del weights[scale_k]
+ continue
+ k = scale_k[: -len(".scale")] + ".weight"
+ target = scale_renames.get(k)
+ if target is not None: # FP8 e4m3, block-128 scale, run by DeepGEMM directly
+ weights[target] = weights[scale_k].to(torch.float32)
+ del weights[scale_k]
+ else:
+ weights[k] = dequant_fp8_block_to_bf16(weights[k], weights[scale_k]).to(self.data_type_)
+ del weights[scale_k]
+ # grouped-O: reshape [groups*o_lora, in] -> [groups, in, o_lora] for the batched matmul
+ woa = f"{self.prefix}.attn.wo_a.weight"
+ if woa in weights and weights[woa].dim() == 2:
+ w = weights[woa]
+ per_group_in = self.n_heads * self.head_dim // self.o_groups
+ weights[woa] = w.view(self.o_groups, self.o_lora_rank, per_group_in).transpose(1, 2).contiguous()
+ return
diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py
new file mode 100644
index 0000000000..887e72f433
--- /dev/null
+++ b/lightllm/models/deepseek_v4/model.py
@@ -0,0 +1,279 @@
+import copy
+import importlib.util
+import os
+
+import torch
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.llama.model import LlamaTpPartModel
+from lightllm.common.req_manager import DeepseekV4ReqManager
+from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager
+from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import (
+ DeepseekV4PreAndPostLayerWeight,
+)
+from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import (
+ DeepseekV4TransformerLayerWeight,
+)
+from lightllm.models.deepseek_v4.layer_infer.pre_layer_infer import (
+ DeepseekV4PreLayerInfer,
+)
+from lightllm.models.deepseek_v4.layer_infer.post_layer_infer import (
+ DeepseekV4PostLayerInfer,
+)
+from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import (
+ DeepseekV4TransformerLayerInfer,
+)
+from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class
+from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo
+from lightllm.models.llama.yarn_rotary_utils import (
+ find_correction_range,
+ linear_ramp_mask,
+)
+from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args
+from lightllm.utils.log_utils import init_logger
+from lightllm.distributed.communication_op import dist_group_manager
+
+logger = init_logger(__name__)
+DSV4_DECODE_CUDAGRAPH_MAX_LEN = 8192
+
+
+@ModelRegistry("deepseek_v4")
+class DeepseekV4TpPartModel(LlamaTpPartModel):
+ req_manager: DeepseekV4ReqManager
+ mem_manager: DeepseekV4MemoryManager
+
+ pre_and_post_weight_class = DeepseekV4PreAndPostLayerWeight
+ transformer_weight_class = DeepseekV4TransformerLayerWeight
+
+ pre_layer_infer_class = DeepseekV4PreLayerInfer
+ post_layer_infer_class = DeepseekV4PostLayerInfer
+ transformer_layer_infer_class = DeepseekV4TransformerLayerInfer
+
+ infer_state_class = DeepseekV4InferStateInfo
+
+ def _verify_params(self):
+ assert self.load_way == "HF", "only support HF format weights"
+ assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
+ assert self.config["o_groups"] % self.tp_world_size_ == 0
+ assert self.config["index_n_heads"] % self.tp_world_size_ == 0
+ return
+
+ def _init_req_manager(self):
+ create_max_seq_len = 0
+ if self.batch_max_tokens is not None:
+ create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
+ if self.max_seq_length is not None:
+ create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
+
+ self._dsv4_req_manager_seq_len = create_max_seq_len
+ layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num()
+ self._dsv4_compress_rates = self._get_compress_rates(layer_num)
+ self.req_manager = DeepseekV4ReqManager(
+ self.max_req_num,
+ create_max_seq_len,
+ compress_rates=self._dsv4_compress_rates,
+ head_dim=self.config["head_dim"],
+ indexer_head_dim=self.config["index_head_dim"],
+ sliding_window=self.config["sliding_window"],
+ )
+ return
+
+ def _get_compress_rates(self, layer_num):
+ rates = list(self.config["compress_ratios"])
+ return rates[:layer_num]
+
+ def _init_mem_manager(self):
+ layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num()
+ compress_rates = getattr(self, "_dsv4_compress_rates", self._get_compress_rates(layer_num))
+ sliding_window = int(self.config["sliding_window"])
+ # 活跃窗口之外的 swa 余量: 在途 prefill chunk 的瞬时占用(出窗槽位到下一次 prep 才回收)
+ # + radix cache 持有的窗口尾部(每条缓存序列约一个 window)。
+ swa_extra_token_num = int(self.batch_max_tokens or 0) + self.max_req_num * sliding_window
+ self.mem_manager = DeepseekV4MemoryManager(
+ self.max_total_token_num,
+ dtype=self.data_type,
+ head_num=1,
+ head_dim=self.config["head_dim"],
+ layer_num=layer_num,
+ compress_rates=compress_rates,
+ indexer_head_dim=self.config["index_head_dim"],
+ max_request_num=self.max_req_num,
+ sliding_window=sliding_window,
+ swa_extra_token_num=swa_extra_token_num,
+ mem_fraction=self.mem_fraction,
+ )
+ assert isinstance(self.req_manager, DeepseekV4ReqManager)
+ self.req_manager.bind_mem_manager(self.mem_manager)
+ return
+
+ def _init_cudagraph(self):
+ if not self.disable_cudagraph and self.graph_max_len_in_batch > DSV4_DECODE_CUDAGRAPH_MAX_LEN:
+ logger.info(
+ "DeepSeek-V4 caps decode cudagraph max_len_in_batch from %s to %s for the current "
+ "graph-safe sparse-attention path; longer decode batches run eager.",
+ self.graph_max_len_in_batch,
+ DSV4_DECODE_CUDAGRAPH_MAX_LEN,
+ )
+ self.graph_max_len_in_batch = DSV4_DECODE_CUDAGRAPH_MAX_LEN
+ return super()._init_cudagraph()
+
+ def _init_att_backend(self):
+ args = get_env_start_args()
+ if args.llm_kv_type == "None":
+ args.llm_kv_type = "fp8kv_dsa"
+ # TODO: 支持其他 kv type
+ if args.llm_kv_type != "fp8kv_dsa":
+ raise RuntimeError("DeepSeek-V4 requires llm_kv_type=fp8kv_dsa for packed FlashMLA sparse attention")
+ self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self)
+ self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self)
+ return
+
+ def _init_custom(self):
+ self._init_to_get_rotary()
+ dist_group_manager.new_deepep_group(
+ self.config["n_routed_experts"],
+ self.config["hidden_size"],
+ self.config.get("num_experts_per_tok", 1),
+ self.config.get("moe_intermediate_size", self.config.get("intermediate_size")),
+ )
+ return
+
+ def _init_to_get_rotary(self):
+ # Interleaved (GPT-J) rope. Build complex64 freqs_cis tables (_freqs_cis_*) following the
+ # gemma4 two-variant convention; the fused sglang q kernel consumes them directly, while
+ # _cos_cached_*/_sin_cached_* are .real/.imag views of the same storage for the kv rope,
+ # inverse rope and compressor paths (deepseek2's interleaved triton rotary_emb_fwd).
+ # Sliding-window layers use base rope_theta (no YaRN);
+ # compressed (CSA/HCA) layers use compress_rope_theta with configured rope_scaling.
+ # Kept fp32 for accuracy (the apply upcasts anyway).
+ cfg = self.config
+ rs = cfg.get("rope_scaling", {}) or {}
+ dim = cfg["qk_rope_head_dim"]
+ # The rope tables MUST span every absolute position any request can produce (the served
+ # max_req_total_len / max_position_embeddings, up to 1M). Capping them shorter makes
+ # init_some_extra_state's index_select(cos/sin, position_ids) read OOB past the table at
+ # contexts beyond the cap (device-side assert / crash). ~268MB total at 1M, fp32x32 x4 views.
+ max_seq = max(int(self.max_seq_length), int(cfg.get("max_position_embeddings", 8192)))
+ freq_exponents = torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim
+ positions = torch.arange(max_seq, dtype=torch.float32, device="cuda")
+
+ sliding_freqs = 1.0 / (cfg["rope_theta"] ** freq_exponents)
+ f = torch.outer(positions, sliding_freqs) # [max_seq, dim//2]
+ self._freqs_cis_sliding = torch.complex(f.cos(), f.sin())
+
+ compress_freqs = 1.0 / (cfg["compress_rope_theta"] ** freq_exponents)
+ rope_type = rs.get("rope_type", rs.get("type", "default"))
+ orig_max = rs.get("original_max_position_embeddings", 0)
+ if rope_type == "yarn" and orig_max > 0:
+ beta_fast = rs.get("beta_fast", 32)
+ beta_slow = rs.get("beta_slow", 1)
+ factor = rs.get("factor", 1)
+ if factor is None:
+ factor = cfg.get("max_position_embeddings", max_seq) / orig_max
+ low, high = find_correction_range(beta_fast, beta_slow, dim, cfg["compress_rope_theta"], orig_max)
+ smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda()
+ compress_freqs = compress_freqs / factor * (1 - smooth) + compress_freqs * smooth
+ f = torch.outer(positions, compress_freqs) # [max_seq, dim//2]
+ self._freqs_cis_compress = torch.complex(f.cos(), f.sin())
+ self._cos_cached_sliding = self._freqs_cis_sliding.real
+ self._sin_cached_sliding = self._freqs_cis_sliding.imag
+ self._cos_cached_compress = self._freqs_cis_compress.real
+ self._sin_cached_compress = self._freqs_cis_compress.imag
+ # Each layer uses exactly one rope variant; wire its table once here (layers are already
+ # built: _init_infer_layer runs before _init_custom) instead of relaying via infer_state.
+ # The compressor needs the full compress tables (entry rope positions != token positions).
+ for layer in self.layers_infer:
+ layer.freqs_cis = self._freqs_cis_compress if layer.compress_ratio else self._freqs_cis_sliding
+ layer.cos_compress_table = self._cos_cached_compress
+ layer.sin_compress_table = self._sin_cached_compress
+ return
+
+
+class DeepSeekV4Tokenizer:
+ """Tokenizer wrapper for DeepSeek-V4's Python prompt encoding."""
+
+ # DeepSeek-V4 has a per-request thinking mode (...) toggled via
+ # chat_template_kwargs={"thinking": true}. It has no Jinja chat_template string,
+ # so advertise thinking support explicitly for tokenizer_supports_force_thinking().
+ supports_thinking = True
+
+ def __init__(self, tokenizer, model_dir):
+ self.tokenizer = tokenizer
+ self.model_dir = model_dir
+ self._encoding_module = None
+ self._added_vocab = None
+
+ def __getattr__(self, name):
+ return getattr(self.tokenizer, name)
+
+ def get_added_vocab(self):
+ if self._added_vocab is None:
+ self._added_vocab = self.tokenizer.get_added_vocab()
+ return self._added_vocab
+
+ def _get_encoding_module(self):
+ if self._encoding_module is not None:
+ return self._encoding_module
+
+ encoding_path = os.path.join(self.model_dir, "encoding", "encoding_dsv4.py")
+ if not os.path.exists(encoding_path):
+ raise FileNotFoundError(f"DeepSeek-V4 encoding file not found: {encoding_path}")
+
+ spec = importlib.util.spec_from_file_location("lightllm_deepseek_v4_encoding_dsv4", encoding_path)
+ if spec is None or spec.loader is None:
+ raise ImportError(f"failed to load DeepSeek-V4 encoding module from {encoding_path}")
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ self._encoding_module = module
+ return module
+
+ def apply_chat_template(
+ self,
+ conversation=None,
+ messages=None,
+ tools=None,
+ tokenize=False,
+ add_generation_prompt=True,
+ thinking=None,
+ enable_thinking=None,
+ **kwargs,
+ ):
+ msgs = conversation if conversation is not None else messages
+ if msgs is None:
+ raise ValueError("Either 'conversation' or 'messages' must be provided")
+
+ msgs = copy.deepcopy(msgs)
+
+ if tools:
+ wrapped_tools = []
+ for tool in tools:
+ if "function" in tool:
+ wrapped_tools.append(tool)
+ else:
+ wrapped_tools.append({"type": "function", "function": tool})
+
+ injected = False
+ for msg in msgs:
+ if msg.get("role") == "system":
+ existing = msg.get("tools") or []
+ msg["tools"] = existing + wrapped_tools
+ injected = True
+ break
+
+ if not injected:
+ msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools})
+
+ if thinking is None:
+ thinking = bool(enable_thinking) if enable_thinking is not None else False
+ thinking_mode = "thinking" if thinking else "chat"
+ encoding = self._get_encoding_module()
+ prompt = encoding.encode_messages(
+ msgs,
+ thinking_mode=thinking_mode,
+ drop_thinking=kwargs.get("drop_thinking", True),
+ add_default_bos_token=kwargs.get("add_default_bos_token", True),
+ reasoning_effort=kwargs.get("reasoning_effort"),
+ )
+
+ if tokenize:
+ return self.tokenizer.encode(prompt, add_special_tokens=False)
+ return prompt
diff --git a/lightllm/models/deepseek_v4/triton_kernel/__init__.py b/lightllm/models/deepseek_v4/triton_kernel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py
new file mode 100644
index 0000000000..b09192498d
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py
@@ -0,0 +1,78 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _build_compress_index_kernel(
+ req_idx_ptr,
+ pos_ptr,
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_c_ptr,
+ index_ptr,
+ length_ptr,
+ cap,
+ RATIO: tl.constexpr,
+ BLOCK_E: tl.constexpr,
+):
+ t = tl.program_id(0)
+ eb = tl.program_id(1)
+ req = tl.load(req_idx_ptr + t).to(tl.int64)
+ pos = tl.load(pos_ptr + t).to(tl.int64)
+ raw_len = (pos + 1) // RATIO
+
+ e = eb * BLOCK_E + tl.arange(0, BLOCK_E)
+ e_mask = e < cap
+ valid = (e < raw_len) & e_mask
+ # group-end token of compressed entry e: position e*RATIO + (RATIO-1).
+ end_pos = e * RATIO + (RATIO - 1)
+ safe_pos = tl.where(valid, end_pos, 0)
+ full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_pos, mask=valid, other=0).to(tl.int64)
+ c_slot = tl.load(full_to_c_ptr + full_slot, mask=valid, other=-1).to(tl.int32)
+ tl.store(index_ptr + t * cap + e, c_slot, mask=e_mask)
+
+ if eb == 0:
+ tl.store(length_ptr + t, tl.maximum(raw_len, 1).to(tl.int32))
+
+
+def build_compress_index(
+ req_idx: torch.Tensor,
+ positions: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ full_to_c_indexs: torch.Tensor,
+ ratio: int,
+ cap: int,
+):
+ """Fused two-level group-end gather for the c4/c128 compressed-entry index tables.
+
+ For token t (at request `req_idx[t]`, absolute `positions[t]`) and compressed entry e:
+ slot[t, e] = full_to_c[ req_to_token[req, e*ratio + (ratio-1)] ] (the group-end token's full slot)
+ with slot = -1 where e >= (pos+1)//ratio (beyond the causal compressed length) or where the
+ full->c map is unset. Returns (index [T, cap] int32, length [T] int32 = clamp((pos+1)//ratio, 1)).
+
+ Replaces the eager _gather_compress_slots/_c128/c4-causal torch chain. `cap` must be a multiple of
+ 64 (FlashMLA topk alignment); the tiled grid (T, ceil(cap/BLOCK_E)) scales to 1M-context caps.
+ cuda-graph-safe: cap is fixed per graph bucket, shapes static.
+ """
+ T = positions.shape[0]
+ index = torch.empty((T, cap), dtype=torch.int32, device=positions.device)
+ length = torch.empty((T,), dtype=torch.int32, device=positions.device)
+ if T == 0:
+ return index, length
+ BLOCK_E = 256
+ grid = (T, triton.cdiv(cap, BLOCK_E))
+ _build_compress_index_kernel[grid](
+ req_idx,
+ positions,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ full_to_c_indexs,
+ index,
+ length,
+ cap,
+ RATIO=ratio,
+ BLOCK_E=BLOCK_E,
+ num_warps=4,
+ )
+ return index, length
diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py
new file mode 100644
index 0000000000..e5b7b80bb4
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py
@@ -0,0 +1,75 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _build_swa_index_kernel(
+ req_idx_ptr,
+ pos_ptr,
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_swa_ptr,
+ swa_index_ptr,
+ swa_length_ptr,
+ WINDOW: tl.constexpr,
+ BLOCK_W: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ req = tl.load(req_idx_ptr + token_idx).to(tl.int64)
+ pos = tl.load(pos_ptr + token_idx).to(tl.int64)
+
+ w = tl.arange(0, BLOCK_W)
+ w_mask = w < WINDOW
+ # most-recent-first window, identical to the eager _swa_indices (offset = position - arange).
+ offset = pos - w
+ valid = (offset >= 0) & w_mask
+ safe_offset = tl.where(valid, offset, 0)
+ full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_offset, mask=valid, other=0).to(tl.int64)
+ swa_slot = tl.load(full_to_swa_ptr + full_slot, mask=valid, other=-1)
+ out = tl.where(valid, swa_slot, -1).to(tl.int32)
+ tl.store(swa_index_ptr + token_idx * WINDOW + w, out, mask=w_mask)
+
+ length = tl.minimum(tl.maximum(pos + 1, 1), WINDOW).to(tl.int32)
+ tl.store(swa_length_ptr + token_idx, length)
+
+
+def build_swa_index(
+ req_idx: torch.Tensor,
+ positions: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ full_to_swa_indexs: torch.Tensor,
+ window: int,
+):
+ """Per-token sliding-window FlashMLA index table, built ONCE per forward (layer-independent:
+ full_to_swa is a single global map and the window is a model constant, so every layer's swa
+ indices are identical). Replaces DeepseekV4IndexInfer._swa_indices: for token t at
+ (req_idx, position) gather the last `window` tokens' full slots via req_to_token, then map
+ full -> swa; out-of-range positions store -1.
+
+ Returns (swa_index [T, window] int32, swa_length [T] int32). `window` is 128 (a multiple of the
+ FlashMLA 64 alignment) so no extra pad is needed; the reader adds the s_q axis via unsqueeze(1).
+ Const output shape (no max_kv_seq_len dependence) makes this cuda-graph-safe to stage from
+ init_some_extra_state via copy_for_cuda_graph.
+ """
+ # window must stay 64-aligned: the output is the FlashMLA `indices` tensor directly (no separate
+ # _pad_last_dim), and the extra-cache fork requires the topk dim to be a multiple of 64.
+ assert window % 64 == 0, f"DeepSeek-V4 sliding_window must be a multiple of 64 for FlashMLA, got {window}"
+ T = positions.shape[0]
+ swa_index = torch.empty((T, window), dtype=torch.int32, device=positions.device)
+ swa_length = torch.empty((T,), dtype=torch.int32, device=positions.device)
+ if T == 0:
+ return swa_index, swa_length
+ _build_swa_index_kernel[(T,)](
+ req_idx,
+ positions,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ full_to_swa_indexs,
+ swa_index,
+ swa_length,
+ WINDOW=window,
+ BLOCK_W=triton.next_power_of_2(window),
+ num_warps=4,
+ )
+ return swa_index, swa_length
diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py
new file mode 100644
index 0000000000..3510b92c30
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py
@@ -0,0 +1,92 @@
+import torch
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel_destindex_copy_indexer_k_dsv4(
+ K,
+ Dest_loc,
+ O_fp8,
+ O_f32,
+ stride_k_bs,
+ stride_k_d,
+ FP8_MIN: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+ SCALE_MIN: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+):
+ cur_index = tl.program_id(0)
+ dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
+ # negative dest (unmapped slot) is a no-op, not an OOB write into a neighboring page.
+ if dest_index < 0:
+ return
+
+ page = dest_index // PAGE_SIZE
+ token_in_page = dest_index % PAGE_SIZE
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ vals = tl.load(K + cur_index * stride_k_bs + offs_d * stride_k_d).to(tl.float32)
+ amax = tl.max(tl.abs(vals), axis=0)
+ # per-token plain fp32 scale (not ue8m0), matching DeepseekV4MemoryManager._pack_indexer_k
+ scale = tl.maximum(amax / FP8_MAX, SCALE_MIN)
+ k_fp8 = tl.clamp(vals / scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
+
+ data_base = page * BYTES_PER_PAGE + token_in_page * HEAD_DIM
+ tl.store(O_fp8 + data_base + offs_d, k_fp8)
+ scale_idx = (page * BYTES_PER_PAGE + PAGE_SIZE * HEAD_DIM) // 4 + token_in_page
+ tl.store(O_f32 + scale_idx, scale)
+ return
+
+
+@torch.no_grad()
+def destindex_copy_indexer_k_dsv4(
+ K: torch.Tensor,
+ DestLoc: torch.Tensor,
+ O_buffer: torch.Tensor,
+ page_size: int,
+):
+ """Packed indexer-K page-slab writer (DeepSeek-V4 c4/CSA layers).
+
+ K: [T, 128] bf16 unquantized indexer keys.
+ DestLoc: [T] int — c4-pool-local token slots; must already be allocated by the caller.
+ Negative slots (unmapped) are skipped.
+ O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from the c4 indexer
+ PackedPagePool (128B fp8 data region + 4B fp32 scale tail per token).
+
+ Bit-compatible with DeepseekV4MemoryManager._pack_indexer_k + PackedPagePool.write.
+ """
+ seq_len = DestLoc.shape[0]
+ if seq_len == 0:
+ return
+ head_dim, scale_bytes = 128, 4
+
+ K = K.reshape(-1, head_dim)
+ assert K.shape[0] == seq_len, f"Expected K shape[0]={seq_len}, got {K.shape[0]}"
+ assert K.dtype == torch.bfloat16, f"Expected bf16 indexer K, got {K.dtype}"
+ bytes_per_page = O_buffer.shape[-1]
+ assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous()
+ assert bytes_per_page % 4 == 0
+ assert bytes_per_page >= page_size * (head_dim + scale_bytes)
+
+ flat = O_buffer.view(-1)
+ _fwd_kernel_destindex_copy_indexer_k_dsv4[(seq_len,)](
+ K,
+ DestLoc,
+ flat.view(torch.float8_e4m3fn),
+ flat.view(torch.float32),
+ K.stride(0),
+ K.stride(1),
+ FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ SCALE_MIN=1e-4,
+ HEAD_DIM=head_dim,
+ PAGE_SIZE=page_size,
+ BYTES_PER_PAGE=bytes_per_page,
+ num_warps=1,
+ num_stages=1,
+ )
+ return
diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py
new file mode 100644
index 0000000000..a3ec6ed8cf
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py
@@ -0,0 +1,121 @@
+import torch
+
+import triton
+import triton.language as tl
+from triton.language.extra import libdevice
+
+
+@triton.jit
+def _fwd_kernel_destindex_copy_kv_flashmla_dsv4(
+ KV,
+ Dest_loc,
+ O_fp8,
+ O_bf16,
+ O_u8,
+ stride_kv_bs,
+ stride_kv_d,
+ FP8_MIN: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+ SCALE_MIN: tl.constexpr,
+ NOPE_DIM: tl.constexpr,
+ ROPE_DIM: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ NUM_GROUPS: tl.constexpr,
+ SCALE_BYTES: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+):
+ cur_index = tl.program_id(0)
+ dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
+ # negative dest (unmapped slot, e.g. full_to_c* rows that never closed a group) is a no-op,
+ # not an OOB write into a neighboring page.
+ if dest_index < 0:
+ return
+
+ page = dest_index // PAGE_SIZE
+ token_in_page = dest_index % PAGE_SIZE
+ data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_DIM * 2)
+ scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_DIM * 2) + token_in_page * SCALE_BYTES
+
+ # nope: per-group ue8m0 quant. SCALE_BYTES(=NUM_GROUPS+1) lanes cover the exponent bytes
+ # plus the trailing zero pad byte in one store. libdevice.log2 (not tl.log2, which is the
+ # approx instruction) and the bit-packed 2**e keep this bit-exact with the torch oracle
+ # DeepseekV4MemoryManager._pack_mla_kv.
+ offs_g = tl.arange(0, SCALE_BYTES)
+ offs_e = tl.arange(0, GROUP_SIZE)
+ group_mask = offs_g < NUM_GROUPS
+ kv_ptrs = KV + cur_index * stride_kv_bs + (offs_g[:, None] * GROUP_SIZE + offs_e[None, :]) * stride_kv_d
+ vals = tl.load(kv_ptrs, mask=group_mask[:, None], other=0.0).to(tl.float32)
+ amax = tl.max(tl.abs(vals), axis=1)
+ scale_exp = tl.ceil(libdevice.log2(tl.maximum(amax / FP8_MAX, SCALE_MIN))).to(tl.int32)
+ scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True)
+ kv_fp8 = tl.clamp(vals / scale[:, None], min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
+ tl.store(O_fp8 + data_base + offs_g[:, None] * GROUP_SIZE + offs_e[None, :], kv_fp8, mask=group_mask[:, None])
+ scale_bytes = tl.where(group_mask, scale_exp + 127, 0).to(tl.uint8)
+ tl.store(O_u8 + scale_base + offs_g, scale_bytes)
+
+ # rope: bf16 passthrough into the data region right after the nope bytes
+ offs_r = tl.arange(0, ROPE_DIM)
+ rope = tl.load(KV + cur_index * stride_kv_bs + (NOPE_DIM + offs_r) * stride_kv_d)
+ tl.store(O_bf16 + (data_base + NOPE_DIM) // 2 + offs_r, rope)
+ return
+
+
+@torch.no_grad()
+def destindex_copy_kv_flashmla_dsv4(
+ KV: torch.Tensor,
+ DestLoc: torch.Tensor,
+ O_buffer: torch.Tensor,
+ page_size: int,
+):
+ """fp8_ds_mla packed page-slab writer (DeepSeek-V4 ABI, all latent pools).
+
+ KV: [T, 512] bf16 — 448 normed-latent dims + 64 rope'd dims per token.
+ DestLoc: [T] int — pool-local token slots (page = slot // page_size); the pool HOLD slot is
+ a valid in-bounds row, negative slots (unmapped) are skipped. Slots must already be
+ resolved/allocated by the caller.
+ O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from PackedPagePool
+ (swa page=128 / c4 page=64 / c128 page=2 all share this kernel).
+
+ Per token: 448B fp8(e4m3) in 7x64 ue8m0 groups + 128B bf16 rope in the page data region;
+ 7 exponent bytes (e+127) + 1 zero pad at the page scale tail. Bit-compatible with
+ DeepseekV4MemoryManager._pack_mla_kv + PackedPagePool.write.
+ """
+ seq_len = DestLoc.shape[0]
+ if seq_len == 0:
+ return
+ nope_dim, rope_dim, group_size = 448, 64, 64
+ head_dim = nope_dim + rope_dim
+ scale_bytes = nope_dim // group_size + 1
+
+ KV = KV.reshape(-1, head_dim)
+ assert KV.shape[0] == seq_len, f"Expected KV shape[0]={seq_len}, got {KV.shape[0]}"
+ assert KV.dtype == torch.bfloat16, f"Expected bf16 KV (rope bytes are stored as-is), got {KV.dtype}"
+ bytes_per_page = O_buffer.shape[-1]
+ assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous()
+ assert bytes_per_page % 2 == 0
+ assert bytes_per_page >= page_size * (nope_dim + rope_dim * 2 + scale_bytes)
+
+ flat = O_buffer.view(-1)
+ _fwd_kernel_destindex_copy_kv_flashmla_dsv4[(seq_len,)](
+ KV,
+ DestLoc,
+ flat.view(torch.float8_e4m3fn),
+ flat.view(torch.bfloat16),
+ flat,
+ KV.stride(0),
+ KV.stride(1),
+ FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ SCALE_MIN=1e-4,
+ NOPE_DIM=nope_dim,
+ ROPE_DIM=rope_dim,
+ GROUP_SIZE=group_size,
+ NUM_GROUPS=nope_dim // group_size,
+ SCALE_BYTES=scale_bytes,
+ PAGE_SIZE=page_size,
+ BYTES_PER_PAGE=bytes_per_page,
+ num_warps=4,
+ num_stages=1,
+ )
+ return
diff --git a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py
new file mode 100644
index 0000000000..a7a0a4be85
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py
@@ -0,0 +1,200 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _gather_c4_indexer_k_kernel(
+ req_idx_ptr, # [batch] int — req_manager slot per batch position
+ c4_len_ptr, # [batch] int — number of causal c4 entries per request (= seq_len // ratio)
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_c4_ptr,
+ SlabFp8_ptr, # c4 indexer pool, viewed as fp8 (flat)
+ SlabF32_ptr, # same pool, viewed as f32 (flat)
+ Kout_fp8_ptr, # [batch*c4_cap, HEAD_DIM] fp8
+ Kout_scale_ptr, # [batch*c4_cap] f32
+ Slots_out_ptr, # [batch*c4_cap] int32 (compact->c4-slot map; -1 for padding)
+ c4_cap,
+ RATIO: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ BYTES_PER_PAGE: tl.constexpr,
+ SCALE_OFFSET: tl.constexpr, # page_size * head_dim (byte offset of the scale tail)
+):
+ # entry index on grid-X (limit ~2^31), batch on grid-Y (<= running_max_req_size): c4_cap reaches
+ # 65536 at 256K context, which would blow the 65535 grid-Y cap if entries were the grid-Y axis.
+ e = tl.program_id(0)
+ r = tl.program_id(1)
+ # int64: out_pos*HEAD_DIM can exceed int32 at high batch + long context (read side is already int64).
+ out_pos = r.to(tl.int64) * c4_cap + e
+ c4_len = tl.load(c4_len_ptr + r)
+ if e >= c4_len:
+ # padding entry: mark slot invalid; K is never read (ke bounds the scorer range).
+ tl.store(Slots_out_ptr + out_pos, -1)
+ return
+
+ # group-end token of compressed entry e lives at position e*RATIO + (RATIO-1).
+ req = tl.load(req_idx_ptr + r).to(tl.int64)
+ end_tok = e * RATIO + (RATIO - 1)
+ full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + end_tok).to(tl.int64)
+ c4_slot = tl.load(full_to_c4_ptr + full_slot).to(tl.int64)
+ valid = c4_slot >= 0
+
+ # inline PackedPagePool byte addressing (matches destindex_copy_indexer_k_dsv4 / gather_indexer_k):
+ # fp8 K at page*bytes_per_page + tok*head_dim; fp32 scale at (page*bytes_per_page + scale_off)//4 + tok.
+ page = c4_slot // PAGE_SIZE
+ tok = c4_slot % PAGE_SIZE
+ data_base = page * BYTES_PER_PAGE + tok * HEAD_DIM
+ scale_base = (page * BYTES_PER_PAGE + SCALE_OFFSET) // 4 + tok
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ k_fp8 = tl.load(SlabFp8_ptr + data_base + offs_d, mask=valid, other=0.0)
+ k_scale = tl.load(SlabF32_ptr + scale_base, mask=valid, other=0.0)
+ tl.store(Kout_fp8_ptr + out_pos * HEAD_DIM + offs_d, k_fp8)
+ tl.store(Kout_scale_ptr + out_pos, k_scale)
+ tl.store(Slots_out_ptr + out_pos, tl.where(valid, c4_slot, -1).to(tl.int32))
+
+
+@torch.no_grad()
+def gather_c4_indexer_k_ragged(
+ mem_manager,
+ layer_index: int,
+ b_req_idx: torch.Tensor,
+ c4_len: torch.Tensor,
+ c4_cap: int,
+ req_to_token_indexs: torch.Tensor,
+):
+ """Gather each request's causal c4 indexer keys into a padded-per-request ragged buffer for the
+ deep_gemm fp8_mqa_logits scorer (mirrors deepseek3_2's extract_indexer_ks, but reads our
+ PackedPagePool by c4 slot instead of a token-indexed [N,1,132] buffer).
+
+ For batch position r and compressed entry e in [0, c4_len[r]):
+ c4_slot = full_to_c4[req_to_token[b_req_idx[r], e*ratio + (ratio-1)]]
+ The raw fp8 key + f32 scale at that slot land at row r*c4_cap + e of the output (so query token t
+ of request r reads keys [r*c4_cap, r*c4_cap + (pos+1)//ratio) -- absolute ks/ke offsets the caller
+ builds). Returns (k_fp8 [batch*c4_cap, HEAD_DIM] fp8, k_scale [batch*c4_cap] f32, slots
+ [batch*c4_cap] int32 = compact-row -> c4 pool slot, -1 for padding). Fixed shapes -> cuda-graph
+ safe (c4_cap is pinned per graph bucket); the padding region is never read by the scorer.
+ """
+ pool = mem_manager.c4_indexer_pool
+ head_dim = mem_manager.indexer_head_dim
+ buf = pool.get_layer_buffer(mem_manager.layer_to_c4_idx[layer_index]).view(-1)
+ slab_fp8 = buf.view(torch.float8_e4m3fn)
+ slab_f32 = buf.view(torch.float32)
+ batch = b_req_idx.shape[0]
+ n = batch * c4_cap
+ k_fp8 = torch.empty((n, head_dim), dtype=torch.float8_e4m3fn, device=buf.device)
+ k_scale = torch.empty((n,), dtype=torch.float32, device=buf.device)
+ slots = torch.empty((n,), dtype=torch.int32, device=buf.device)
+ _gather_c4_indexer_k_kernel[(c4_cap, batch)](
+ b_req_idx,
+ c4_len,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ mem_manager.full_to_c4_indexs,
+ slab_fp8,
+ slab_f32,
+ k_fp8,
+ k_scale,
+ slots,
+ c4_cap,
+ RATIO=4,
+ HEAD_DIM=head_dim,
+ PAGE_SIZE=pool.page_size,
+ BYTES_PER_PAGE=pool.bytes_per_page,
+ SCALE_OFFSET=pool.scale_offset_in_page,
+ num_warps=1,
+ )
+ return k_fp8, k_scale, slots
+
+
+@triton.jit
+def _build_c4_indexer_page_table_kernel(
+ req_idx_ptr, # [batch] int
+ c4_len_ptr, # [batch] int
+ req_to_token_ptr,
+ req_to_token_stride0,
+ full_to_c4_ptr,
+ page_table_ptr, # [batch, page_cap] int32
+ valid_flag_ptr, # [1] int32, initialized to 1; set to 0 on layout mismatch
+ page_cap,
+ hold_req_id,
+ RATIO: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ VALIDATE: tl.constexpr,
+):
+ p = tl.program_id(0)
+ r = tl.program_id(1)
+ req = tl.load(req_idx_ptr + r).to(tl.int64)
+ c4_len = tl.load(c4_len_ptr + r).to(tl.int64)
+ page_start = p * PAGE_SIZE
+ active = (req != hold_req_id) & (page_start < c4_len)
+
+ full_pos0 = page_start * RATIO + (RATIO - 1)
+ full_slot0 = tl.load(
+ req_to_token_ptr + req * req_to_token_stride0 + full_pos0,
+ mask=active,
+ other=0,
+ ).to(tl.int64)
+ c4_slot0 = tl.load(full_to_c4_ptr + full_slot0, mask=active, other=0).to(tl.int64)
+ phys_page = c4_slot0 // PAGE_SIZE
+ tl.store(page_table_ptr + r * page_cap + p, tl.where(active, phys_page, 0).to(tl.int32))
+
+ if VALIDATE:
+ offs = tl.arange(0, PAGE_SIZE)
+ e = page_start + offs
+ valid = active & (e < c4_len)
+ full_pos = e * RATIO + (RATIO - 1)
+ full_slot = tl.load(
+ req_to_token_ptr + req * req_to_token_stride0 + full_pos,
+ mask=valid,
+ other=0,
+ ).to(tl.int64)
+ c4_slot = tl.load(full_to_c4_ptr + full_slot, mask=valid, other=-1).to(tl.int64)
+ expected = phys_page * PAGE_SIZE + offs
+ ok = tl.where(valid, (c4_slot == expected) & (c4_slot >= 0), True)
+ if tl.min(ok.to(tl.int32), axis=0) == 0:
+ tl.store(valid_flag_ptr, 0)
+
+
+@torch.no_grad()
+def build_c4_indexer_page_table(
+ mem_manager,
+ b_req_idx: torch.Tensor,
+ c4_len: torch.Tensor,
+ c4_cap: int,
+ req_to_token_indexs: torch.Tensor,
+ hold_req_id: int,
+ validate: bool = False,
+):
+ """Build the logical-c4-page -> physical-c4-page table expected by DeepGEMM paged logits.
+
+ This is safe only when each logical c4 page maps to a physical page with matching offsets:
+ c4_slot(entry p*64 + o) == page_table[p] * 64 + o
+ The optional validation flag checks that invariant and lets the caller fall back to the
+ gather path while we keep the current token-slot allocator.
+ """
+ pool = mem_manager.c4_indexer_pool
+ page_size = pool.page_size
+ assert c4_cap % page_size == 0
+ batch = b_req_idx.shape[0]
+ page_cap = c4_cap // page_size
+ page_table = torch.empty((batch, page_cap), dtype=torch.int32, device=b_req_idx.device)
+ valid_flag = torch.ones((1,), dtype=torch.int32, device=b_req_idx.device)
+ _build_c4_indexer_page_table_kernel[(page_cap, batch)](
+ b_req_idx,
+ c4_len,
+ req_to_token_indexs,
+ req_to_token_indexs.stride(0),
+ mem_manager.full_to_c4_indexs,
+ page_table,
+ valid_flag,
+ page_cap,
+ int(hold_req_id),
+ RATIO=4,
+ PAGE_SIZE=page_size,
+ VALIDATE=validate,
+ num_warps=1,
+ )
+ return page_table, valid_flag
diff --git a/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py
new file mode 100644
index 0000000000..47d87d4932
--- /dev/null
+++ b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
+ """float8_e8m0fnu encodes 2**(byte-127); torch decodes it correctly on .to(float32)."""
+ return scale.to(torch.float32)
+
+
+def dequant_fp8_block_to_bf16(weight_e4m3: torch.Tensor, scale_e8m0: torch.Tensor, block_size: int = 128):
+ """De-quantize an FP8 e4m3 weight [out, in] with block-[bs,bs] ue8m0 scale to bf16."""
+ from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant
+
+ w = weight_e4m3.cuda().contiguous()
+ s = e8m0_to_fp32(scale_e8m0).cuda().contiguous()
+ # weight_dequant runs with torch default dtype for the output; force bf16 result.
+ return weight_dequant(w, s, block_size)
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 1bdf8f3427..e745492173 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -161,6 +161,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
"qwen",
"deepseekv31",
"deepseekv32",
+ "deepseekv4",
"glm47",
"kimi_k2",
"qwen3_coder",
@@ -620,8 +621,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
type=str,
default=None,
choices=["fp8", "fp4"],
- help="""Expert quantization dtype for EP MoE. Supported values are
- fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""",
+ help="""Requested dtype for MoE expert weights, fp8 or fp4. Resolves the fused_moe
+ quant method: fp8 -> deepgemm-fp8w8a8-b128; fp4 -> deepgemm-fp4fp8-b32 (online
+ quantization) on SM100 GPUs, or marlin-mxfp4w4a16-b32 (Marlin W4A16, TP only) on other GPUs.
+ Defaults to `expert_dtype` in config.json if present. Per-layer override:
+ --quant_cfg mix_bits with name `fused_moe`.""",
)
parser.add_argument(
"--vit_quant_type",
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 3cf431d650..c13f562af9 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -10,7 +10,11 @@
from .metrics.manager import start_metric_manager
from .embed_cache.manager import start_cache_manager
from lightllm.utils.log_utils import init_logger
-from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name
+from lightllm.utils.envs_utils import (
+ set_env_start_args,
+ set_unique_server_name,
+ get_unique_server_name,
+)
from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive
from .detokenization.manager import start_detokenization_process
from .router.manager import start_router_process
@@ -23,6 +27,8 @@
has_vision_module,
is_linear_att_mixed_model,
auto_set_max_req_total_len,
+ get_model_type,
+ get_config_json,
)
from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args
@@ -108,6 +114,19 @@ def normal_or_p_d_start(args):
else:
args.enable_multimodal = True
+ model_type = get_model_type(args.model_dir)
+ if model_type == "deepseek_v4":
+ if args.run_mode != "normal":
+ raise NotImplementedError("DeepSeek-V4 currently supports only run_mode=normal in LightLLM.")
+ if args.enable_cpu_cache or args.enable_disk_cache:
+ raise NotImplementedError("DeepSeek-V4 CPU/disk KV cache is not supported yet.")
+ if args.mtp_mode is not None or args.mtp_draft_model_dir is not None or args.mtp_step != 0:
+ raise NotImplementedError("DeepSeek-V4 MTP/speculative decoding is not supported yet.")
+ if args.enable_ep_moe:
+ raise NotImplementedError("DeepSeek-V4 EP MoE is not supported yet; use TP for now.")
+ if "prompt_cache_kv_buffer" in get_config_json(args.model_dir):
+ raise NotImplementedError("DeepSeek-V4 prompt_cache_kv_buffer is not supported yet.")
+
if args.enable_cpu_cache:
# 生成一个用于创建cpu kv cache的共享内存id。
args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789
@@ -333,7 +352,14 @@ def normal_or_p_d_start(args):
from lightllm.utils.config_utils import get_dtype
args.data_type = get_dtype(args.model_dir)
- assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
+ assert args.data_type in [
+ "fp16",
+ "float16",
+ "bf16",
+ "bfloat16",
+ "fp32",
+ "float32",
+ ]
already_uesd_ports = [args.port]
if args.nccl_port is not None:
@@ -425,7 +451,6 @@ def normal_or_p_d_start(args):
)
if not args.disable_vision:
-
if not args.visual_use_proxy_mode:
from .visualserver.manager import start_visual_process
@@ -609,7 +634,14 @@ def visual_only_start(args):
from lightllm.utils.config_utils import get_dtype
args.data_type = get_dtype(args.model_dir)
- assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
+ assert args.data_type in [
+ "fp16",
+ "float16",
+ "bf16",
+ "bfloat16",
+ "fp32",
+ "float32",
+ ]
logger.info(f"alloced ports: {can_use_ports}")
diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py
index 54d22a0d0d..0565a8f0cf 100644
--- a/lightllm/server/build_prompt.py
+++ b/lightllm/server/build_prompt.py
@@ -53,6 +53,13 @@ def tokenizer_supports_force_thinking() -> bool:
assert tokenizer is not None
+ # Tokenizers that encode prompts in Python (e.g. DeepSeek-V4) have no Jinja
+ # chat_template string to inspect, so advertise thinking support via an
+ # explicit attribute instead.
+ if getattr(tokenizer, "supports_thinking", False):
+ logger.info("tokenizer_supports_force_thinking : True (explicit attribute)")
+ return True
+
try:
ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template
logger.debug(f"chat_template: {tokenizer.chat_template}")
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index dfcb2f8d9e..63c9f6ac8f 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -40,6 +40,7 @@
"[TOOL_CALLS]",
"<|tool▁calls▁begin|>",
"<|DSML|function_calls>",
+ "<|DSML|tool_calls>",
]
@@ -1480,11 +1481,14 @@ class DeepSeekV32Detector(BaseFormatDetector):
Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2
"""
- def __init__(self):
+ def __init__(self, block_name: str = "function_calls"):
super().__init__()
self.dsml_token = "|DSML|"
- self.bot_token = f"<{self.dsml_token}function_calls>"
- self.eot_token = f"{self.dsml_token}function_calls>"
+ # DeepSeek V3.2 wraps tool calls in a `function_calls` block; V4 uses
+ # `tool_calls`. Only the outer block name differs — the invoke/parameter
+ # grammar is identical — so subclasses just override block_name.
+ self.bot_token = f"<{self.dsml_token}{block_name}>"
+ self.eot_token = f"{self.dsml_token}{block_name}>"
self.invoke_start_prefix = f"<{self.dsml_token}invoke"
self.invoke_end_token = f"{self.dsml_token}invoke>"
self.param_end_token = f"{self.dsml_token}parameter>"
@@ -1962,6 +1966,32 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip()
+class DeepSeekV4Detector(DeepSeekV32Detector):
+ """
+ Detector for DeepSeek V4 model function call format using DSML.
+
+ Identical grammar to V3.2 (``<|DSML|invoke name="...">`` blocks with
+ ``<|DSML|parameter name="k" string="true|false">v|DSML|parameter>``
+ tags), except the outer block is named ``tool_calls`` instead of
+ ``function_calls`` — matching the model's own encoding (encoding_dsv4.py:
+ ``tool_calls_block_name = "tool_calls"``) and system prompt.
+
+ Format Structure:
+ ```
+ <|DSML|tool_calls>
+ <|DSML|invoke name="get_weather">
+ <|DSML|parameter name="location" string="true">Hangzhou|DSML|parameter>
+ |DSML|invoke>
+ |DSML|tool_calls>
+ ```
+
+ Reference: https://huggingface.co/deepseek-ai/DeepSeek-V4
+ """
+
+ def __init__(self):
+ super().__init__(block_name="tool_calls")
+
+
class FunctionCallParser:
"""
Parser for function/tool calls in model outputs.
@@ -1975,6 +2005,7 @@ class FunctionCallParser:
"deepseekv3": DeepSeekV3Detector,
"deepseekv31": DeepSeekV31Detector,
"deepseekv32": DeepSeekV32Detector,
+ "deepseekv4": DeepSeekV4Detector,
"glm47": Glm47Detector,
"kimi_k2": KimiK2Detector,
"llama3": Llama32Detector,
diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py
index 21e26c5854..ff09c018be 100644
--- a/lightllm/server/router/dynamic_prompt/radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/radix_cache.py
@@ -2,7 +2,7 @@
import torch
import numpy as np
import collections
-from typing import Tuple, Dict, Set, List, Optional, Union
+from typing import Any, Tuple, Dict, Set, List, Optional, Union
from sortedcontainers import SortedSet
from .shared_arr import SharedArray
@@ -25,6 +25,7 @@ def __init__(self):
self.parent: TreeNode = None
self.token_id_key: torch.Tensor = None
self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置
+ self.token_extra_value: Any = None
self.ref_counter = 0
self.time_id = time_gen.generate_time_id() # 用于标识时间周期
@@ -34,14 +35,17 @@ def __init__(self):
def get_compare_key(self):
return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id)
- def split_node(self, prefix_len):
+ def split_node(self, prefix_len, child_key_fn=None, extra_value_ops=None):
split_parent_node = TreeNode()
split_parent_node.parent = self.parent
- split_parent_node.parent.children[self.token_id_key[0].item()] = split_parent_node
+ split_parent_node.parent.children[child_key_fn(self.token_id_key)] = split_parent_node
split_parent_node.token_id_key = self.token_id_key[0:prefix_len]
split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len]
+ if self.token_extra_value is not None and extra_value_ops is not None:
+ split_parent_node.token_extra_value = extra_value_ops.slice(self.token_extra_value, 0, prefix_len)
+ self.token_extra_value = extra_value_ops.slice(self.token_extra_value, prefix_len, len(self.token_id_key))
split_parent_node.children = {}
- split_parent_node.children[self.token_id_key[prefix_len].item()] = self
+ split_parent_node.children[child_key_fn(self.token_id_key[prefix_len:])] = self
split_parent_node.ref_counter = self.ref_counter
new_len = len(split_parent_node.token_mem_index_value)
@@ -56,11 +60,12 @@ def split_node(self, prefix_len):
self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len
return split_parent_node
- def add_and_return_new_child(self, token_id_key, token_mem_index_value):
+ def add_and_return_new_child(self, token_id_key, token_mem_index_value, token_extra_value=None, child_key=None):
child = TreeNode()
child.token_id_key = token_id_key
child.token_mem_index_value = token_mem_index_value
- first_token_key = child.token_id_key[0].item()
+ child.token_extra_value = token_extra_value
+ first_token_key = child.token_id_key[0].item() if child_key is None else child_key
assert first_token_key not in self.children.keys()
self.children[first_token_key] = child
child.parent = self
@@ -71,9 +76,17 @@ def add_and_return_new_child(self, token_id_key, token_mem_index_value):
return child
def remove_child(self, child_node: "TreeNode"):
- del self.children[child_node.token_id_key[0].item()]
- child_node.parent = None
- return
+ child_key = child_node.token_id_key[0].item()
+ if child_key in self.children:
+ del self.children[child_key]
+ child_node.parent = None
+ return
+ for key, value in list(self.children.items()):
+ if value is child_node:
+ del self.children[key]
+ child_node.parent = None
+ return
+ raise KeyError("child node not found")
def update_time(self):
self.time_id = time_gen.generate_time_id()
@@ -103,12 +116,22 @@ class RadixCache:
unique_name 主要用于解决单机,多实列部署时的shm冲突
"""
- def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None):
+ def __init__(
+ self,
+ unique_name,
+ total_token_num,
+ rank_in_node,
+ mem_manager=None,
+ page_size: int = 1,
+ extra_value_ops=None,
+ ):
from lightllm.common.kv_cache_mem_manager import MemoryManager
self.mem_manager: MemoryManager = mem_manager
self._key_dtype = torch.int64
self._value_dtype = torch.int64
+ self.page_size = max(1, int(page_size))
+ self.extra_value_ops = extra_value_ops
self.root_node = TreeNode()
self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype)
@@ -125,30 +148,66 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
)
self.tree_total_tokens_num.arr[0] = 0
- def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]:
+ def _align_len(self, length: int) -> int:
+ if self.page_size <= 1:
+ return int(length)
+ return int(length) // self.page_size * self.page_size
+
+ def align_len(self, length: int) -> int:
+ return self._align_len(length)
+
+ def _child_key(self, key: torch.Tensor):
+ if self.page_size <= 1:
+ return key[0].item()
+ return tuple(key[: self.page_size].tolist())
+
+ def _match_len(self, key: torch.Tensor, node_key: torch.Tensor) -> int:
+ prefix_len = match(key, node_key)
+ return self._align_len(prefix_len)
+
+ def _slice_extra(self, extra_value, start: int, end: int):
+ if extra_value is None:
+ return None
+ assert self.extra_value_ops is not None
+ return self.extra_value_ops.slice(extra_value, start, end)
+
+ def _concat_extra(self, values: list):
+ values = [v for v in values if v is not None]
+ if len(values) == 0:
+ return None
+ assert self.extra_value_ops is not None
+ return self.extra_value_ops.concat(values)
+
+ def insert(self, key, value=None, extra_value=None) -> Tuple[int, Optional[TreeNode]]:
if value is None:
value = key
+ align_len = self._align_len(len(key))
+ key = key[:align_len]
+ value = value[:align_len]
+ if extra_value is not None:
+ extra_value = self._slice_extra(extra_value, 0, align_len)
+
assert len(key) == len(value) # and len(key) >= 1
if len(key) == 0:
return 0, None
- return self._insert_helper(self.root_node, key, value)
+ return self._insert_helper(self.root_node, key, value, extra_value)
- def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]:
+ def _insert_helper(self, node: TreeNode, key, value, extra_value) -> Tuple[int, Optional[TreeNode]]:
handle_stack = collections.deque()
update_list = collections.deque()
- handle_stack.append((node, key, value))
+ handle_stack.append((node, key, value, extra_value))
ans_prefix_len = 0
ans_node = None
while len(handle_stack) != 0:
- node, key, value = handle_stack.popleft()
- ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value)
- if len(ans_tuple) == 4:
- (_prefix_len, new_node, new_key, new_value) = ans_tuple
+ node, key, value, extra_value = handle_stack.popleft()
+ ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value, extra_value=extra_value)
+ if len(ans_tuple) == 5:
+ (_prefix_len, new_node, new_key, new_value, new_extra_value) = ans_tuple
ans_prefix_len += _prefix_len
- handle_stack.append((new_node, new_key, new_value))
+ handle_stack.append((new_node, new_key, new_value, new_extra_value))
else:
_prefix_len, ans_node = ans_tuple
ans_prefix_len += _prefix_len
@@ -166,15 +225,15 @@ def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[Tree
return ans_prefix_len, ans_node
def _insert_helper_no_recursion(
- self, node: TreeNode, key: torch.Tensor, value: torch.Tensor
- ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]:
+ self, node: TreeNode, key: torch.Tensor, value: torch.Tensor, extra_value=None
+ ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor, Any]]:
if node.is_leaf():
self.evict_tree_set.discard(node)
- first_key_id = key[0].item()
+ first_key_id = self._child_key(key)
if first_key_id in node.children.keys():
child: TreeNode = node.children[first_key_id]
- prefix_len = match(key, child.token_id_key)
+ prefix_len = self._match_len(key, child.token_id_key)
if prefix_len == len(key):
if prefix_len == len(child.token_id_key):
if child.is_leaf():
@@ -184,10 +243,14 @@ def _insert_helper_no_recursion(
self.evict_tree_set.add(child)
return prefix_len, child
elif prefix_len < len(child.token_id_key):
+ if prefix_len == 0:
+ return 0, node
if child.is_leaf():
self.evict_tree_set.discard(child)
- split_parent_node = child.split_node(prefix_len)
+ split_parent_node = child.split_node(
+ prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops
+ )
if split_parent_node.is_leaf():
self.evict_tree_set.add(split_parent_node)
@@ -199,13 +262,23 @@ def _insert_helper_no_recursion(
assert False, "can not run to here"
elif prefix_len < len(key) and prefix_len < len(child.token_id_key):
+ if prefix_len == 0:
+ return 0, node
if child.is_leaf():
self.evict_tree_set.discard(child)
+ new_extra_value = self._slice_extra(extra_value, prefix_len, len(key))
key = key[prefix_len:]
value = value[prefix_len:]
- split_parent_node = child.split_node(prefix_len)
- new_node = split_parent_node.add_and_return_new_child(key, value)
+ split_parent_node = child.split_node(
+ prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops
+ )
+ new_node = split_parent_node.add_and_return_new_child(
+ key,
+ value,
+ token_extra_value=new_extra_value,
+ child_key=self._child_key(key),
+ )
# update total token num
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
@@ -218,12 +291,23 @@ def _insert_helper_no_recursion(
self.evict_tree_set.add(child)
return prefix_len, new_node
elif prefix_len < len(key) and prefix_len == len(child.token_id_key):
- return (prefix_len, child, key[prefix_len:], value[prefix_len:])
+ return (
+ prefix_len,
+ child,
+ key[prefix_len:],
+ value[prefix_len:],
+ self._slice_extra(extra_value, prefix_len, len(key)),
+ )
else:
assert False, "can not run to here"
else:
- new_node = node.add_and_return_new_child(key, value)
+ new_node = node.add_and_return_new_child(
+ key,
+ value,
+ token_extra_value=extra_value,
+ child_key=first_key_id,
+ )
# update total token num
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
if new_node.is_leaf():
@@ -231,7 +315,12 @@ def _insert_helper_no_recursion(
return 0, new_node
def match_prefix(self, key, update_refs=False):
- assert len(key) != 0
+ key = key[: self._align_len(len(key))]
+ if len(key) == 0:
+ return None, 0, None
+ key = self._trim_key_by_extra_value_validity(key)
+ if len(key) == 0:
+ return None, 0, None
ans_value_list = []
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
if tree_node != self.root_node:
@@ -245,6 +334,30 @@ def match_prefix(self, key, update_refs=False):
self.dec_node_ref_counter(self.root_node)
return None, 0, None
+ def _trim_key_by_extra_value_validity(self, key: torch.Tensor) -> torch.Tensor:
+ """命中有效性裁剪(extra_value_ops 提供 valid_match_length 时启用,如 DeepSeek-V4 的
+ swa 按页 bitmap): 先做一次只读探测遍历得到自然命中与沿路 extra_value,按其有效边界截短
+ key,随后的正常遍历(加引用/分裂)只走截短后的前缀 —— 引用计数与最终返回值在同一次遍历
+ 内保持一致,不存在事后裁剪导致的漏减/多减。
+
+ 探测遍历可能分裂部分命中的节点(与正常遍历同语义,树不变式不受影响)。裁剪只会缩短命中,
+ 没有任何失败路径。"""
+ if self.extra_value_ops is None:
+ return key
+ valid_match_length = getattr(self.extra_value_ops, "valid_match_length", None)
+ if valid_match_length is None:
+ return key
+ probe_values = []
+ probe_node = self._match_prefix_helper(self.root_node, key, probe_values, update_refs=False)
+ if probe_node == self.root_node or len(probe_values) == 0:
+ return key
+ natural_len = sum(len(v) for v in probe_values)
+ extra_value = self.get_extra_value_by_node(probe_node)
+ valid_len = int(valid_match_length(extra_value, natural_len))
+ if valid_len < natural_len:
+ return key[:valid_len]
+ return key
+
def _match_prefix_helper(
self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False
) -> TreeNode:
@@ -290,20 +403,24 @@ def _match_prefix_helper_no_recursion(
if len(key) == 0:
return node
- first_key_id = key[0].item()
+ first_key_id = self._child_key(key)
if first_key_id not in node.children.keys():
return node
else:
child = node.children[first_key_id]
- prefix_len = match(key, child.token_id_key)
+ prefix_len = self._match_len(key, child.token_id_key)
if prefix_len == len(child.token_id_key):
ans_value_list.append(child.token_mem_index_value)
return (child, key[prefix_len:])
elif prefix_len < len(child.token_id_key):
+ if prefix_len == 0:
+ return node
if child.is_leaf():
self.evict_tree_set.discard(child)
- split_parent_node = child.split_node(prefix_len)
+ split_parent_node = child.split_node(
+ prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops
+ )
ans_value_list.append(split_parent_node.token_mem_index_value)
if update_refs:
@@ -334,6 +451,8 @@ def evict(self, need_remove_tokens, evict_callback):
), "error evict tree node state"
num_evicted += len(node.token_mem_index_value)
evict_callback(node.token_mem_index_value)
+ if self.extra_value_ops is not None and node.token_extra_value is not None:
+ self.extra_value_ops.free(node.token_extra_value)
# update total token num
self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
parent_node: TreeNode = node.parent
@@ -369,11 +488,12 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
child_node.token_mem_index_value = torch.cat(
[parent_node.token_mem_index_value, child_node.token_mem_index_value]
)
+ child_node.token_extra_value = self._concat_extra([parent_node.token_extra_value, child_node.token_extra_value])
child_node.node_value_len = len(child_node.token_mem_index_value)
child_node.time_id = max(parent_node.time_id, child_node.time_id)
grandparent_node = parent_node.parent
- key_in_grandparent = parent_node.token_id_key[0].item()
+ key_in_grandparent = self._child_key(parent_node.token_id_key)
grandparent_node.children[key_in_grandparent] = child_node
child_node.parent = grandparent_node
@@ -469,6 +589,19 @@ def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]:
ans_list.reverse()
return torch.concat(ans_list, dim=0)
+ def get_extra_value_by_node(self, node: TreeNode):
+ if node is None or self.extra_value_ops is None:
+ return None
+
+ ans_list = []
+ while node is not None:
+ if node.token_extra_value is not None:
+ ans_list.append(node.token_extra_value)
+ node = node.parent
+
+ ans_list.reverse()
+ return self._concat_extra(ans_list)
+
def get_refed_tokens_num(self):
return self.refed_tokens_num.arr[0]
@@ -489,6 +622,36 @@ def _print_helper(self, node: TreeNode, indent):
self._print_helper(child, indent=indent + 2)
return
+ def reclaim_unreferenced_swa_pages(self, need_pages: int) -> None:
+ """DeepSeek-V4 swa 压力阀: 页 allocator 触底时,沿 LRU 序(evict_tree_set)只对
+ ref_count==0 的节点链回收其 swa 页(full 槽与压缩条目保留——节点仍可服务更长前缀的
+ 中段命中),并清载荷 bitmap 位使后续命中按缩短语义裁剪。所有权判定直接复用 radix
+ 引用计数: 节点被任何活跃请求借用即 ref>0,其页不可达。不够时由 allocator 的 assert
+ 兜底(最后防线)。"""
+ if self.mem_manager is None or self.extra_value_ops is None:
+ return
+ invalidate = getattr(self.extra_value_ops, "invalidate_swa_pages", None)
+ if invalidate is None:
+ return
+ allocator = self.mem_manager.swa_page_allocator
+ target = allocator.can_use_mem_size + int(need_pages)
+ for leaf in list(self.evict_tree_set):
+ if allocator.can_use_mem_size >= target:
+ break
+ node = leaf
+ # 叶子起步沿父链回收: 引用计数向上累加(add_node_ref_counter 走父链),
+ # 因此 ref==0 的祖先必无任何活跃借用方。重复访问无害(evict_swa/-1 跳过)。
+ # 每回收一个节点就复查目标,避免多回收(无谓削减命中可用性)。
+ while node is not None and node is not self.root_node and node.ref_counter == 0:
+ if len(node.token_mem_index_value) > 0:
+ self.mem_manager.evict_swa(node.token_mem_index_value)
+ if node.token_extra_value is not None:
+ invalidate(node.token_extra_value)
+ if allocator.can_use_mem_size >= target:
+ return
+ node = node.parent
+ return
+
def free_radix_cache_to_get_enough_token(self, need_token_num):
assert self.mem_manager is not None
if need_token_num > self.mem_manager.allocator.can_use_mem_size:
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 5c2d0d45fb..2bf2314185 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -8,7 +8,7 @@
from sortedcontainers import SortedDict
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Callable, Any, Union
-from lightllm.common.req_manager import ReqManager, ReqManagerForMamba
+from lightllm.common.req_manager import DeepseekV4ReqManager, ReqManager, ReqManagerForMamba
from lightllm.utils.infer_utils import mark_start, mark_end
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
@@ -50,7 +50,9 @@ def register(
vocab_size: int,
):
self.args = get_env_start_args()
- from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
+ from lightllm.server.router.model_infer.mode_backend.base_backend import (
+ ModeBackend,
+ )
self.backend: ModeBackend = backend
self.req_manager = req_manager
@@ -122,11 +124,20 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
return req_objs
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
+ is_dsv4_req_manager = hasattr(self.req_manager, "build_prompt_cache_payload")
if self.radix_cache is None:
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
+ if is_dsv4_req_manager:
+ # 槽位随 full 槽经 mem_manager.free 级联回收。pause 路径不释放 req_idx,
+ # 必须在此复位出窗水位线 + 清 c128 在途状态(恢复命中走 extend,不会再有
+ # restore/zero 时机;c4 状态随 swa 页生灭,无需处理)。
+ self.req_manager.init_compress_state(req.req_idx)
else:
if not self.is_linear_att_mixed_model:
- self._full_att_free_req(free_token_index=free_token_index, req=req)
+ if is_dsv4_req_manager:
+ self._dsv4_full_att_free_req(free_token_index=free_token_index, req=req)
+ else:
+ self._full_att_free_req(free_token_index=free_token_index, req=req)
else:
self._linear_att_free_req(free_token_index=free_token_index, req=req)
assert len(req.linear_att_len_to_big_page_id) == 0
@@ -134,6 +145,11 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
req.shm_req.shm_cur_kv_len = req.cur_kv_len
return
+ def _append_free_token_index(self, free_token_index: List, tensor: torch.Tensor):
+ if tensor.numel() > 0:
+ free_token_index.append(tensor)
+ return
+
def _full_att_free_req(self, free_token_index: List, req: "InferReq"):
input_token_ids = req.get_input_token_ids()
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
@@ -149,6 +165,68 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"):
req.shared_kv_node = None
return
+ def _dsv4_full_att_free_req(self, free_token_index: List, req: "InferReq"):
+ if req.cur_kv_len == 0:
+ free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0:0])
+ return
+
+ old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
+ inserted_len = old_prefix_len
+ duplicate_prefix_len = old_prefix_len
+
+ # 载荷只剩按页 bitmap(compressor 状态随 swa 页生灭/边界自然归零,不进载荷),
+ # 任意 128 对齐前缀皆可插入——含生成段(floor(cur_kv_len) 边界,回收保留尾页保证其驻留)。
+ cache_len = self.radix_cache.align_len(req.cur_kv_len)
+ self.req_manager: DeepseekV4ReqManager
+ if cache_len > old_prefix_len:
+ payload = self.req_manager.build_prompt_cache_payload(req.req_idx, cache_len)
+ value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu()
+ # 按页有效性 bitmap 用插入时刻的映射写定(此后只会被阀清 0,不会复活)。水位线
+ # 纯 CPU 推导,避免 router 关键路径上的 GPU gather 同步(每插入一次要等全部在途
+ # decode kernel)。插入门: 截掉结尾的 invalid 页 —— 它们生来不可命中,还会永久
+ # 挡住后续更长前缀复用同一段 token(全量重插会因前缀已存在而保留旧 bitmap)。
+ page_size = self.req_manager.get_prompt_cache_page_size()
+ bitmap = self.req_manager.swa_page_valid_from_watermark(req.req_idx, cache_len)
+ n_pages = int(bitmap.numel())
+ while n_pages > 0 and not bool(bitmap[n_pages - 1]):
+ n_pages -= 1
+ gated_len = n_pages * page_size
+ if gated_len < cache_len:
+ logger.info(
+ f"DeepSeek-V4 prompt cache insert gate: trailing swa pages already evicted, "
+ f"shrink insert {cache_len} -> {gated_len}"
+ )
+ cache_len = gated_len
+ payload.cache_len = cache_len
+ payload.swa_page_valid = bitmap[:n_pages].clone()
+
+ if cache_len > old_prefix_len:
+ input_token_ids = req.get_input_token_ids()
+ key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu")
+ duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value[:cache_len], extra_value=payload)
+ inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len
+ if inserted_len != cache_len:
+ inserted_len = old_prefix_len
+ duplicate_prefix_len = old_prefix_len
+
+ dense_row = self.req_manager.req_to_token_indexs[req.req_idx]
+ self._append_free_token_index(free_token_index, dense_row[old_prefix_len:duplicate_prefix_len])
+ self._append_free_token_index(free_token_index, dense_row[inserted_len : req.cur_kv_len])
+ if len(free_token_index) == 0:
+ free_token_index.append(dense_row[0:0])
+ # 释放的 full 槽经 mem_manager.free 级联回收 swa/c4/c128(映射键控,无需收集槽位)。
+
+ # pause 路径不会走 req_manager.free/init: 复位出窗水位线(残留水位线会破坏下一次
+ # prefill 的共享前缀保护)并清 c128 在途状态(恢复命中走 extend 续算,若残留暂停前的
+ # 半窗聚合会算错;c128 状态在 128 对齐命中边界本应为零)。
+ self.req_manager.init_compress_state(req.req_idx)
+
+ if req.shared_kv_node is not None:
+ assert req.shared_kv_node.node_prefix_total_len <= max(inserted_len, old_prefix_len)
+ self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
+ req.shared_kv_node = None
+ return
+
def _linear_att_free_req(self, free_token_index: List, req: "InferReq"):
assert g_infer_context.is_linear_att_mixed_model is True
args = get_env_start_args()
@@ -325,7 +403,12 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
self.req_manager.free_token(free_token_index)
return self
- def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int):
+ def recover_paused_reqs(
+ self,
+ paused_reqs: List["InferReq"],
+ is_master_in_dp: bool,
+ can_alloc_token_num: int,
+ ):
if paused_reqs:
for req in paused_reqs:
@@ -382,7 +465,9 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
)
big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True)
- from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer
+ from lightllm.common.basemodel.triton_kernel.linear_att_copy import (
+ copy_linear_att_state_to_kv_buffer,
+ )
copy_linear_att_state_to_kv_buffer(
b_req_idx=b_req_idx,
@@ -412,9 +497,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...]
dst_buffer_idx = req.tail_linear_att_small_page_buffer_id
- dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache(
- buffer_idx=dst_buffer_idx
- )
+ (
+ dst_conv_state,
+ dst_ssm_state,
+ ) = self.radix_cache.linear_att_small_page_buffers.get_state_cache(buffer_idx=dst_buffer_idx)
# TODO 对于非连续对象调用 copy_ 效率并不高
dst_conv_state.copy_(gpu_conv_state, non_blocking=True)
dst_ssm_state.copy_(gpu_ssm_state, non_blocking=True)
@@ -591,6 +677,8 @@ def _init_all_state(self):
self.cur_output_len = 0
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
+ if hasattr(g_infer_context.req_manager, "init_compress_state"):
+ g_infer_context.req_manager.init_compress_state(req_idx=self.req_idx)
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
# token healing mode 才被使用的管理对象
@@ -626,6 +714,9 @@ def _match_radix_cache(self):
ready_cache_len = share_node.node_prefix_total_len
# 从 cpu 到 gpu 是流内阻塞操作
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
+ # DeepSeek-V4 命中无需任何恢复: 槽位由 full_to_* 映射键控(radix 持有 full 槽即有效,
+ # 命中长度已在 match_prefix 内按 bitmap 裁剪),c4 compressor 状态随 swa 页常驻
+ # (零拷贝续算),c128 状态在 128 对齐边界自然归零(init_compress_state 已清)。
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
@@ -639,7 +730,10 @@ def _linear_match_radix_cache(self):
enable_prompt_cache = (not self.sampling_param.disable_prompt_cache) and g_infer_context.radix_cache is not None
linear_hash_list = self.shm_req.linear_att_token_hash_list.get_all()
linear_att_hash_page_size = self.args.linear_att_hash_page_size
- match_tokens = min(len(linear_hash_list) * linear_att_hash_page_size, self.get_cur_total_len() - 1)
+ match_tokens = min(
+ len(linear_hash_list) * linear_att_hash_page_size,
+ self.get_cur_total_len() - 1,
+ )
match_tokens = max(0, match_tokens)
match_tokens = (match_tokens // linear_att_hash_page_size) * linear_att_hash_page_size
match_block_num = match_tokens // linear_att_hash_page_size
@@ -705,7 +799,8 @@ def _linear_match_radix_cache(self):
# 将 对应的 value_tensors 中的 kv 数据 拷贝到 tail_mems 中对应的数据去
radix_cache.mem_manager.operator.copy_mem_to_mem(
- value_tensor[cur_big_page_tokens:shared_kv_len], tail_mems
+ value_tensor[cur_big_page_tokens:shared_kv_len],
+ tail_mems,
)
self.shared_kv_node = share_node # 只是为了保证 copy_small_page_buffer_to_linear_att_state 正确调用
@@ -736,7 +831,8 @@ def _linear_match_radix_cache(self):
assert self.tail_linear_att_small_page_buffer_id is None
# 恢复linear att 状态
g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state(
- big_page_buffer_idx=share_node.big_page_buffer_idx, req=self
+ big_page_buffer_idx=share_node.big_page_buffer_idx,
+ req=self,
)
self.shm_req.shm_cur_kv_len = self.cur_kv_len
@@ -792,6 +888,7 @@ def get_input_token_ids(self):
def get_chuncked_input_token_ids(self):
chunked_start = self.cur_kv_len
chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size)
+ chunked_end = self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end)
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
def get_chuncked_input_token_ids_for_linear_att(self):
@@ -812,6 +909,23 @@ def get_chuncked_input_token_ids_for_linear_att(self):
def get_chuncked_input_token_len(self):
chunked_start = self.cur_kv_len
chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size)
+ return self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end)
+
+ def _align_chuncked_end_for_prompt_cache(self, chunked_start: int, chunked_end: int):
+ radix_cache = g_infer_context.radix_cache
+ page_size = getattr(radix_cache, "page_size", 1) if radix_cache is not None else 1
+ if page_size <= 1 or self.sampling_param.disable_prompt_cache:
+ return chunked_end
+ prompt_end = int(self.shm_req.input_len)
+ chunked_start = int(chunked_start)
+ chunked_end = int(chunked_end)
+ if chunked_end >= prompt_end:
+ return chunked_end
+
+ assert self.args.chunked_prefill_size % page_size == 0, (
+ f"chunked_prefill_size={self.args.chunked_prefill_size} must be divisible by "
+ f"prompt-cache page_size={page_size}"
+ )
return chunked_end
def get_chuncked_input_token_len_for_linear_att(self):
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index a65dfb1bbb..1f19b9462a 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -153,6 +153,8 @@ def init_model(self, kvargs):
self.model: TpPartBaseModel = self.model # for easy typing
set_random_seed(2147483647)
self.is_linear_att_mixed_model = isinstance(self.model.req_manager, ReqManagerForMamba)
+ if hasattr(self.model.req_manager, "build_prompt_cache_payload"):
+ self.support_overlap = False
if self.is_linear_att_mixed_model:
self.linear_att_cache_manager = LinearAttCacheManager(
@@ -164,6 +166,7 @@ def init_model(self, kvargs):
if not self.use_dynamic_prompt_cache:
self.radix_cache = None
+ setattr(self.args, "dynamic_prompt_cache_page_size", 1)
else:
if self.is_linear_att_mixed_model:
self.radix_cache = LinearAttPagedRadixCache(
@@ -175,13 +178,25 @@ def init_model(self, kvargs):
kv_cache_mem_manager=self.model.mem_manager,
linear_att_small_page_buffers=self.linear_att_cache_manager,
)
+ setattr(self.args, "dynamic_prompt_cache_page_size", 1)
else:
+ radix_page_size = 1
+ radix_extra_value_ops = None
+ if hasattr(self.model.req_manager, "get_prompt_cache_value_ops"):
+ radix_page_size = self.model.req_manager.get_prompt_cache_page_size()
+ radix_extra_value_ops = self.model.req_manager.get_prompt_cache_value_ops()
+ setattr(self.args, "dynamic_prompt_cache_page_size", radix_page_size)
self.radix_cache = RadixCache(
unique_name=get_unique_server_name(),
total_token_num=self.model.mem_manager.size,
rank_in_node=self.rank_in_node,
mem_manager=self.model.mem_manager,
+ page_size=radix_page_size,
+ extra_value_ops=radix_extra_value_ops,
)
+ if radix_extra_value_ops is not None and hasattr(self.model.mem_manager, "set_swa_pressure_valve"):
+ # swa 页 allocator 触底时让 radix 对 ref==0 节点回收 swa 页(DeepSeek-V4)。
+ self.model.mem_manager.set_swa_pressure_valve(self.radix_cache.reclaim_unreferenced_swa_pages)
if "prompt_cache_kv_buffer" in model_cfg:
assert self.use_dynamic_prompt_cache
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index e1a4e421d1..6aea7cd672 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -90,6 +90,11 @@ def get_tokenizer(
)
logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.")
return DeepSeekV32Tokenizer(hf_tokenizer)
+ if model_type == "deepseek_v4":
+ from ..models.deepseek_v4.model import DeepSeekV4Tokenizer
+
+ logger.info("Using DeepSeek-V4 tokenizer mode with Python-based chat template encoding.")
+ return DeepSeekV4Tokenizer(tokenizer, tokenizer_name)
if model_cfg["architectures"][0] == "TarsierForConditionalGeneration":
from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor
diff --git a/lightllm/third_party/__init__.py b/lightllm/third_party/__init__.py
new file mode 100644
index 0000000000..2adb50db25
--- /dev/null
+++ b/lightllm/third_party/__init__.py
@@ -0,0 +1 @@
+"""Third-party source subsets vendored for LightLLM runtime support."""
diff --git a/lightllm/third_party/sglang_jit/LICENSE b/lightllm/third_party/sglang_jit/LICENSE
new file mode 100755
index 0000000000..9c422689c8
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023-2024 SGLang Team
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/lightllm/third_party/sglang_jit/README.md b/lightllm/third_party/sglang_jit/README.md
new file mode 100644
index 0000000000..4f68c9cfd8
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/README.md
@@ -0,0 +1,13 @@
+# Vendored SGLang JIT Subset
+
+This directory contains the minimal SGLang JIT source subset needed by the
+DeepSeek-V4 LightLLM implementation.
+
+Source: https://github.com/sgl-project/sglang
+Commit: 8cea0473ea5299bc04885f8f6ba71269415a39b5
+License: Apache License 2.0, copied in `LICENSE`.
+
+Local changes:
+- The Python imports were moved from `sglang.jit_kernel.*` to
+ `lightllm.third_party.sglang_jit.*`.
+- The package exports only the DSv4 functions used by LightLLM.
diff --git a/lightllm/third_party/sglang_jit/__init__.py b/lightllm/third_party/sglang_jit/__init__.py
new file mode 100644
index 0000000000..164d545b4e
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/__init__.py
@@ -0,0 +1 @@
+"""Vendored SGLang JIT kernels used by DeepSeek-V4."""
diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh
new file mode 100644
index 0000000000..3a89e8114c
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh
@@ -0,0 +1,522 @@
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+#include
+
+namespace {
+
+using Plan128 = device::compress::PrefillPlan;
+using IndiceT = int32_t;
+
+/// \brief Each thread will handle this many elements (split along head_dim)
+constexpr int32_t kTileElements = 2;
+/// \brief Each warp will handle this many elements (split along 128)
+constexpr int32_t kElementsPerWarp = 8;
+constexpr uint32_t kNumWarps = 128 / kElementsPerWarp;
+constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps;
+
+/// \brief Need to reduce register usage to increase occupancy
+#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2)
+
+struct Compress128DecodeParams {
+ /**
+ * \brief Shape: `[num_indices, 128, head_dim * 2]` \n
+ * last dimension layout:
+ * | kv current | score current |
+ */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[batch_size, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[batch_size, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]` */
+ const IndiceT* __restrict__ seq_lens;
+ /** \NOTE: `batch_size` <= `num_indices` */
+ uint32_t batch_size;
+};
+
+struct Compress128PrefillParams {
+ /**
+ * \brief Shape: `[num_indices, 128, head_dim * 2]` \n
+ * last dimension layout:
+ * | kv current | score current |
+ */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[batch_size, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[batch_size, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]`*/
+ const int32_t* __restrict__ load_indices;
+ /** \brief The following part is plan info. */
+ const Plan128* __restrict__ compress_plan;
+ const Plan128* __restrict__ write_plan;
+ uint32_t num_compress;
+ uint32_t num_write;
+};
+
+struct Compress128SharedBuffer {
+ using Storage = device::AlignedVector;
+ Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict
+ SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) {
+ return data[warp_id][lane_id];
+ }
+ SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) {
+ return data[warp_id][lane_id][tile_id];
+ }
+};
+
+template
+SGL_DEVICE void c128_write(
+ T* kv_score_buf, //
+ const T* kv_score_src,
+ const int64_t head_dim,
+ const int32_t write_pos,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ using Storage = AlignedVector;
+ const auto element_size = head_dim * 2;
+ const auto gmem = tile::Memory{lane_id, kWarpThreads};
+ kv_score_buf += write_pos * element_size;
+
+ /// NOTE: Layout | [0] = kv | [1] = score |
+ Storage kv_score[2];
+#pragma unroll
+ for (int32_t i = 0; i < 2; ++i) {
+ kv_score[i] = gmem.load(kv_score_src + head_dim * i);
+ }
+#pragma unroll
+ for (int32_t i = 0; i < 2; ++i) {
+ gmem.store(kv_score_buf + head_dim * i, kv_score[i]);
+ }
+}
+
+template
+SGL_DEVICE void c128_forward(
+ const InFloat* kv_score_buf,
+ const InFloat* kv_score_src,
+ OutFloat* kv_out,
+ const InFloat* score_bias,
+ const int64_t head_dim,
+ const int32_t window_len,
+ const uint32_t warp_id,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ const auto element_size = head_dim * 2;
+ const auto score_offset = head_dim;
+
+ /// NOTE: part 1: load kv + score
+ using StorageIn = AlignedVector;
+ const auto gmem_in = tile::Memory{lane_id, kWarpThreads};
+ StorageIn kv[kElementsPerWarp];
+ StorageIn score[kElementsPerWarp];
+ StorageIn bias[kElementsPerWarp];
+ const int32_t warp_offset = warp_id * kElementsPerWarp;
+
+#pragma unroll
+ for (int32_t i = 0; i < 8; ++i) {
+ const int32_t j = i + warp_offset;
+ bias[i] = gmem_in.load(score_bias + j * head_dim);
+ }
+
+#pragma unroll
+ for (int32_t i = 0; i < kElementsPerWarp; ++i) {
+ const int32_t j = i + warp_offset;
+ const InFloat* src;
+ __builtin_assume(j < 128);
+ if (j < window_len) {
+ src = kv_score_buf + j * element_size;
+ } else {
+ /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src`
+ const int32_t k = j - 127;
+ src = kv_score_src + k * element_size;
+ }
+ kv[i] = gmem_in.load(src);
+ score[i] = gmem_in.load(src + score_offset);
+ }
+
+ /// NOTE: part 2: safe online softmax + weighted sum
+ using TmpStorage = typename Compress128SharedBuffer::Storage;
+ __shared__ Compress128SharedBuffer s_local_val_max;
+ __shared__ Compress128SharedBuffer s_local_exp_sum;
+ __shared__ Compress128SharedBuffer s_local_product;
+
+ TmpStorage tmp_val_max;
+ TmpStorage tmp_exp_sum;
+ TmpStorage tmp_product;
+
+#pragma unroll
+ for (int32_t i = 0; i < kTileElements; ++i) {
+ float score_fp32[kElementsPerWarp];
+
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]);
+ }
+
+ float max_value = score_fp32[0];
+ float sum_exp_value = 0.0f;
+
+#pragma unroll
+ for (int32_t j = 1; j < kElementsPerWarp; ++j) {
+ const auto fp32_score = score_fp32[j];
+ max_value = fmaxf(max_value, fp32_score);
+ }
+
+ float sum_product = 0.0f;
+#pragma unroll
+ for (int32_t j = 0; j < 8; ++j) {
+ const auto fp32_score = score_fp32[j];
+ const auto exp_score = expf(fp32_score - max_value);
+ sum_product += cast(kv[j][i]) * exp_score;
+ sum_exp_value += exp_score;
+ }
+
+ tmp_val_max[i] = max_value;
+ tmp_exp_sum[i] = sum_exp_value;
+ tmp_product[i] = sum_product;
+ }
+
+ // naturally aligned, so no bank conflict
+ s_local_val_max(warp_id, lane_id) = tmp_val_max;
+ s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum;
+ s_local_product(warp_id, lane_id) = tmp_product;
+
+ __syncthreads();
+
+ /// NOTE: part 3: online softmax
+ /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce
+ /// each reduce will consume `kNumWarps` threads (use partial warp reduction)
+ constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps;
+ constexpr uint32_t kIteration = kReductionCount / kBlockSize;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kIteration; ++i) {
+ /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)`
+ const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id;
+ /// NOTE: Range `[0, kNumWarps)`
+ const uint32_t local_warp_id = j % kNumWarps;
+ /// NOTE: Range `[0, kTileElements * kWarpThreads)`
+ const uint32_t local_elem_id = j / kNumWarps;
+ /// NOTE: Range `[0, kTileElements)`
+ const uint32_t local_tile_id = local_elem_id % kTileElements;
+ /// NOTE: Range `[0, kWarpThreads)`
+ const uint32_t local_lane_id = local_elem_id / kTileElements;
+ /// NOTE: each warp will access the whole tile (all `kTileElements`)
+ /// and for different lanes, the memory access only differ in `local_warp_id`
+ /// so there's no bank conflict in shared memory access.
+ static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs");
+ const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id);
+ const auto global_val_max = warp::reduce_max(local_val_max);
+ const auto rescale = expf(local_val_max - global_val_max);
+ const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale);
+ const auto final_scale = rescale / global_exp_sum;
+ const auto global_product = warp::reduce_sum(local_product * final_scale);
+ kv_out[local_elem_id] = cast(global_product);
+ }
+}
+
+template
+C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ constexpr int64_t kElementSize = kHeadDim * 2;
+ static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim");
+
+ const auto& [
+ _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score
+ indices, seq_lens, batch_size // decode info
+ ] = params;
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+
+ const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id
+ const uint32_t global_sid = blockIdx.x % kNumSplit; // split id
+ if (global_bid >= batch_size) return;
+
+ const int32_t index = indices[global_bid];
+ const int32_t seq_len = seq_lens[global_bid];
+ const int64_t split_offset = global_sid * kTileDim;
+
+ // kv score
+ const auto kv_score_buffer = static_cast(_kv_score_buffer);
+ const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset;
+
+ // kv input
+ const auto kv_score_input = static_cast(_kv_score_input);
+ const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset;
+
+ // kv output
+ const auto kv_compressed_output = static_cast(_kv_compressed_output);
+ const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset;
+
+ // score bias (ape)
+ const auto score_bias = static_cast(_score_bias) + split_offset;
+
+ PDLWaitPrimary();
+
+ /// NOTE: the write must be visible to the subsequent c128_forward,
+ /// so only the last warp can write to HBM
+ /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127`
+ if (warp_id == kNumWarps - 1) {
+ c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id);
+ }
+ if (seq_len % 128 == 0) {
+ c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id);
+ }
+
+ PDLTriggerSecondary();
+}
+
+// compress kernel
+template
+C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ constexpr int64_t kElementSize = kHeadDim * 2;
+ static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim");
+
+ const auto& [
+ _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score
+ indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan
+ ] = params;
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+
+ uint32_t global_id;
+ if constexpr (kWrite) {
+ // for write kernel, we use global warp_id to dispatch work
+ global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads;
+ } else {
+ // for compress kernel, we use block id to dispatch work
+ global_id = blockIdx.x; // block id
+ }
+ const uint32_t global_pid = global_id / kNumSplit; // plan id
+ const uint32_t global_sid = global_id % kNumSplit; // split id
+
+ /// NOTE: compiler can optimize this if-else at compile time
+ const auto num_plans = kWrite ? num_write : num_compress;
+ const auto plan_ptr = kWrite ? write_plan : compress_plan;
+ if (global_pid >= num_plans) return;
+
+ const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid];
+ const auto indices_ptr = kWrite ? indices : load_indices;
+
+ const int64_t split_offset = global_sid * kTileDim;
+
+ // kv input
+ const auto kv_score_input = static_cast(_kv_score_input);
+ const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset;
+
+ // kv output
+ const auto kv_compressed_output = static_cast(_kv_compressed_output);
+ const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset;
+
+ // score bias (ape)
+ const auto score_bias = static_cast(_score_bias) + split_offset;
+
+ if (ragged_id == 0xFFFFFFFF) [[unlikely]]
+ return;
+
+ const int32_t index = indices_ptr[global_bid];
+ // kv score
+ const auto kv_score_buffer = static_cast(_kv_score_buffer);
+ const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset;
+
+ PDLWaitPrimary();
+
+ // only responsible for the compress part
+ if constexpr (kWrite) {
+ c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id);
+ } else {
+ c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id);
+ }
+
+ PDLTriggerSecondary();
+}
+
+template
+struct FlashCompress128Kernel {
+ static constexpr auto decode_kernel = flash_c128_decode;
+ template
+ static constexpr auto prefill_kernel = flash_c128_prefill;
+ static constexpr auto prefill_c_kernel = prefill_kernel*kWrite=*/false>;
+ static constexpr auto prefill_w_kernel = prefill_kernel*kWrite=*/true>;
+ static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64
+ static constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static constexpr uint32_t kWriteBlockSize = 128;
+ static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads;
+
+ static void run_decode(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::Optional /* UNUSED */) {
+ using namespace host;
+
+ // this should not happen in practice
+ auto B = SymbolicSize{"batch_size"};
+ auto device = SymbolicDevice{};
+ device.set_options();
+
+ TensorMatcher({-1, 128, kHeadDim * 2}) // kv score
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_buffer);
+ TensorMatcher({B, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_input);
+ TensorMatcher({B, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device)
+ .verify(ape);
+ TensorMatcher({B}) // indices
+ .with_dtype()
+ .with_device(device)
+ .verify(indices);
+ TensorMatcher({B}) // seq lens
+ .with_dtype()
+ .with_device(device)
+ .verify(seq_lens);
+
+ const auto batch_size = static_cast(B.unwrap());
+ const auto params = Compress128DecodeParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .batch_size = batch_size,
+ };
+
+ const uint32_t num_blocks = batch_size * kNumSplit;
+ LaunchKernel(num_blocks, kBlockSize, device.unwrap()) //
+ .enable_pdl(kUsePDL)(decode_kernel, params);
+ }
+
+ static void run_prefill(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView compress_plan,
+ const tvm::ffi::TensorView write_plan,
+ const tvm::ffi::Optional extra) {
+ using namespace host;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto X = SymbolicSize{"compress_tokens"};
+ auto Y = SymbolicSize{"write_tokens"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 128, kHeadDim * 2}) // kv score
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({N, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({N, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+ TensorMatcher({B}) // indices
+ .with_dtype()
+ .with_device(device_)
+ .verify(indices);
+ TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(compress_plan);
+ TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(write_plan);
+
+ // might be needed for prefill write
+ const auto load_indices = extra.value_or(indices);
+ TensorMatcher({B}) // [read_positions]
+ .with_dtype()
+ .with_device(device_)
+ .verify(load_indices);
+
+ const auto device = device_.unwrap();
+ const auto batch_size = static_cast(B.unwrap());
+ const auto num_q_tokens = static_cast(N.unwrap());
+ const auto num_c = static_cast(X.unwrap());
+ const auto num_w = static_cast(Y.unwrap());
+ const auto params = Compress128PrefillParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .load_indices = static_cast(load_indices.data_ptr()),
+ .compress_plan = static_cast(compress_plan.data_ptr()),
+ .write_plan = static_cast(write_plan.data_ptr()),
+ .num_compress = num_c,
+ .num_write = num_w,
+ };
+ RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size");
+ RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan");
+
+ constexpr auto kBlockSize_C = kBlockSize;
+ constexpr auto kBlockSize_W = kWriteBlockSize;
+ if (const auto num_c_blocks = num_c * kNumSplit) {
+ LaunchKernel(num_c_blocks, kBlockSize_C, device) //
+ .enable_pdl(kUsePDL)(prefill_c_kernel, params);
+ }
+ if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) {
+ LaunchKernel(num_w_blocks, kBlockSize_W, device) //
+ .enable_pdl(kUsePDL)(prefill_w_kernel, params);
+ }
+ }
+};
+
+} // namespace
diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh
new file mode 100644
index 0000000000..b497470606
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh
@@ -0,0 +1,726 @@
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+namespace device::compress {
+
+/// \brief Plan entry for online compress 128 prefill.
+/// Each entry describes a contiguous segment of tokens that lies inside a
+/// single 128-chunk. Multiple segments can map to the same batch id when the
+/// extend tokens span chunk boundaries.
+///
+/// **Layout compatibility:** the field order/types match `PrefillPlan` so that
+/// downstream kernels (e.g. `fused_norm_rope` in `CompressExtend` mode) can
+/// consume the compress_plan tensor as-if it were a `PrefillPlan` tensor --
+/// they only read `ragged_id` and `position`, both of which carry identical
+/// semantics here (the LAST token of the segment in q-ragged and global
+/// coordinates respectively).
+///
+/// Note that `window_len` here means "number of real tokens in this segment"
+/// (1..128), which differs from `PrefillPlan::window_len`. Downstream kernels
+/// that share the tensor MUST NOT read it under that name.
+struct alignas(16) OnlinePrefillPlan {
+ /// \brief Ragged-q position of the LAST token in this segment.
+ /// Equal to `segment_start_ragged + window_len - 1`.
+ uint32_t ragged_id;
+ /// \brief Index into the `indices` / `load_indices` arrays.
+ uint32_t batch_id;
+ /// \brief Global position of the LAST token in this segment.
+ /// For compress plans, `position % 128 == 127` (chunk-closing); for write
+ /// plans, `position % 128 < 127`.
+ uint32_t position;
+ /// \brief Number of real tokens in this segment (1..128).
+ /// The first segment token sits at `position - window_len + 1` (global) and
+ /// at `ragged_id - window_len + 1` (ragged).
+ uint32_t window_len;
+};
+
+static_assert(alignof(OnlinePrefillPlan) == alignof(PrefillPlan));
+static_assert(sizeof(OnlinePrefillPlan) == sizeof(PrefillPlan));
+
+} // namespace device::compress
+
+namespace host::compress {
+
+using device::compress::OnlinePrefillPlan;
+using OnlinePrefillPlanTensorDtype = uint8_t;
+inline constexpr int64_t kOnlinePrefillPlanDim = 16;
+
+static_assert(alignof(OnlinePrefillPlan) == sizeof(OnlinePrefillPlan));
+static_assert(sizeof(OnlinePrefillPlan) == kOnlinePrefillPlanDim * sizeof(OnlinePrefillPlanTensorDtype));
+
+} // namespace host::compress
+
+namespace {
+
+using OnlinePlan = device::compress::OnlinePrefillPlan;
+using IndiceT = int32_t;
+
+/// \brief Need to reduce register usage to increase occupancy
+struct Compress128OnlineDecodeParams {
+ /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[batch_size, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[batch_size, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]` */
+ const IndiceT* __restrict__ seq_lens;
+ /** \NOTE: `batch_size` <= `num_indices` */
+ uint32_t batch_size;
+};
+
+/// \brief Need to reduce register usage to increase occupancy
+struct Compress128OnlinePrefillParams {
+ /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */
+ void* __restrict__ kv_score_buffer;
+ /** \brief Shape: `[num_q_tokens, head_dim * 2]` */
+ const void* __restrict__ kv_score_input;
+ /** \brief Shape: `[num_q_tokens, head_dim]` */
+ void* __restrict__ kv_compressed_output;
+ /** \brief Shape: `[128, head_dim]` (called `ape`) */
+ const void* __restrict__ score_bias;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ indices;
+ /** \brief Shape: `[batch_size, ]`*/
+ const IndiceT* __restrict__ load_indices;
+ /// \brief Plan for segments that close a chunk (write to `kv_compressed_output`).
+ /// Shape: `[num_compress, 16]` (uint8).
+ const OnlinePlan* __restrict__ compress_plan;
+ /// \brief Plan for the trailing partial segment of each batch (write back to
+ /// `kv_score_buffer`). Shape: `[num_write, 16]` (uint8).
+ const OnlinePlan* __restrict__ write_plan;
+ uint32_t num_compress;
+ uint32_t num_write;
+};
+
+// 4 elements per thread, kHeadDim / 4 threads per block
+template
+__global__ void flash_c128_online_decode(const __grid_constant__ Compress128OnlineDecodeParams params) {
+ using namespace device;
+ constexpr uint32_t kVecSize = 4;
+ constexpr uint32_t kBlockSize = kHeadDim / kVecSize;
+ using Vec = AlignedVector;
+ const auto gmem = tile::Memory::cta(kBlockSize);
+ const auto batch_id = blockIdx.x;
+ const auto index = params.indices[batch_id];
+ const auto seq_len = params.seq_lens[batch_id];
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_buf = kv_score_buffer + index * (kHeadDim * 3);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2);
+
+ /// NOTE: kv_score_buffer layout is [max, sum, kv] (slot 0 / 1 / 2). Reads,
+ /// writes, and the prefill kernel must all agree on this order.
+ const auto max_score_vec = gmem.load(kv_buf, 0);
+ const auto sum_score_vec = gmem.load(kv_buf, 1);
+ const auto old_kv_vec = gmem.load(kv_buf, 2);
+
+ /// NOTE: kv_score_input layout is | kv | score | (head_dim each), matching
+ /// the offline c128 kernel and the online prefill kernel.
+ const auto new_kv_vec = gmem.load(kv_src, 0);
+ const auto new_score_raw_vec = gmem.load(kv_src, 1);
+
+ /// NOTE: the new token sits at global position `seq_len - 1`, so its
+ /// position inside the 128-chunk is `(seq_len - 1) % 128`. The previous
+ /// `seq_len % 128` was off by one (`bias[127]` vs `bias[0]`, etc.).
+ const auto pos_in_chunk = (seq_len - 1) % 128;
+ const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk);
+
+ Vec out_kv_vec;
+ Vec out_max_vec;
+ Vec out_sum_vec;
+ if (pos_in_chunk != 0) {
+ // Mid-chunk: combine prior partial state with the new token via online softmax.
+#pragma unroll
+ for (uint32_t i = 0; i < 4; ++i) {
+ const auto old_max = max_score_vec[i];
+ const auto old_kv = old_kv_vec[i];
+ const auto new_score = new_score_raw_vec[i] + bias_vec[i];
+ const auto new_kv = new_kv_vec[i];
+ const auto new_max = fmax(old_max, new_score);
+ const auto old_sum = sum_score_vec[i] * expf(old_max - new_max);
+ const auto new_exp = expf(new_score - new_max);
+ const auto new_sum = old_sum + new_exp;
+ out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum;
+ out_max_vec[i] = new_max;
+ out_sum_vec[i] = new_sum;
+ }
+ } else {
+ // First token of a new 128-chunk: initialize state with this token alone.
+#pragma unroll
+ for (uint32_t i = 0; i < 4; ++i) {
+ out_kv_vec[i] = new_kv_vec[i];
+ out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i];
+ out_sum_vec[i] = 1.0f; // exp(score - max) with max == score
+ }
+ }
+
+ if (pos_in_chunk == 127) {
+ // Chunk just closed: emit the compressed kv. No need to update the buffer
+ // -- the next chunk's first token will overwrite it.
+ const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim;
+ gmem.store(kv_out, out_kv_vec);
+ } else {
+ // Otherwise persist the running [max, sum, kv] state for the next step.
+ gmem.store(kv_buf, out_max_vec, 0);
+ gmem.store(kv_buf, out_sum_vec, 1);
+ gmem.store(kv_buf, out_kv_vec, 2);
+ }
+}
+
+constexpr int32_t kTileElements = 2; // split (along head-dim)
+/// \brief Each warp will handle this many elements (split along softmax-128)
+constexpr int32_t kElementsPerWarp = 8;
+constexpr uint32_t kNumWarps = 128 / kElementsPerWarp;
+constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps;
+using PrefillStorage = device::AlignedVector;
+
+struct Compress128SharedBuffer {
+ using Storage = device::AlignedVector;
+ Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict
+ SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) {
+ return data[warp_id][lane_id];
+ }
+ SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) {
+ return data[warp_id][lane_id][tile_id];
+ }
+};
+
+template
+SGL_DEVICE void c128_prefill_forward(
+ const PrefillStorage (&kv)[kElementsPerWarp],
+ const PrefillStorage (&score)[kElementsPerWarp],
+ float* kv_out,
+ float* max_out,
+ float* sum_out,
+ const uint32_t warp_id,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ /// NOTE: part 2: safe online softmax + weighted sum
+ using TmpStorage = typename Compress128SharedBuffer::Storage;
+ __shared__ Compress128SharedBuffer s_local_val_max;
+ __shared__ Compress128SharedBuffer s_local_exp_sum;
+ __shared__ Compress128SharedBuffer s_local_product;
+
+ TmpStorage tmp_val_max;
+ TmpStorage tmp_exp_sum;
+ TmpStorage tmp_product;
+
+#pragma unroll
+ for (int32_t i = 0; i < kTileElements; ++i) {
+ float score_fp32[kElementsPerWarp];
+
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ score_fp32[j] = score[j][i];
+ }
+
+ float max_value = score_fp32[0];
+ float sum_exp_value = 0.0f;
+
+#pragma unroll
+ for (int32_t j = 1; j < kElementsPerWarp; ++j) {
+ const auto fp32_score = score_fp32[j];
+ max_value = fmaxf(max_value, fp32_score);
+ }
+
+ float sum_product = 0.0f;
+#pragma unroll
+ for (int32_t j = 0; j < 8; ++j) {
+ const auto fp32_score = score_fp32[j];
+ const auto exp_score = expf(fp32_score - max_value);
+ sum_product += cast(kv[j][i]) * exp_score;
+ sum_exp_value += exp_score;
+ }
+
+ tmp_val_max[i] = max_value;
+ tmp_exp_sum[i] = sum_exp_value;
+ tmp_product[i] = sum_product;
+ }
+
+ // naturally aligned, so no bank conflict
+ s_local_val_max(warp_id, lane_id) = tmp_val_max;
+ s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum;
+ s_local_product(warp_id, lane_id) = tmp_product;
+
+ __syncthreads();
+
+ /// NOTE: part 3: online softmax
+ /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce
+ /// each reduce will consume `kNumWarps` threads (use partial warp reduction)
+ constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps;
+ constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kIteration; ++i) {
+ /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)`
+ const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id;
+ /// NOTE: Range `[0, kNumWarps)`
+ const uint32_t local_warp_id = j % kNumWarps;
+ /// NOTE: Range `[0, kTileElements * kWarpThreads)`
+ const uint32_t local_elem_id = j / kNumWarps;
+ /// NOTE: Range `[0, kTileElements)`
+ const uint32_t local_tile_id = local_elem_id % kTileElements;
+ /// NOTE: Range `[0, kWarpThreads)`
+ const uint32_t local_lane_id = local_elem_id / kTileElements;
+ /// NOTE: each warp will access the whole tile (all `kTileElements`)
+ /// and for different lanes, the memory access only differ in `local_warp_id`
+ /// so there's no bank conflict in shared memory access.
+ static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs");
+ const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id);
+ const auto global_val_max = warp::reduce_max(local_val_max);
+ const auto rescale = expf(local_val_max - global_val_max);
+ const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale);
+ const auto final_scale = rescale / global_exp_sum;
+ const auto global_product = warp::reduce_sum(local_product * final_scale);
+ kv_out[local_elem_id] = global_product;
+ if constexpr (kNeedData) {
+ max_out[local_elem_id] = global_val_max;
+ sum_out[local_elem_id] = global_exp_sum;
+ }
+ }
+ if constexpr (kNeedData) __syncthreads();
+}
+
+/// \brief Sentinel score for padded positions in a 128-segment.
+/// Must be finite so that `score - max` never produces NaN even when an
+/// entire warp has only padded positions.
+constexpr float kPadScore = -FLT_MAX;
+
+/// \brief Online compress 128 prefill. Two passes share this body:
+/// - `kWrite=false` (compress pass): handles segments that close a chunk.
+/// May load prior partial state from the buffer, but never writes to it,
+/// so concurrent blocks can read the same slot without racing.
+/// - `kWrite=true` (write pass): handles the trailing partial segment of each
+/// batch. Each batch contributes at most one such plan, so concurrent blocks
+/// touch disjoint buffer slots.
+///
+/// The two passes MUST run as separate kernel launches (in stream order) so
+/// that all reads in pass 1 finish before any writes in pass 2 start.
+template
+__global__ __launch_bounds__(kPrefillBlockSize, 2) //
+ void flash_c128_online_prefill(const __grid_constant__ Compress128OnlinePrefillParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim");
+
+ /// NOTE: the compiler folds the if-else at compile time.
+ const auto num_plans = kWrite ? params.num_write : params.num_compress;
+ const auto plan_ptr = kWrite ? params.write_plan : params.compress_plan;
+ const uint32_t global_id = blockIdx.x;
+ const uint32_t global_pid = global_id / kNumSplit; // plan id
+ const uint32_t global_sid = global_id % kNumSplit; // split id
+ if (global_pid >= num_plans) return;
+ const auto [ragged_id, batch_id, position, window_len] = plan_ptr[global_pid];
+ if (ragged_id == 0xFFFFFFFFu) [[unlikely]]
+ return;
+
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+ const int32_t split_offset = global_sid * kTileDim; // int32 is enough
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_compressed_output = static_cast(params.kv_compressed_output);
+ const auto score_bias_base = static_cast(params.score_bias);
+
+ constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score |
+ const uint32_t chunk_offset = (position % 128u) + 1u - window_len;
+ const uint32_t window_end = chunk_offset + window_len; // exclusive, in [1, 128]
+ const int32_t segment_start = ragged_id - (position % 128u); // can be negative, but safe
+ const int32_t load_index = chunk_offset != 0 ? params.load_indices[batch_id] : -1;
+ const int32_t store_index = kWrite ? params.indices[batch_id] : -1;
+
+ PDLWaitPrimary();
+
+ // 2 * 8 = 16 register per elem. in theory we should consume 48 register here
+ PrefillStorage kv[kElementsPerWarp];
+ PrefillStorage score[kElementsPerWarp];
+ PrefillStorage bias[kElementsPerWarp];
+ const auto warp_offset = warp_id * kElementsPerWarp;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ if (j >= chunk_offset && j < window_end) {
+ const auto kv_src_ptr = kv_score_input + (segment_start + j) * kElementSize + split_offset;
+ const auto score_src_ptr = kv_src_ptr + kHeadDim;
+ const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset;
+ kv[i].load(kv_src_ptr, lane_id);
+ score[i].load(score_src_ptr, lane_id);
+ bias[i].load(bias_src_ptr, lane_id);
+ }
+ }
+
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ const bool is_valid = (j >= chunk_offset && j < window_end);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore;
+ /// NOTE: must zero out kv on padded slots -- `c128_prefill_forward`
+ /// computes `kv * exp_score` where `exp_score = expf(-FLT_MAX - max) ??? 0`,
+ /// and IEEE-754 makes `NaN * 0 = NaN` / `+-inf * 0 = NaN`. An
+ /// uninitialized register can hold a NaN/inf bit pattern, so without
+ /// this reset a single padded warp can poison the whole softmax.
+ kv[i][ii] = is_valid ? kv[i][ii] : 0.0f;
+ }
+ }
+
+ __shared__ alignas(16) float seg_kv[kTileDim];
+ __shared__ alignas(16) float seg_max[kTileDim];
+ __shared__ alignas(16) float seg_sum[kTileDim];
+
+ c128_prefill_forward(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id);
+
+ PDLTriggerSecondary();
+
+ if (warp_id == 0) {
+ PrefillStorage out_kv_vec, out_max_vec, out_sum_vec;
+ out_kv_vec.load(seg_kv, lane_id);
+ out_max_vec.load(seg_max, lane_id);
+ out_sum_vec.load(seg_sum, lane_id);
+ if (chunk_offset != 0) {
+ /// NOTE: load (max, sum, kv) of the in-progress chunk for this index.
+ /// `load_indices` may differ from `indices` when the prior partial state
+ /// lives on a different slot than the slot we ultimately write to.
+ const auto buf_load = kv_score_buffer + load_index * (kHeadDim * 3) + split_offset;
+ PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec;
+ buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id);
+ buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id);
+ buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ const float m1 = buf_max_vec[ii];
+ const float s1 = buf_sum_vec[ii];
+ const float k1 = buf_kv_vec[ii];
+ const float m2 = out_max_vec[ii];
+ const float s2 = out_sum_vec[ii];
+ const float k2 = out_kv_vec[ii];
+ const float new_max = fmaxf(m1, m2);
+ const float new_s1 = s1 * expf(m1 - new_max);
+ const float new_s2 = s2 * expf(m2 - new_max);
+ const float new_sum = new_s1 + new_s2;
+ const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum;
+ out_max_vec[ii] = new_max;
+ out_sum_vec[ii] = new_sum;
+ out_kv_vec[ii] = new_kv;
+ }
+ }
+
+ if constexpr (kWrite) {
+ const auto buf_store = kv_score_buffer + store_index * (kHeadDim * 3) + split_offset;
+ reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec;
+ reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec;
+ reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec;
+ } else {
+ const auto out_ptr = kv_compressed_output + ragged_id * kHeadDim + split_offset;
+ reinterpret_cast(out_ptr)[lane_id] = out_kv_vec;
+ }
+ }
+}
+
+template
+struct FlashCompress128OnlineKernel {
+ static constexpr auto decode_kernel = flash_c128_online_decode;
+ template
+ static constexpr auto prefill_kernel = flash_c128_online_prefill;
+ static constexpr auto prefill_c_kernel = prefill_kernel*kWrite=*/false>;
+ static constexpr auto prefill_w_kernel = prefill_kernel*kWrite=*/true>;
+ static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64
+ static constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4;
+
+ static void run_decode(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::Optional /* UNUSED */) {
+ using namespace host;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto device = SymbolicDevice{};
+ device.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv)
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_buffer);
+ TensorMatcher({B, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_score_input);
+ TensorMatcher({B, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device)
+ .verify(ape);
+ TensorMatcher({B}).with_dtype().with_device(device).verify(indices);
+ TensorMatcher({B}).with_dtype().with_device(device).verify(seq_lens);
+
+ const auto batch_size = static_cast(B.unwrap());
+ const auto params = Compress128OnlineDecodeParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .batch_size = batch_size,
+ };
+ LaunchKernel(batch_size, kDecodeBlockSize, device.unwrap()) //
+ .enable_pdl(kUsePDL)(decode_kernel, params);
+ }
+
+ static void run_prefill(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView indices,
+ const tvm::ffi::TensorView compress_plan,
+ const tvm::ffi::TensorView write_plan,
+ const tvm::ffi::Optional extra) {
+ using namespace host;
+ using host::compress::kOnlinePrefillPlanDim;
+ using host::compress::OnlinePrefillPlanTensorDtype;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto X = SymbolicSize{"compress_tokens"};
+ auto Y = SymbolicSize{"write_tokens"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) ??? 2D
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({N, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({N, kHeadDim}) // kv compressed output
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+ TensorMatcher({B}) // indices
+ .with_dtype()
+ .with_device(device_)
+ .verify(indices);
+ TensorMatcher({X, kOnlinePrefillPlanDim}) // compress plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(compress_plan);
+ TensorMatcher({Y, kOnlinePrefillPlanDim}) // write plan
+ .with_dtype()
+ .with_device(device_)
+ .verify(write_plan);
+
+ /// NOTE: `extra` is `load_indices`. When the previous partial state lives
+ /// on a slot different from the destination slot (e.g. paged buffers), the
+ /// caller must supply this; otherwise it defaults to `indices`.
+ const auto load_indices = extra.value_or(indices);
+ TensorMatcher({B}).with_dtype().with_device(device_).verify(load_indices);
+
+ const auto device = device_.unwrap();
+ const auto num_c = static_cast(X.unwrap());
+ const auto num_w = static_cast(Y.unwrap());
+ const auto params = Compress128OnlinePrefillParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .indices = static_cast(indices.data_ptr()),
+ .load_indices = static_cast(load_indices.data_ptr()),
+ .compress_plan = static_cast(compress_plan.data_ptr()),
+ .write_plan = static_cast(write_plan.data_ptr()),
+ .num_compress = num_c,
+ .num_write = num_w,
+ };
+
+ /// NOTE: pass 1 reads the buffer (for the first segment of each batch
+ /// that started mid-chunk) and writes only to `kv_compressed_output`.
+ /// Pass 2 then writes the trailing partial state of each batch back to
+ /// the buffer. Stream serialization between the two launches enforces
+ /// read-before-write on shared buffer slots.
+ if (const auto num_c_blocks = num_c * kNumSplit) {
+ LaunchKernel(num_c_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_c_kernel, params);
+ }
+ if (const auto num_w_blocks = num_w * kNumSplit) {
+ LaunchKernel(num_w_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_w_kernel, params);
+ }
+ }
+};
+
+} // namespace
+
+namespace host::compress {
+
+using OnlinePlanResult = tvm::ffi::Tuple;
+
+struct OnlinePrefillCompressParams {
+ OnlinePrefillPlan* __restrict__ compress_plan;
+ OnlinePrefillPlan* __restrict__ write_plan;
+ const int64_t* __restrict__ seq_lens;
+ const int64_t* __restrict__ extend_lens;
+ uint32_t batch_size;
+ uint32_t num_tokens;
+};
+
+/// \brief Build the compress + write plans for online compress 128 prefill.
+///
+/// Each batch's `[prefix_len, prefix_len + extend_len)` range is split at
+/// 128-aligned boundaries. Every resulting segment falls into one of:
+/// - **compress**: closes a 128-chunk (`chunk_offset + window_len == 128`).
+/// These plans only read the buffer (when starting mid-chunk) and write the
+/// compressed kv to `kv_compressed_output`.
+/// - **write**: trailing partial of the batch (`chunk_offset + window_len < 128`).
+/// May read the buffer and always writes the new partial state back to it.
+/// Each batch produces at most one such plan.
+///
+/// The two plans MUST be dispatched as separate kernel launches in stream
+/// order so that pass-1 reads of a buffer slot complete before any pass-2
+/// write of the same slot.
+inline OnlinePlanResult plan_online_prefill_host(const OnlinePrefillCompressParams& params, const bool use_cuda_graph) {
+ const auto& [compress_plan, write_plan, seq_lens, extend_lens, batch_size, num_tokens] = params;
+
+ uint32_t counter = 0;
+ uint32_t compress_count = 0;
+ uint32_t write_count = 0;
+ for (const auto i : irange(batch_size)) {
+ const uint32_t seq_len = static_cast(seq_lens[i]);
+ const uint32_t extend_len = static_cast(extend_lens[i]);
+ RuntimeCheck(0 < extend_len && extend_len <= seq_len);
+ const uint32_t prefix_len = seq_len - extend_len;
+ const uint32_t end_pos = prefix_len + extend_len;
+ /// NOTE: split the extend range into per-128-chunk segments. Each segment
+ /// stays inside one chunk, so the kernel can decide load/store from
+ /// `chunk_offset` and `window_len` alone.
+ uint32_t pos = prefix_len;
+ while (pos < end_pos) {
+ const uint32_t chunk_start = (pos / 128u) * 128u;
+ const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive
+ const uint32_t seg_len = seg_end - pos;
+ const uint32_t chunk_off = pos - chunk_start;
+ /// NOTE: store last-token coordinates so that downstream consumers
+ /// (e.g. `fused_norm_rope`) can read `ragged_id` and `position` with the
+ /// same semantics as `PrefillPlan`. The segment start is recoverable as
+ /// `ragged_id - window_len + 1` and `position - window_len + 1`.
+ const uint32_t last_pos = seg_end - 1;
+ const uint32_t last_ragged = counter + (last_pos - prefix_len);
+ const auto plan = OnlinePrefillPlan{
+ .ragged_id = last_ragged,
+ .batch_id = i,
+ .position = last_pos,
+ .window_len = seg_len,
+ };
+ if (chunk_off + seg_len == 128u) {
+ // full chunk, must be complete, maybe read the buffer, no write
+ RuntimeCheck(compress_count < num_tokens);
+ compress_plan[compress_count++] = plan;
+ } else {
+ // last chunk, must be incomplete, maybe read the buffer, must write
+ RuntimeCheck(write_count < num_tokens);
+ write_plan[write_count++] = plan;
+ }
+ pos = seg_end;
+ }
+ counter += extend_len;
+ }
+ RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens);
+ if (!use_cuda_graph) return OnlinePlanResult{compress_count, write_count};
+ /// NOTE: pad both plans with sentinel entries so cuda-graph runs always see
+ /// the same number of blocks. The kernel skips plans whose `ragged_id` is -1.
+ constexpr auto kInvalid = static_cast(-1);
+ constexpr auto kInvalidPlan = OnlinePrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid};
+ for (const auto i : irange(compress_count, num_tokens)) {
+ compress_plan[i] = kInvalidPlan;
+ }
+ for (const auto i : irange(write_count, num_tokens)) {
+ write_plan[i] = kInvalidPlan;
+ }
+ return OnlinePlanResult{num_tokens, num_tokens};
+}
+
+inline OnlinePlanResult plan_online_prefill(
+ const tvm::ffi::TensorView extend_lens,
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::TensorView compress_plan,
+ const tvm::ffi::TensorView write_plan,
+ const bool use_cuda_graph) {
+ auto N = SymbolicSize{"batch_size"};
+ auto M = SymbolicSize{"num_tokens"};
+ auto device = SymbolicDevice{};
+ /// NOTE: only host (CPU/cuda-host) planning is implemented for now. The
+ device.set_options();
+ TensorMatcher({N}) //
+ .with_dtype()
+ .with_device(device)
+ .verify(extend_lens)
+ .verify(seq_lens);
+ TensorMatcher({M, kOnlinePrefillPlanDim}) //
+ .with_dtype()
+ .with_device(device)
+ .verify(compress_plan)
+ .verify(write_plan);
+ const auto params = OnlinePrefillCompressParams{
+ .compress_plan = static_cast(compress_plan.data_ptr()),
+ .write_plan = static_cast(write_plan.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .extend_lens = static_cast(extend_lens.data_ptr()),
+ .batch_size = static_cast(N.unwrap()),
+ .num_tokens = static_cast(M.unwrap()),
+ };
+ return plan_online_prefill_host(params, use_cuda_graph);
+}
+
+} // namespace host::compress
+
+namespace {
+
+[[maybe_unused]]
+constexpr auto& plan_compress_online_prefill = host::compress::plan_online_prefill;
+
+} // namespace
diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh
new file mode 100644
index 0000000000..71e600dc39
--- /dev/null
+++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh
@@ -0,0 +1,875 @@
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+namespace {
+
+using PlanD = device::compress::DecodePlan;
+using PlanC = device::compress::CompressPlan;
+
+// ---------------------------------------------------------------------------
+// Decode kernel: 1 token / batch. Each block handles one batch.
+// 4 elements per thread -> kBlockSize = head_dim / 4.
+// ---------------------------------------------------------------------------
+
+struct Compress128OnlineDecodeParams {
+ void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3]
+ const void* __restrict__ kv_score_input; // [batch_size, head_dim * 2]
+ void* __restrict__ kv_compressed_output; // [batch_size, head_dim]
+ const void* __restrict__ score_bias; // [128, head_dim]
+ const PlanD* __restrict__ plan_d;
+ uint32_t batch_size;
+};
+
+template
+__global__ void flash_c128_online_decode_v2(const __grid_constant__ Compress128OnlineDecodeParams params) {
+ using namespace device;
+ constexpr uint32_t kVecSize = 4;
+ constexpr uint32_t kBlockSize = kHeadDim / kVecSize;
+ using Vec = AlignedVector;
+ const auto gmem = tile::Memory::cta(kBlockSize);
+ const auto batch_id = blockIdx.x;
+ if (batch_id >= params.batch_size) return;
+
+ // Wait for the plan-finalize kernel to publish `plan.read_page_0 / write_loc`
+ // before reading the plan. The plan kernel runs on the same stream and does
+ // NOT issue a PDL trigger, so launching this kernel with PDL means our
+ // pre-wait global reads can race with the plan kernel's writes.
+ PDLWaitPrimary();
+
+ const auto plan = params.plan_d[batch_id];
+ const auto pos_in_chunk = (plan.seq_len - 1) % 128;
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_load_buf = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3);
+ const auto kv_store_buf = kv_score_buffer + plan.write_loc * (kHeadDim * 3);
+ const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2);
+
+ // Buffer layout: [max | sum | kv] (slot 0 / 1 / 2 of the head_dim*3 row).
+ const auto new_kv_vec = gmem.load(kv_src, 0);
+ const auto new_score_raw_vec = gmem.load(kv_src, 1);
+ const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk);
+
+ Vec out_kv_vec;
+ Vec out_max_vec;
+ Vec out_sum_vec;
+ if (pos_in_chunk != 0) {
+ // Mid-chunk: combine prior partial state with the new token.
+ const auto max_score_vec = gmem.load(kv_load_buf, 0);
+ const auto sum_score_vec = gmem.load(kv_load_buf, 1);
+ const auto old_kv_vec = gmem.load(kv_load_buf, 2);
+#pragma unroll
+ for (uint32_t i = 0; i < kVecSize; ++i) {
+ const auto old_max = max_score_vec[i];
+ const auto old_kv = old_kv_vec[i];
+ const auto new_score = new_score_raw_vec[i] + bias_vec[i];
+ const auto new_kv = new_kv_vec[i];
+ const auto new_max = fmaxf(old_max, new_score);
+ const auto old_sum = sum_score_vec[i] * expf(old_max - new_max);
+ const auto new_exp = expf(new_score - new_max);
+ const auto new_sum = old_sum + new_exp;
+ out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum;
+ out_max_vec[i] = new_max;
+ out_sum_vec[i] = new_sum;
+ }
+ } else {
+ // First token of a new chunk: state == this token alone.
+#pragma unroll
+ for (uint32_t i = 0; i < kVecSize; ++i) {
+ out_kv_vec[i] = new_kv_vec[i];
+ out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i];
+ out_sum_vec[i] = 1.0f;
+ }
+ }
+
+ if (pos_in_chunk == 127) {
+ // Chunk just closed: emit compressed kv, no buffer update.
+ const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim;
+ gmem.store(kv_out, out_kv_vec);
+ } else {
+ gmem.store(kv_store_buf, out_max_vec, 0);
+ gmem.store(kv_store_buf, out_sum_vec, 1);
+ gmem.store(kv_store_buf, out_kv_vec, 2);
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Prefill kernel: 1 segment / block. Two passes (compress + write) share the
+// kernel template, parameterized by `kWrite`.
+// 16 warps per block; each warp handles 8 of the 128 chunk positions.
+// ---------------------------------------------------------------------------
+
+constexpr int32_t kTileElements = 2; // split along head-dim
+constexpr int32_t kElementsPerWarp = 8; // split along the 128-chunk
+constexpr uint32_t kNumWarps = 128 / kElementsPerWarp;
+constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps;
+using PrefillStorage = device::AlignedVector;
+
+struct Compress128OnlinePrefillParams {
+ void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3]
+ const void* __restrict__ kv_score_input; // [num_q_tokens, head_dim * 2]
+ void* __restrict__ kv_compressed_output; // [num_compress, head_dim]
+ const void* __restrict__ score_bias; // [128, head_dim]
+ const PlanC* __restrict__ plan_c; // close-chunk segments
+ const PlanC* __restrict__ plan_w; // trailing partial segments
+ uint32_t num_compress;
+ uint32_t num_write;
+};
+
+struct Compress128SharedBuffer {
+ using Storage = device::AlignedVector;
+ Storage data[kNumWarps][device::kWarpThreads + 1]; // +1 to avoid bank conflict
+ SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) {
+ return data[warp_id][lane_id];
+ }
+ SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) {
+ return data[warp_id][lane_id][tile_id];
+ }
+};
+
+/// \brief Sentinel score for padded positions in a 128-segment.
+constexpr float kPadScore = -FLT_MAX;
+
+[[maybe_unused]]
+SGL_DEVICE void c128_prefill_segment_softmax(
+ const PrefillStorage (&kv)[kElementsPerWarp],
+ const PrefillStorage (&score)[kElementsPerWarp],
+ float* seg_kv,
+ float* seg_max,
+ float* seg_sum,
+ const uint32_t warp_id,
+ const uint32_t lane_id) {
+ using namespace device;
+
+ // Per-warp running state (max, sum, kv) for kTileElements head-dim slots.
+ using TmpStorage = typename Compress128SharedBuffer::Storage;
+ __shared__ Compress128SharedBuffer s_local_val_max;
+ __shared__ Compress128SharedBuffer s_local_exp_sum;
+ __shared__ Compress128SharedBuffer s_local_product;
+
+ TmpStorage tmp_val_max;
+ TmpStorage tmp_exp_sum;
+ TmpStorage tmp_product;
+
+#pragma unroll
+ for (int32_t i = 0; i < kTileElements; ++i) {
+ float score_fp32[kElementsPerWarp];
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ score_fp32[j] = score[j][i];
+ }
+ float max_value = score_fp32[0];
+#pragma unroll
+ for (int32_t j = 1; j < kElementsPerWarp; ++j) {
+ max_value = fmaxf(max_value, score_fp32[j]);
+ }
+ float sum_exp_value = 0.0f;
+ float sum_product = 0.0f;
+#pragma unroll
+ for (int32_t j = 0; j < kElementsPerWarp; ++j) {
+ const auto exp_score = expf(score_fp32[j] - max_value);
+ sum_product += kv[j][i] * exp_score;
+ sum_exp_value += exp_score;
+ }
+ tmp_val_max[i] = max_value;
+ tmp_exp_sum[i] = sum_exp_value;
+ tmp_product[i] = sum_product;
+ }
+
+ // Aligned writes (no bank conflict thanks to `+1` padding).
+ s_local_val_max(warp_id, lane_id) = tmp_val_max;
+ s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum;
+ s_local_product(warp_id, lane_id) = tmp_product;
+
+ __syncthreads();
+
+ // Cross-warp reduction. Same recipe as c128_online.cuh: each block-thread
+ // pair reduces a (tile_id, lane_id) slot using a kNumWarps-wide warp shuffle.
+ constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps;
+ constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize;
+ static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs");
+
+#pragma unroll
+ for (uint32_t i = 0; i < kIteration; ++i) {
+ const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id;
+ const uint32_t local_warp_id = j % kNumWarps;
+ const uint32_t local_elem_id = j / kNumWarps;
+ const uint32_t local_tile_id = local_elem_id % kTileElements;
+ const uint32_t local_lane_id = local_elem_id / kTileElements;
+ const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id);
+ const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id);
+ const auto global_val_max = warp::reduce_max(local_val_max);
+ const auto rescale = expf(local_val_max - global_val_max);
+ const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale);
+ const auto final_scale = rescale / global_exp_sum;
+ const auto global_product = warp::reduce_sum(local_product * final_scale);
+ seg_kv[local_elem_id] = global_product;
+ seg_max[local_elem_id] = global_val_max;
+ seg_sum[local_elem_id] = global_exp_sum;
+ }
+ __syncthreads();
+}
+
+/// \brief Online compress 128 prefill v2.
+///
+/// `kWrite=false` (compress pass): handles segments that close a 128-chunk.
+/// Reads optional prior state from `read_page_0` (-1 = none), emits compressed
+/// kv to `kv_compressed_output[plan_id]` (compact).
+/// `kWrite=true` (write pass) : handles trailing partial segments.
+/// Reads optional prior state from `read_page_0` (-1 = none), writes new
+/// running state to `read_page_1`.
+template
+__global__ __launch_bounds__(kPrefillBlockSize, 2) //
+ void flash_c128_online_prefill_v2(const __grid_constant__ Compress128OnlinePrefillParams params) {
+ using namespace device;
+
+ constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64
+ constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static_assert(kHeadDim % kTileDim == 0);
+
+ // Compile-time fold to the right plan list.
+ const auto num_plans = kWrite ? params.num_write : params.num_compress;
+ const auto plan_ptr = kWrite ? params.plan_w : params.plan_c;
+ const uint32_t global_id = blockIdx.x;
+ const uint32_t global_pid = global_id / kNumSplit;
+ const uint32_t global_sid = global_id % kNumSplit;
+ if (global_pid >= num_plans) return;
+
+ const uint32_t warp_id = threadIdx.x / kWarpThreads;
+ const uint32_t lane_id = threadIdx.x % kWarpThreads;
+ const int32_t split_offset = global_sid * kTileDim;
+
+ // The previous kernel (plan-finalize stage 1) does NOT issue a PDL trigger,
+ // so PDLWaitPrimary effectively waits for stage 1 to complete. Read the plan
+ // AFTER the wait so the freshly-written `read_page_0` (= state-pool slot) is
+ // visible. Reading it before the wait is a real race -- with PDL enabled the
+ // kernel can begin executing before stage 1's stores propagate, and we'd see
+ // the stage-0 batch_id placeholder in `read_page_0` instead of the slot.
+ PDLWaitPrimary();
+
+ const auto plan = plan_ptr[global_pid];
+ if (plan.is_invalid()) [[unlikely]]
+ return;
+
+ const auto kv_score_buffer = static_cast(params.kv_score_buffer);
+ const auto kv_score_input = static_cast(params.kv_score_input);
+ const auto kv_compressed_output = static_cast(params.kv_compressed_output);
+ const auto score_bias_base = static_cast(params.score_bias);
+
+ constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score |
+
+ // The plan stores last-token coordinates; segment start is recoverable as
+ // ragged_id - window_len + 1.
+ const uint32_t window_len = plan.buffer_len;
+ const uint32_t position = plan.seq_len - 1;
+ const uint32_t pos_in_chunk_end = (position % 128u) + 1u; // exclusive, in [1, 128]
+ const uint32_t chunk_offset = pos_in_chunk_end - window_len; // in [0, 127]
+ const int32_t segment_start_ragged = static_cast(plan.ragged_id) - static_cast(position % 128u);
+
+ // --- Stage 1: load kv / score / bias for this warp's 8 chunk positions.
+ PrefillStorage kv[kElementsPerWarp];
+ PrefillStorage score[kElementsPerWarp];
+ PrefillStorage bias[kElementsPerWarp];
+ const uint32_t warp_offset = warp_id * kElementsPerWarp;
+
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ if (j >= chunk_offset && j < pos_in_chunk_end) {
+ const auto kv_src_ptr = kv_score_input + (segment_start_ragged + j) * kElementSize + split_offset;
+ const auto score_src_ptr = kv_src_ptr + kHeadDim;
+ const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset;
+ kv[i].load(kv_src_ptr, lane_id);
+ score[i].load(score_src_ptr, lane_id);
+ bias[i].load(bias_src_ptr, lane_id);
+ }
+ }
+
+ // --- Stage 2: pad invalid positions. score = -FLT_MAX, kv = 0 (so that
+ // kv * exp(score-max) ??? 0 / 0 cleanly without producing NaN/inf).
+#pragma unroll
+ for (uint32_t i = 0; i < kElementsPerWarp; ++i) {
+ const uint32_t j = i + warp_offset;
+ const bool is_valid = (j >= chunk_offset && j < pos_in_chunk_end);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore;
+ kv[i][ii] = is_valid ? kv[i][ii] : 0.0f;
+ }
+ }
+
+ // --- Stage 3: warp-tile online softmax over the 128-position chunk.
+ __shared__ alignas(16) float seg_kv[kTileDim];
+ __shared__ alignas(16) float seg_max[kTileDim];
+ __shared__ alignas(16) float seg_sum[kTileDim];
+ c128_prefill_segment_softmax(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id);
+
+ PDLTriggerSecondary();
+
+ // --- Stage 4: warp 0 folds with prior partial state (if any) and writes.
+ if (warp_id == 0) {
+ PrefillStorage out_kv_vec, out_max_vec, out_sum_vec;
+ out_kv_vec.load(seg_kv, lane_id);
+ out_max_vec.load(seg_max, lane_id);
+ out_sum_vec.load(seg_sum, lane_id);
+
+ if (chunk_offset != 0 && plan.read_page_0 >= 0) {
+ // Combine with prior partial state for this slot.
+ const auto buf_load = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset;
+ PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec;
+ buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id);
+ buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id);
+ buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id);
+#pragma unroll
+ for (uint32_t ii = 0; ii < kTileElements; ++ii) {
+ const float m1 = buf_max_vec[ii];
+ const float s1 = buf_sum_vec[ii];
+ const float k1 = buf_kv_vec[ii];
+ const float m2 = out_max_vec[ii];
+ const float s2 = out_sum_vec[ii];
+ const float k2 = out_kv_vec[ii];
+ const float new_max = fmaxf(m1, m2);
+ const float new_s1 = s1 * expf(m1 - new_max);
+ const float new_s2 = s2 * expf(m2 - new_max);
+ const float new_sum = new_s1 + new_s2;
+ const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum;
+ out_max_vec[ii] = new_max;
+ out_sum_vec[ii] = new_sum;
+ out_kv_vec[ii] = new_kv;
+ }
+ }
+
+ if constexpr (kWrite) {
+ // For trailing-partial segments the load and store slots collapse to the
+ // segment's own chunk slot (the request keeps a single in-progress
+ // chunk's running state at any time), so we reuse `read_page_0`.
+ const auto buf_store = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset;
+ reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec;
+ reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec;
+ reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec;
+ } else {
+ // Compact output: one row per compress plan, indexed by `global_pid`.
+ const auto out_ptr = kv_compressed_output + global_pid * kHeadDim + split_offset;
+ reinterpret_cast(out_ptr)[lane_id] = out_kv_vec;
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Host wrapper: matches the c128_v2 / c4_v2 host API style (run_decode /
+// run_prefill methods on a kernel-class template). We only expose `kHeadDim`
+// + `kUsePDL`; the dtype is fixed to fp32 for the online state pool.
+// ---------------------------------------------------------------------------
+
+template
+struct FlashCompress128OnlineKernel {
+ static constexpr auto decode_kernel = flash_c128_online_decode_v2;
+ template
+ static constexpr auto prefill_kernel = flash_c128_online_prefill_v2;
+ static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64
+ static constexpr uint32_t kNumSplit = kHeadDim / kTileDim;
+ static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4;
+
+ static void run_decode(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView plan_d_) {
+ using namespace host;
+
+ auto B = SymbolicSize{"batch_size"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({B, kHeadDim * 2}) // kv score input
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({B, kHeadDim}) // kv compressed output (sparse by batch_id)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+
+ const auto plan_d = compress::verify_plan_d(plan_d_, B, device_);
+ const auto batch_size = static_cast(B.unwrap());
+ if (batch_size == 0) return;
+ const auto params = Compress128OnlineDecodeParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .plan_d = plan_d,
+ .batch_size = batch_size,
+ };
+ LaunchKernel(batch_size, kDecodeBlockSize, device_.unwrap()) //
+ .enable_pdl(kUsePDL)(decode_kernel, params);
+ }
+
+ static void run_prefill(
+ const tvm::ffi::TensorView kv_score_buffer,
+ const tvm::ffi::TensorView kv_score_input,
+ const tvm::ffi::TensorView kv_compressed_output,
+ const tvm::ffi::TensorView ape,
+ const tvm::ffi::TensorView plan_c_,
+ const tvm::ffi::TensorView plan_w_) {
+ using namespace host;
+
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto C = SymbolicSize{"num_c_plans"};
+ auto W = SymbolicSize{"num_w_plans"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_buffer);
+ TensorMatcher({N, kHeadDim * 2}) // kv score input (ragged)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_score_input);
+ TensorMatcher({C, kHeadDim}) // kv compressed output (compact, by plan_c index)
+ .with_dtype()
+ .with_device(device_)
+ .verify(kv_compressed_output);
+ TensorMatcher({128, kHeadDim}) // ape
+ .with_dtype()
+ .with_device(device_)
+ .verify(ape);
+
+ // Both compress and write segments use PlanC layout. plan_c uses
+ // read_page_1=-1 (unused); plan_w uses read_page_1=store_slot.
+ const auto plan_c = compress::verify_plan_c(plan_c_, C, device_);
+ const auto plan_w = compress::verify_plan_c(plan_w_, W, device_);
+ const auto device = device_.unwrap();
+ const auto num_q_tokens = static_cast(N.unwrap());
+ const auto num_c = static_cast(C.unwrap());
+ const auto num_w = static_cast(W.unwrap());
+ RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w");
+ const auto params = Compress128OnlinePrefillParams{
+ .kv_score_buffer = kv_score_buffer.data_ptr(),
+ .kv_score_input = kv_score_input.data_ptr(),
+ .kv_compressed_output = kv_compressed_output.data_ptr(),
+ .score_bias = ape.data_ptr(),
+ .plan_c = plan_c,
+ .plan_w = plan_w,
+ .num_compress = num_c,
+ .num_write = num_w,
+ };
+
+ // The two passes MUST be serialized in stream order: pass 1 reads slots
+ // that pass 2 may write to; running them in parallel would race.
+ if (const auto num_c_blocks = num_c * kNumSplit) {
+ LaunchKernel(num_c_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_kernel*kWrite=*/false>, params);
+ }
+ if (const auto num_w_blocks = num_w * kNumSplit) {
+ LaunchKernel(num_w_blocks, kPrefillBlockSize, device) //
+ .enable_pdl(kUsePDL)(prefill_kernel*kWrite=*/true>, params);
+ }
+ }
+};
+
+} // namespace
+
+// ===========================================================================
+// Plan builders. Mirrors the offline v2 pattern (`c_plan.cuh`):
+// - Decode: a single GPU kernel reads seq_lens / req_to_token /
+// req_pool_indices on device and emits the final PlanD tensor in one go.
+// - Prefill: stage 0 (host, on CPU pinned memory) splits each batch's
+// extend range into per-chunk segments and emits PlanC entries with the
+// batch_id stashed in `read_page_0` as a placeholder. Stage 1 is a tiny
+// GPU kernel that finalizes `read_page_0` to `req_to_token[rid][chunk_start]`,
+// so the slot tensors never leave GPU memory. The online state pool keeps
+// a single in-progress chunk per request, so each segment's load and
+// store slot collapse to one value (the slot for the segment's own chunk),
+// and `read_page_1` is unused.
+// ===========================================================================
+
+namespace host::compress {
+
+using device::compress::CompressPlan;
+using device::compress::DecodePlan;
+
+// ---------------------------------------------------------------------------
+// Decode plan builder.
+// ---------------------------------------------------------------------------
+
+struct OnlineDecodePlanParams {
+ DecodePlan* __restrict__ plan_d;
+ const int64_t* __restrict__ seq_lens;
+ const int64_t* __restrict__ req_pool_indices;
+ const int32_t* __restrict__ req_to_token;
+ const int64_t* __restrict__ full_to_swa; // (full_cache_size,) int64
+ int64_t stride_r2t;
+ int32_t swa_page_size;
+ uint32_t batch_size;
+};
+
+__global__ void plan_c128_online_decode_kernel(const OnlineDecodePlanParams params) {
+ const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= params.batch_size) return;
+ const auto seq_len = static_cast(params.seq_lens[idx]);
+ const auto rid = params.req_pool_indices[idx];
+ const int32_t chunk_start = static_cast((seq_len - 1u) / 128u * 128u);
+ const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start];
+ const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]);
+ const int32_t slot = swa_loc / params.swa_page_size;
+ params.plan_d[idx] = DecodePlan{
+ .seq_len = seq_len,
+ .write_loc = slot,
+ .read_page_0 = slot,
+ .read_page_1 = -1,
+ };
+}
+
+/// \brief Build the decode plan tensor. Caller (Python) pre-allocates
+/// `plan_d_dev` as a `(batch_size, 16)` device uint8 tensor; this routine
+/// only fills it. See `plan_online_prefill` for the rationale (avoid
+/// `ffi::empty` + dlpack roundtrip / PyTorch caching-allocator stream
+/// tracking issue that surfaces as IMA in unrelated downstream kernels).
+inline void plan_online_decode(
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::TensorView req_pool_indices,
+ const tvm::ffi::TensorView req_to_token,
+ const tvm::ffi::TensorView full_to_swa,
+ const tvm::ffi::TensorView plan_d_dev_,
+ const int32_t swa_page_size) {
+ auto B = SymbolicSize{"batch_size"};
+ auto device_ = SymbolicDevice{};
+ device_.set_options();
+
+ auto seq_dtype = SymbolicDType{};
+ TensorMatcher({B}) //
+ .with_dtype(seq_dtype)
+ .with_device(device_)
+ .verify(seq_lens);
+ TensorMatcher({B}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_pool_indices);
+ TensorMatcher({-1, -1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_to_token);
+ TensorMatcher({-1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(full_to_swa);
+ TensorMatcher({B, sizeof(DecodePlan)}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(plan_d_dev_);
+ RuntimeCheck(swa_page_size > 0);
+
+ const auto batch_size = static_cast(B.unwrap());
+ if (batch_size == 0) return;
+
+ const auto device = device_.unwrap();
+ constexpr uint32_t kBlockSize = 256;
+ const uint32_t num_blocks = host::div_ceil(batch_size, kBlockSize);
+ const auto stride_r2t = req_to_token.stride(0);
+ const auto params = OnlineDecodePlanParams{
+ .plan_d = static_cast(plan_d_dev_.data_ptr()),
+ .seq_lens = static_cast(seq_lens.data_ptr()),
+ .req_pool_indices = static_cast(req_pool_indices.data_ptr()),
+ .req_to_token = static_cast(req_to_token.data_ptr()),
+ .full_to_swa = static_cast(full_to_swa.data_ptr()),
+ .stride_r2t = stride_r2t,
+ .swa_page_size = swa_page_size,
+ .batch_size = batch_size,
+ };
+ LaunchKernel(num_blocks, kBlockSize, device)(plan_c128_online_decode_kernel, params);
+}
+
+// ---------------------------------------------------------------------------
+// Prefill plan builder: host stage 0 + GPU stage 1.
+// ---------------------------------------------------------------------------
+
+struct OnlinePrefillStage0Params {
+ CompressPlan* __restrict__ plan_c;
+ CompressPlan* __restrict__ plan_w;
+ const int64_t* __restrict__ seq_lens;
+ const int64_t* __restrict__ extend_lens;
+ uint32_t batch_size;
+ uint32_t num_q_tokens;
+};
+
+inline std::tuple _plan_prefill_partial(const OnlinePrefillStage0Params& p) {
+ uint32_t counter = 0;
+ uint32_t compress_count = 0;
+ uint32_t write_count = 0;
+ for (const auto i : irange(p.batch_size)) {
+ const uint32_t seq_len = static_cast(p.seq_lens[i]);
+ const uint32_t extend_len = static_cast(p.extend_lens[i]);
+ RuntimeCheck(0 < extend_len && extend_len <= seq_len);
+ const uint32_t prefix_len = seq_len - extend_len;
+ const uint32_t end_pos = prefix_len + extend_len;
+
+ uint32_t pos = prefix_len;
+ while (pos < end_pos) {
+ const uint32_t chunk_start = (pos / 128u) * 128u;
+ const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive
+ const uint32_t seg_len = seg_end - pos;
+ const uint32_t chunk_off = pos - chunk_start;
+ const uint32_t last_pos = seg_end - 1;
+ const uint32_t last_ragged = counter + (last_pos - prefix_len);
+ RuntimeCheck(last_ragged < (1u << 16), "PlanC.ragged_id is uint16; ragged ", last_ragged, " overflows");
+ RuntimeCheck(seg_len <= 128u);
+ // Stash batch_id in `read_page_0` for stage 1 to translate. A
+ // chunk-aligned segment never loads, so we still need stage 1 to fill
+ // a slot in -- the kernel keys the load on `chunk_offset != 0`.
+ const auto plan = CompressPlan{
+ .seq_len = last_pos + 1u,
+ .ragged_id = static_cast(last_ragged),
+ .buffer_len = static_cast(seg_len),
+ .read_page_0 = static_cast(i), // batch_id placeholder
+ .read_page_1 = -1, // unused, kept so MSB layout is stable
+ };
+ if (chunk_off + seg_len == 128u) {
+ // close-chunk segment
+ RuntimeCheck(compress_count < p.num_q_tokens);
+ p.plan_c[compress_count++] = plan;
+ } else {
+ // trailing partial segment
+ RuntimeCheck(write_count < p.num_q_tokens);
+ p.plan_w[write_count++] = plan;
+ }
+ pos = seg_end;
+ }
+ counter += extend_len;
+ }
+ RuntimeCheck(counter == p.num_q_tokens, "input size ", counter, " != num_q_tokens ", p.num_q_tokens);
+ return std::tuple{compress_count, write_count};
+}
+
+struct OnlinePrefillStage1Params {
+ CompressPlan* __restrict__ plan_c;
+ CompressPlan* __restrict__ plan_w;
+ const int64_t* __restrict__ req_pool_indices; // (batch_size,)
+ const int32_t* __restrict__ req_to_token; // (num_reqs, max_tokens)
+ const int64_t* __restrict__ full_to_swa; // (full_cache_size,)
+ int64_t stride_r2t;
+ int32_t swa_page_size;
+ uint32_t num_c;
+ uint32_t num_w;
+};
+
+__global__ void plan_c128_online_prefill_kernel(const OnlinePrefillStage1Params params) {
+ const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const uint32_t total = params.num_c + params.num_w;
+ if (idx >= total) return;
+
+ const bool is_compress = idx < params.num_c;
+ CompressPlan* const plan_ptr = is_compress ? ¶ms.plan_c[idx] : ¶ms.plan_w[idx - params.num_c];
+ auto plan = *plan_ptr;
+ const auto batch_id = plan.read_page_0;
+ const auto rid = params.req_pool_indices[batch_id];
+ const int32_t position = static_cast(plan.seq_len - 1u);
+ const int32_t chunk_start = (position / 128) * 128;
+ const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start];
+ const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]);
+ plan.read_page_0 = swa_loc / params.swa_page_size;
+ *plan_ptr = plan;
+}
+
+using OnlinePrefillPlan = tvm::ffi::Tuple;
+
+inline OnlinePrefillPlan plan_online_prefill(
+ const tvm::ffi::TensorView seq_lens,
+ const tvm::ffi::TensorView extend_lens,
+ const tvm::ffi::TensorView req_pool_indices,
+ const tvm::ffi::TensorView req_to_token,
+ const tvm::ffi::TensorView full_to_swa,
+ const tvm::ffi::TensorView plan_c_pin,
+ const tvm::ffi::TensorView plan_w_pin,
+ const tvm::ffi::TensorView plan_c_dev_,
+ const tvm::ffi::TensorView plan_w_dev_,
+ const int32_t swa_page_size) {
+ auto B = SymbolicSize{"batch_size"};
+ auto N = SymbolicSize{"num_q_tokens"};
+ auto cpu = SymbolicDevice{};
+ auto device_ = SymbolicDevice{};
+ cpu.set_options();
+ device_.set_options();
+
+ TensorMatcher({B}) //
+ .with_dtype()
+ .with_device(cpu)
+ .verify(seq_lens)
+ .verify(extend_lens);
+ TensorMatcher({B}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_pool_indices);
+ TensorMatcher({-1, -1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(req_to_token);
+ TensorMatcher({-1}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(full_to_swa);
+ TensorMatcher({N, sizeof(CompressPlan)}) //
+ .with_dtype()
+ .with_device(cpu)
+ .verify(plan_c_pin)
+ .verify(plan_w_pin);
+ TensorMatcher({N, sizeof(CompressPlan)}) //
+ .with_dtype()
+ .with_device(device_)
+ .verify(plan_c_dev_)
+ .verify(plan_w_dev_);
+
+ const auto stage0_params = OnlinePrefillStage0Params{
+ .plan_c = static_cast(plan_c_pin.data_ptr()),
+ .plan_w = static_cast