diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 539ade769e..dc18ecf4ba 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -9,6 +9,60 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo +# this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up +# to the nearest supported count (zero heads are discarded from the output slice). +FLASHMLA_SUPPORTED_HEADS = (64, 128) + + +def _pad_q_heads(q_4d: torch.Tensor, attn_sink: torch.Tensor): + h_q = q_4d.shape[2] + if h_q in FLASHMLA_SUPPORTED_HEADS: + return q_4d, attn_sink, h_q + target = next((h for h in FLASHMLA_SUPPORTED_HEADS if h >= h_q), None) + assert target is not None, f"num q heads {h_q} exceeds flash_mla support {FLASHMLA_SUPPORTED_HEADS}" + q_pad = torch.nn.functional.pad(q_4d, (0, 0, 0, target - h_q)) + sink_pad = torch.nn.functional.pad(attn_sink, (0, target - h_q)) + return q_pad, sink_pad, h_q + + +def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor: + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN + + usable = page_size * DSV4_MLA_BYTES_PER_TOKEN + return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN) + + +@dataclasses.dataclass +class _Dsv4Metadata: + swa_indices: torch.Tensor + swa_lengths: torch.Tensor + extra_cache: torch.Tensor = None + extra_indices: torch.Tensor = None + extra_lengths: torch.Tensor = None + + +def _metadata_from_dict(infer_state, nsa_dict: dict) -> "_Dsv4Metadata": + """Bundle the model-built FINAL index tensors (carried in nsa_dict by DeepseekV4IndexInfer) with + the layer-keyed fp8 extra-cache byte view. The cache view is data-independent (a fixed per-layer + buffer slice), so it is built here -- a genuine flash_mla ABI concern -- rather than on the model + side; only the index/length tensors cross the att_control boundary.""" + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE + + ratio = nsa_dict["compress_ratio"] + extra_cache = None + if ratio: + page = DSV4_C4_PAGE_SIZE if ratio == 4 else DSV4_C128_PAGE_SIZE + extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(nsa_dict["layer_index"]) + extra_cache = _view_dsv4_flashmla_cache(extra_buffer, page) + return _Dsv4Metadata( + swa_indices=nsa_dict["swa_indices"], + swa_lengths=nsa_dict["swa_lengths"], + extra_cache=extra_cache, + extra_indices=nsa_dict.get("extra_indices"), + extra_lengths=nsa_dict.get("extra_lengths"), + ) + + class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) @@ -62,6 +116,12 @@ def prefill_att( ) -> torch.Tensor: assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + if att_control.nsa_prefill_dict.get("flashmla_kvcache"): + return self._flashmla_kvcache_prefill_att( + q=q, + packed_kv=k, + nsa_dict=att_control.nsa_prefill_dict, + ) return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control) def _nsa_prefill_att( @@ -78,6 +138,8 @@ def _nsa_prefill_att( kv_lora_rank = nsa_dict["kv_lora_rank"] topk_mem_indices = nsa_dict["topk_mem_indices"] prefill_cache_kv = nsa_dict["prefill_cache_kv"] + attn_sink = nsa_dict.get("attn_sink") + topk_length = nsa_dict.get("topk_length") if self.infer_state.prefix_total_token_num > 0: # 当前推理生成的token kv部分从 prefill_cache_kv 中获取,历史 @@ -101,9 +163,51 @@ def _nsa_prefill_att( indices=topk_indices, sm_scale=softmax_scale, d_v=kv_lora_rank, + attn_sink=attn_sink, + topk_length=topk_length, ) return mla_out + def _flashmla_kvcache_prefill_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor: + attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous() + metadata = _metadata_from_dict(self.infer_state, nsa_dict) + return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict) + + def _flashmla_kvcache_att( + self, + q: torch.Tensor, + packed_kv: torch.Tensor, + metadata: _Dsv4Metadata, + attn_sink: torch.Tensor, + nsa_dict: dict, + ) -> torch.Tensor: + import flash_mla + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE + + q_4d = q.unsqueeze(1).contiguous() + q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink) + k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE) + sched_meta, _ = flash_mla.get_mla_metadata() + out, _ = flash_mla.flash_mla_with_kvcache( + q=q_4d, + k_cache=k_cache, + block_table=None, + cache_seqlens=None, + head_dim_v=nsa_dict["head_dim_v"], + tile_scheduler_metadata=sched_meta, + num_splits=None, + softmax_scale=nsa_dict["softmax_scale"], + causal=False, + is_fp8_kvcache=True, + indices=metadata.swa_indices, + attn_sink=attn_sink, + topk_length=metadata.swa_lengths, + extra_k_cache=metadata.extra_cache, + extra_indices_in_kvcache=metadata.extra_indices, + extra_topk_length=metadata.extra_lengths, + ) + return out[:, 0, :num_real_heads].contiguous() + @dataclasses.dataclass class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): @@ -143,7 +247,18 @@ def init_state(self): ) import flash_mla - self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + # one sched_meta per layer type: the lazy config locks extra-cache geometry (page size, + # presence) on first invocation, so swa-only/c4/c128 layers must not share one object. + self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} + return + + def reset_sched_meta_for_capture(self): + # cuda-graph capture hook: the warmup pass already locked/stored sched meta on this + # (shared) state object; reset so the capture pass re-plans INSIDE the graph and every + # replay re-plans from the live tensors instead of binding warmup leftovers. + import flash_mla + + self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)} return def decode_att( @@ -156,6 +271,12 @@ def decode_att( ) -> torch.Tensor: assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + if att_control.nsa_decode_dict.get("flashmla_kvcache"): + return self._flashmla_kvcache_decode_att( + q=q, + packed_kv=k, + nsa_dict=att_control.nsa_decode_dict, + ) return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control) def _nsa_decode_att( @@ -170,6 +291,11 @@ def _nsa_decode_att( topk_mem_indices = nsa_dict["topk_mem_indices"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] + attn_sink = nsa_dict.get("attn_sink") + topk_length = nsa_dict.get("topk_length") + extra_k_cache = nsa_dict.get("extra_k_cache") + extra_indices = nsa_dict.get("extra_indices_in_kvcache") + extra_topk_length = nsa_dict.get("extra_topk_length") if topk_mem_indices.ndim == 2: topk_mem_indices = topk_mem_indices.unsqueeze(1) @@ -189,10 +315,54 @@ def _nsa_decode_att( block_table=None, cache_seqlens=None, head_dim_v=kv_lora_rank, - tile_scheduler_metadata=self.flashmla_sched_meta, + tile_scheduler_metadata=self.flashmla_sched_meta[0], softmax_scale=softmax_scale, causal=False, is_fp8_kvcache=True, indices=topk_mem_indices, + attn_sink=attn_sink, + topk_length=topk_length, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_length, ) return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] + + def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor: + attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous() + metadata = _metadata_from_dict(self.infer_state, nsa_dict) + return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict) + + def _flashmla_kvcache_att( + self, + q: torch.Tensor, + packed_kv: torch.Tensor, + metadata: _Dsv4Metadata, + attn_sink: torch.Tensor, + nsa_dict: dict, + ) -> torch.Tensor: + import flash_mla + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE + + q_4d = q.unsqueeze(1).contiguous() + q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink) + k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE) + out, _ = flash_mla.flash_mla_with_kvcache( + q=q_4d, + k_cache=k_cache, + block_table=None, + cache_seqlens=None, + head_dim_v=nsa_dict["head_dim_v"], + tile_scheduler_metadata=self.flashmla_sched_meta[nsa_dict["compress_ratio"]], + num_splits=None, + softmax_scale=nsa_dict["softmax_scale"], + causal=False, + is_fp8_kvcache=True, + indices=metadata.swa_indices, + attn_sink=attn_sink, + topk_length=metadata.swa_lengths, + extra_k_cache=metadata.extra_cache, + extra_indices_in_kvcache=metadata.extra_indices, + extra_topk_length=metadata.extra_lengths, + ) + return out[:, 0, :num_real_heads].contiguous() diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 94f9d4c1a2..5825d9b45f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -291,6 +291,11 @@ def _init_custom(self): @torch.no_grad() def forward(self, model_input: ModelInput): + # decode 槽位 prep: 放在 to_cuda 前 (b_req_idx/b_seq_len/mem_indexes_cpu 还是原生 CPU 张量), + # 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。 + # mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。 + if not model_input.is_prefill and model_input.mem_indexes_cpu is not None: + self.req_manager.prepare_decode(model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes_cpu) model_input.to_cuda() assert model_input.mem_indexes.is_cuda @@ -521,6 +526,18 @@ def _prefill( alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, ) + if hasattr(self.req_manager, "prepare_prefill_swa"): + self.req_manager.prepare_prefill_swa( + b_req_idx=infer_state.b_req_idx, + b_ready_cache_len=infer_state.b_ready_cache_len, + b_seq_len=infer_state.b_seq_len, + ) + if hasattr(self.req_manager, "prepare_prefill_compress_slots"): + self.req_manager.prepare_prefill_compress_slots( + b_req_idx=infer_state.b_req_idx, + b_ready_cache_len=infer_state.b_ready_cache_len, + b_seq_len=infer_state.b_seq_len, + ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -741,6 +758,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state0.mem_index, max_q_seq_len=infer_state0.max_q_seq_len, ) + if hasattr(self.req_manager, "prepare_prefill_swa"): + self.req_manager.prepare_prefill_swa( + b_req_idx=infer_state0.b_req_idx, + b_ready_cache_len=infer_state0.b_ready_cache_len, + b_seq_len=infer_state0.b_seq_len, + ) + if hasattr(self.req_manager, "prepare_prefill_compress_slots"): + self.req_manager.prepare_prefill_compress_slots( + b_req_idx=infer_state0.b_req_idx, + b_ready_cache_len=infer_state0.b_ready_cache_len, + b_seq_len=infer_state0.b_seq_len, + ) infer_state0.init_some_extra_state(self) infer_state0.init_att_state() @@ -754,6 +783,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod alloc_mem_index=infer_state1.mem_index, max_q_seq_len=infer_state1.max_q_seq_len, ) + if hasattr(self.req_manager, "prepare_prefill_swa"): + self.req_manager.prepare_prefill_swa( + b_req_idx=infer_state1.b_req_idx, + b_ready_cache_len=infer_state1.b_ready_cache_len, + b_seq_len=infer_state1.b_seq_len, + ) + if hasattr(self.req_manager, "prepare_prefill_compress_slots"): + self.req_manager.prepare_prefill_compress_slots( + b_req_idx=infer_state1.b_req_idx, + b_ready_cache_len=infer_state1.b_ready_cache_len, + b_seq_len=infer_state1.b_seq_len, + ) infer_state1.init_some_extra_state(self) infer_state1.init_att_state() @@ -781,6 +822,10 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod @torch.no_grad() def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput): + # decode 槽位 prep: 在 to_cuda 前 (原生 CPU 张量)、且已在 forward 的 CUDA stream 上 (见 forward 注释)。 + for mi in (model_input0, model_input1): + if mi.mem_indexes_cpu is not None: + self.req_manager.prepare_decode(mi.b_req_idx, mi.b_seq_len, mi.mem_indexes_cpu) model_input0.to_cuda() model_input1.to_cuda() assert self.args.enable_tpsp_mix_mode diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..e1d96b744e 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -14,6 +14,20 @@ logger = init_logger(__name__) +def _reset_att_state_sched_meta(infer_state: InferStateInfo): + # capture 前调用: warmup 趟用 copy.copy 浅拷贝共享 decode_att_state,其内部惰性初始化的 + # 调度对象(如 FlashMLASchedMeta,首次内核调用时按当时数据规划并回写)会被 warmup 的 + # dummy 负载锁定;若不重置,捕获趟将绑定为 dummy 规划的调度张量,所有 replay 都用错误 + # 的 tile schedule(DSV4 实测 gsm8k 0.96 -> 0.74)。重置后规划发生在捕获区内,随 replay 重算。 + for att_state in (infer_state.decode_att_state, infer_state.decode_att_state1): + if att_state is None: + continue + reset_fn = getattr(att_state, "reset_sched_meta_for_capture", None) + if reset_fn is not None: + reset_fn() + return + + class CudaGraph: # CudaGraph forward pass for the decoding stage. @@ -94,6 +108,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): if param_name not in pure_para_set: delattr(infer_state, param_name) + _reset_att_state_sched_meta(infer_state) + with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) @@ -128,6 +144,9 @@ def _capture_decode_overlap( if para_name not in pure_para_set1: delattr(infer_state1, para_name) + _reset_att_state_sched_meta(infer_state) + _reset_att_state_sched_meta(infer_state1) + with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index fca9b80fcf..24842ed383 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -68,12 +68,13 @@ def __init__( auto_update_redundancy_expert=self.auto_update_redundancy_expert, ) self.lock = threading.Lock() + self._moe_weight_finalized = False self._create_weight() def _init_config(self, network_config: Dict[str, Any]): self.n_group = network_config.get("n_group", 0) self.use_grouped_topk = self.n_group > 0 - self.norm_topk_prob = network_config["norm_topk_prob"] + self.norm_topk_prob = network_config.get("norm_topk_prob", False) self.topk_group = network_config.get("topk_group", 0) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) @@ -136,6 +137,7 @@ def experts( is_prefill: Optional[bool] = None, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" + self._finalize_moe_weight() return self.fuse_moe_impl( input_tensor=input_tensor, router_logits=router_logits, @@ -152,6 +154,25 @@ def experts( per_expert_scale=self.per_expert_scale, ) + def experts_with_preselected( + self, + input_tensor: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, + ) -> torch.Tensor: + self._finalize_moe_weight() + return self.fuse_moe_impl.fused_experts_with_topk( + input_tensor=input_tensor, + w13=self.w13, + w2=self.w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + is_prefill=is_prefill, + clamp_limit=clamp_limit, + ) + def low_latency_dispatch( self, hidden_states: torch.Tensor, @@ -280,7 +301,18 @@ def verify_load(self): e_score_correction_bias_load_ok = ( True if self.e_score_correction_bias is None else getattr(self.e_score_correction_bias, "load_ok", False) ) - return weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok + load_ok = weight_load_ok and per_expert_scale_load_ok and e_score_correction_bias_load_ok + if load_ok: + self._finalize_moe_weight() + return load_ok + + def _finalize_moe_weight(self): + if self._moe_weight_finalized: + return + finalize = getattr(self.quant_method, "finalize_moe_weight", None) + if finalize is not None: + finalize(self) + self._moe_weight_finalized = True def _create_weight(self): intermediate_size = self.split_inter_size diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py index 67bb90e4ef..282c0abdce 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py @@ -2,9 +2,15 @@ from .triton_impl import FuseMoeTriton from .marlin_impl import FuseMoeMarlin from .deepgemm_impl import FuseMoeDeepGEMM +from .mxfp4_impl import FuseMoeMXFP4 def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool): + if quant_method.method_name == "marlin-mxfp4w4a16-b32": + if enable_ep_moe: + raise RuntimeError("marlin-mxfp4w4a16-b32 does not support enable_ep_moe yet") + return FuseMoeMXFP4 + if enable_ep_moe: return FuseMoeDeepGEMM diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index 4d4614c007..72acf2430a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -76,7 +76,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, ): + assert clamp_limit is None, "EP deepgemm fused MoE does not support clamp_limit yet" output = fused_experts( hidden_states=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 0094b09b1c..1fdfd94d0d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -30,7 +30,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, ): + assert clamp_limit is None, "awq_marlin fused MoE does not support clamp_limit yet" w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py new file mode 100644 index 0000000000..97cf238116 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/mxfp4_impl.py @@ -0,0 +1,44 @@ +import torch +from typing import Optional + +from lightllm.common.quantization.quantize_method import WeightPack +from .triton_impl import FuseMoeTriton + + +class FuseMoeMXFP4(FuseMoeTriton): + def create_workspace(self): + return None + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, + ): + try: + from vllm.model_executor.layers.fused_moe.activation import MoEActivation + from vllm.model_executor.layers.fused_moe.experts.marlin_moe import fused_marlin_moe + from vllm.scalar_type import scalar_types + except Exception as e: + raise RuntimeError(f"MXFP4 fused MoE requires vLLM fused kernels, error={repr(e)}") from e + + return fused_marlin_moe( + hidden_states=input_tensor.contiguous(), + w1=w13.weight, + w2=w2.weight, + bias1=None, + bias2=None, + w1_scale=w13.weight_scale, + w2_scale=w2.weight_scale, + topk_weights=topk_weights.to(torch.float32).contiguous(), + topk_ids=topk_ids.to(torch.long).contiguous(), + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=self.n_routed_experts, + activation=MoEActivation.SILU, + clamp_limit=clamp_limit, + ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index a0d30547a3..8967dda34e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -94,6 +94,7 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + clamp_limit: Optional[float] = None, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -111,9 +112,30 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + limit=clamp_limit, ) return input_tensor + def fused_experts_with_topk( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_prefill: Optional[bool] = None, + clamp_limit: Optional[float] = None, + ): + return self._fused_experts( + input_tensor=input_tensor, + w13=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + is_prefill=is_prefill, + clamp_limit=clamp_limit, + ) + def __call__( self, input_tensor: torch.Tensor, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index cb2e370cb9..28fe6e4304 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -49,7 +49,7 @@ def check_ep_expert_dtype(quant_method: Any): "EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], " f"but the resolved fused_moe quant method is `{expert_dtype}`. " "Please start with --expert_dtype fp8 or --expert_dtype fp4. " - "Note that --expert_dtype fp4 is only supported on SM100 GPUs." + "Note that --expert_dtype fp4 with EP MoE is only supported on SM100 GPUs." ) if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu(): raise RuntimeError( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index 45c7ea73c6..82fc9131c1 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -24,6 +24,7 @@ def _silu_and_mul_kernel_fast( NEED_MASK: tl.constexpr, layout: tl.constexpr = "blocked", # "blocked" or "interleaved" USE_LIMIT_AND_ALPHA: tl.constexpr = False, + USE_LIMIT_ONLY: tl.constexpr = False, USE_TANH_APPROXIMATE_GELU: tl.constexpr = False, ): stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) @@ -76,6 +77,11 @@ def _silu_and_mul_kernel_fast( mask=mask, ) else: + if USE_LIMIT_ONLY: + # clamped swiglu (DeepSeek-V4 swiglu_limit): clamp 后接标准 silu, + # 无 gpt-oss 的 alpha 缩放与 (up+1)。 + gate = tl.minimum(gate, limit) + up = tl.minimum(tl.maximum(up, -limit), limit) if USE_TANH_APPROXIMATE_GELU: # tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP. gate_cubed = gate * gate * gate @@ -124,7 +130,8 @@ def silu_and_mul_fwd( ): assert input.is_contiguous() assert output.is_contiguous() - assert (limit is None and alpha is None) or (limit is not None and alpha is not None) + # limit+alpha: gpt-oss 语义 (up+1)*silu(alpha*gate); 仅 limit: clamp 后标准 silu (DeepSeek-V4) + assert alpha is None or limit is not None stride_input_m = input.stride(0) stride_input_n = input.stride(1) @@ -147,6 +154,7 @@ def silu_and_mul_fwd( while triton.cdiv(size_m, BLOCK_M) > 8192: BLOCK_M *= 2 USE_LIMIT_AND_ALPHA = limit is not None and alpha is not None + USE_LIMIT_ONLY = limit is not None and alpha is None grid = ( triton.cdiv(size_n, BLOCK_N), @@ -171,6 +179,7 @@ def silu_and_mul_fwd( num_warps=num_warps, layout=layout, USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA, + USE_LIMIT_ONLY=USE_LIMIT_ONLY, USE_TANH_APPROXIMATE_GELU=ffn_use_tanh_approximate_gelu(), ) return diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 05544e149a..95f7e8ab76 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,6 +4,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek3_2mem_manager import Deepseek3_2MemoryManager +from .deepseek4_mem_manager import DeepseekV4MemoryManager from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager @@ -17,6 +18,7 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek3_2MemoryManager", + "DeepseekV4MemoryManager", "FP8PerTokenGroupQuantDeepseek3_2MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", diff --git a/lightllm/common/kv_cache_mem_manager/allocator.py b/lightllm/common/kv_cache_mem_manager/allocator.py index 850c158778..0179ed2714 100644 --- a/lightllm/common/kv_cache_mem_manager/allocator.py +++ b/lightllm/common/kv_cache_mem_manager/allocator.py @@ -3,13 +3,13 @@ from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.log_utils import init_logger -from typing import Union, List +from typing import Union, List, Optional logger = init_logger(__name__) class KvCacheAllocator: - def __init__(self, size: int) -> None: + def __init__(self, size: int, shared_name: Optional[str] = None) -> None: self.size = size self.mem_state = torch.arange( 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True @@ -26,9 +26,11 @@ def __init__(self, size: int) -> None: rank_in_node = get_current_rank_in_node() # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + # shared_name 为 None 时使用主 kv 池的默认名(router 调度据此估算);DeepSeek-V4 的压缩子池等 + # 需要各自独立的计数器,传入区别于主池的唯一名,避免多个 allocator 写同一个共享计数器。 + if shared_name is None: + shared_name = f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + self.shared_can_use_token_num = SharedInt(shared_name) self.shared_can_use_token_num.set_value(self.can_use_mem_size) return diff --git a/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py new file mode 100644 index 0000000000..cfb149dcec --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/deepseek4_mem_manager.py @@ -0,0 +1,872 @@ +import torch +import torch.distributed as dist +from typing import List, Optional, Union +from .mem_manager import MemoryManager +from .operator import DeepseekV4MemOperator +from .allocator import KvCacheAllocator +from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.log_utils import init_logger +from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + +logger = init_logger(__name__) + + +# fp8_ds_mla packed-latent byte layout (ABI shared with the flash_mla extra-cache fork and +# sglang/vllm): 448B NoPE fp8 + 64*2B RoPE bf16 + 7B ue8m0 scale + 1B pad = 584B per token, +# stored in page slabs whose tail carries the per-token scale bytes. +DSV4_MLA_NOPE_DIM = 448 +DSV4_MLA_ROPE_DIM = 64 +DSV4_MLA_HEAD_DIM = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM +DSV4_MLA_QUANT_GROUP_SIZE = 64 +DSV4_MLA_SCALE_BYTES = DSV4_MLA_NOPE_DIM // DSV4_MLA_QUANT_GROUP_SIZE + 1 +DSV4_MLA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 + DSV4_MLA_SCALE_BYTES +DSV4_MLA_DATA_BYTES_PER_TOKEN = DSV4_MLA_NOPE_DIM + DSV4_MLA_ROPE_DIM * 2 +DSV4_MLA_PAGE_ALIGN_BYTES = DSV4_MLA_DATA_BYTES_PER_TOKEN +DSV4_INDEXER_HEAD_DIM = 128 +DSV4_INDEXER_SCALE_BYTES = 4 +DSV4_INDEXER_BYTES_PER_TOKEN = DSV4_INDEXER_HEAD_DIM + DSV4_INDEXER_SCALE_BYTES +DSV4_FP8_E4M3_MAX = 448.0 +DSV4_FP8_SCALE_MIN = 1e-4 +DSV4_SWA_PAGE_SIZE = 128 +DSV4_C4_PAGE_SIZE = 64 +DSV4_C128_PAGE_SIZE = 2 +DSV4_PROMPT_CACHE_PAGE_SIZE = DSV4_C4_PAGE_SIZE * 4 +# compressor state ring: c4 overlap 对为每页 2 个分组槽 × ratio 4 行;c128 离线聚合为每页 1 组。 +DSV4_C4_STATE_RING = 8 +DSV4_C128_STATE_RING = 128 +# swa 池占 full token 空间的比例下限(sglang swa_full_tokens_ratio=0.1 同值)。 +# v5 的 swa 压力阀(借页/驱逐)已覆盖 radix 树与准入波次的瞬时增长,结构性预算 +# (max_req×window + batch_max_tokens 余量)另行叠加,0.1 仅作 full 池比例下限。 +DSV4_SWA_FULL_TOKENS_RATIO = 0.1 + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +class PackedPagePool: + """fp8_ds_mla 风格的 page-slab 存储: 每页前段连续放 token 的 data 字节,页尾放 per-token scale 字节。 + + 寻址是纯 token 槽位 (page = slot // page_size),page 只是 scale-tail/对齐的物理打包技巧, + 不存在页粒度的分配。``write``/``read`` 是 torch 参考实现(单测 oracle);生产写入走 + triton packed writer(destindex_copy_kv_flashmla_dsv4 等),kernel 直接消费 ``buffer``。 + """ + + def __init__( + self, + size: int, + page_size: int, + layer_num: int, + data_bytes: int, + scale_bytes: int, + align_bytes: int = 1, + device: str = "cuda", + ): + self.size = size + self.page_size = page_size + self.layer_num = layer_num + self.data_bytes_per_token = data_bytes + self.scale_bytes_per_token = scale_bytes + self.bytes_per_token = data_bytes + scale_bytes + self.num_pages = _ceil_div(size + 1, page_size) + self.bytes_per_page = _ceil_div(page_size * self.bytes_per_token, align_bytes) * align_bytes + self.scale_offset_in_page = page_size * data_bytes + self.buffer = torch.zeros((layer_num, self.num_pages, self.bytes_per_page), dtype=torch.uint8, device=device) + self.HOLD_TOKEN_MEMINDEX = size + + def get_layer_buffer(self, layer_index: int) -> torch.Tensor: + return self.buffer[layer_index] + + def _loc_offsets(self, loc: torch.Tensor): + loc = loc.long() + page = torch.div(loc, self.page_size, rounding_mode="floor") + token = loc % self.page_size + page_base = page * self.bytes_per_page + data_offsets = page_base + token * self.data_bytes_per_token + scale_offsets = page_base + self.scale_offset_in_page + token * self.scale_bytes_per_token + return data_offsets, scale_offsets + + def write(self, layer_index: int, loc: torch.Tensor, packed: torch.Tensor) -> None: + if loc.numel() == 0: + return + loc = loc.reshape(-1) + packed = packed.reshape(-1, self.bytes_per_token).contiguous() + flat = self.buffer[layer_index].view(-1) + data_offsets, scale_offsets = self._loc_offsets(loc) + data_range = torch.arange(self.data_bytes_per_token, device=loc.device) + scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device) + flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] = packed[:, : self.data_bytes_per_token] + flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] = packed[:, self.data_bytes_per_token :] + return + + def read(self, layer_index: int, loc: torch.Tensor) -> torch.Tensor: + loc = loc.reshape(-1) + if loc.numel() == 0: + return torch.empty((0, self.bytes_per_token), dtype=torch.uint8, device=self.buffer.device) + flat = self.buffer[layer_index].view(-1) + data_offsets, scale_offsets = self._loc_offsets(loc) + data_range = torch.arange(self.data_bytes_per_token, device=loc.device) + scale_range = torch.arange(self.scale_bytes_per_token, device=loc.device) + data = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)] + scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)] + return torch.cat([data, scale], dim=1).contiguous() + + +class DeepseekV4MemoryManager(MemoryManager): + """DeepSeek-V4 KV cache: 窗口 latent(全层) + c4/c128 压缩 latent(压实层) + c4 indexer-K。 + + 与兄弟 manager 一致的 token-slot 设计;req 索引的表都在 DeepseekV4ReqManager。 + + - ``swa_pool``: 584B packed latent,所有层。池子小于 full token 空间;prep 阶段 + ``alloc_swa_prefill/decode`` 按**页**(128 槽,位置对齐: slot(p)=page_base+p%128)分配, + 映射记录到 ``full_to_swa_indexs``(以 full token 槽位为键)。出窗槽位由 DeepseekV4ReqManager + 在 prep 阶段批量惰性回收(``evict_swa``,页存活计数减到 0 才整页归还);full 槽位释放时 + ``free`` 级联回收对应 swa 槽,所以 radix 驱逐/请求释放/暂停无需任何额外协议。 + 页 allocator 触底时先走压力阀(radix 对 ref==0 节点回收)再 assert。 + 没有 ring buffer,prefill chunk 大小不受 sliding_window 限制。 + - ``c4_pool``/``c128_pool``: 压缩 latent,按 qwen3next 的层号压实手法只为压缩层建层; + c4 另带 packed indexer-K 池。槽位映射(``full_to_c4/c128_indexs``)以组末 token 的 full + 槽位为键(prep 阶段分配/scatter),``free`` 级联回收,与 swa 完全同构。 + - 写入走标准 operator 路径(``pack_mla_kv_to_cache``),内部为 triton packed writer; + torch codecs 保留为 ABI 的可执行规格(单测 oracle)。 + """ + + operator_class = DeepseekV4MemOperator + + mla_nope_dim = DSV4_MLA_NOPE_DIM + mla_rope_dim = DSV4_MLA_ROPE_DIM + mla_head_dim = DSV4_MLA_HEAD_DIM + mla_quant_group_size = DSV4_MLA_QUANT_GROUP_SIZE + mla_scale_bytes = DSV4_MLA_SCALE_BYTES + mla_bytes_per_token = DSV4_MLA_BYTES_PER_TOKEN + indexer_head_dim_default = DSV4_INDEXER_HEAD_DIM + indexer_bytes_per_token = DSV4_INDEXER_BYTES_PER_TOKEN + + def __init__( + self, + size, + dtype, + head_num, + head_dim, + layer_num, + compress_rates: List[int], + indexer_head_dim: int = 128, + max_request_num: Optional[int] = None, + sliding_window: Optional[int] = None, + swa_extra_token_num: int = 0, + swa_full_tokens_ratio: float = DSV4_SWA_FULL_TOKENS_RATIO, + always_copy=False, + mem_fraction=0.9, + ): + assert head_num == 1, "DeepSeek-V4 是 MLA(MQA),dense latent 的 head_num 必须为 1" + assert head_dim == self.mla_head_dim, f"DeepSeek-V4 packed KV 期望 head_dim={self.mla_head_dim}" + assert ( + indexer_head_dim == self.indexer_head_dim_default + ), f"DeepSeek-V4 packed indexer-K 期望 indexer_head_dim={self.indexer_head_dim_default}" + assert len(compress_rates) == layer_num, f"compress_rates 长度 {len(compress_rates)} 必须等于 layer_num {layer_num}" + assert all(r in (0, 4, 128) for r in compress_rates), "compress_rates 取值只能是 0/4/128" + + self.compress_rates = list(compress_rates) + self.n_c4 = sum(1 for r in self.compress_rates if r == 4) + self.n_c128 = sum(1 for r in self.compress_rates if r == 128) + self.indexer_head_dim = indexer_head_dim + self.max_request_num = max_request_num + self.sliding_window = sliding_window + # 活跃窗口(max_request_num * sliding_window)之外的余量: 在途 prefill chunk 的瞬时占用 + # (出窗槽位要到下一次 prep 才回收) + radix cache 持有的窗口尾部。 + self.swa_extra_token_num = int(swa_extra_token_num) + self.swa_full_tokens_ratio = float(swa_full_tokens_ratio) + + # 全局层号 -> 各压缩池内的压实层号(同 qwen3next 的层号压实手法) + self.layer_to_c4_idx = {} + self.layer_to_c128_idx = {} + c4 = c128 = 0 + for lid, r in enumerate(self.compress_rates): + if r == 4: + self.layer_to_c4_idx[lid] = c4 + c4 += 1 + elif r == 128: + self.layer_to_c128_idx[lid] = c128 + c128 += 1 + + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + + # ------------------------------------------------------------------ sizing + def _swa_per_req_budget(self) -> int: + # 活跃请求保留 window + 一个 radix 页(req_manager._swa_retain_len: 让最近完成的 + # prompt-cache 边界的结尾页恒驻留)。V4 的 prompt-cache 边界取 256 token, + # 避免 radix 共享前缀落在 c4 物理页(64 c4 entry = 256 token)中间。 + return int(self.sliding_window) + DSV4_PROMPT_CACHE_PAGE_SIZE + + def _planned_swa_size(self, full_size: int) -> int: + # swa 池按页分配(页 = 128 = sliding_window),容量向上取整到整页。 + if self.max_request_num is None or self.sliding_window is None: + return _ceil_div(full_size, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE + cap = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num + cap = max(cap, int(full_size * self.swa_full_tokens_ratio)) + cap = max(1, min(full_size, cap)) + return _ceil_div(cap, DSV4_SWA_PAGE_SIZE) * DSV4_SWA_PAGE_SIZE + + @staticmethod + def _slab_bytes_per_slot(page_size: int, data_bytes: int, scale_bytes: int, align_bytes: int = 1) -> float: + bytes_per_page = _ceil_div(page_size * (data_bytes + scale_bytes), align_bytes) * align_bytes + return bytes_per_page / page_size + + @staticmethod + def _paged_state_rows(num_swa_pages: int, ring: int, ratio: int) -> int: + rows = num_swa_pages * ring + ring + 1 + return _ceil_div(rows, ratio) * ratio + + @staticmethod + def _init_state_sentinel(buffer: torch.Tensor) -> None: + half = buffer.shape[-1] // 2 + buffer[:, -1, :half].zero_() + buffer[:, -1, half:].fill_(float("-inf")) + return + + def _paged_state_bytes_per_swa_slot(self) -> float: + """c4/c128 compressor state(swa 页派生寻址)摊到每个 swa 槽的字节数。""" + per_page = 0.0 + if self.n_c4 > 0: + per_page += DSV4_C4_STATE_RING * (4 * self.head_dim + 4 * self.indexer_head_dim) * 4 * self.n_c4 + if self.n_c128 > 0: + per_page += DSV4_C128_STATE_RING * (2 * self.head_dim) * 4 * self.n_c128 + return per_page / DSV4_SWA_PAGE_SIZE + + def _swa_slot_bytes(self) -> float: + per_layer = self._slab_bytes_per_slot( + DSV4_SWA_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES + ) + return per_layer * self.layer_num + self._paged_state_bytes_per_swa_slot() + + def _compressed_cell_size(self) -> float: + """每个 full token 摊到压缩池上的精确字节数(按 page-slab 对齐后)。""" + c4_latent = self._slab_bytes_per_slot( + DSV4_C4_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES + ) + c128_latent = self._slab_bytes_per_slot( + DSV4_C128_PAGE_SIZE, DSV4_MLA_DATA_BYTES_PER_TOKEN, self.mla_scale_bytes, DSV4_MLA_PAGE_ALIGN_BYTES + ) + c4_indexer = self._slab_bytes_per_slot(DSV4_C4_PAGE_SIZE, self.indexer_head_dim, DSV4_INDEXER_SCALE_BYTES) + return (c4_latent + c4_indexer) * self.n_c4 / 4 + c128_latent * self.n_c128 / 128 + + def get_cell_size(self): + compressed = self._compressed_cell_size() + if self.size is None: + return self._swa_slot_bytes() + compressed + swa_ratio = self._planned_swa_size(self.size) / max(1, self.size) + return self._swa_slot_bytes() * swa_ratio + compressed + + def profile_size(self, mem_fraction): + if self.size is not None: + return + + torch.cuda.empty_cache() + world_size = dist.get_world_size() + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) + available_bytes = available_memory * 1024 ** 3 + swa_slot_bytes = self._swa_slot_bytes() + compressed_cell = self._compressed_cell_size() + + if self.max_request_num is not None and self.sliding_window is not None and compressed_cell > 0: + swa_budget = int(self.max_request_num) * self._swa_per_req_budget() + self.swa_extra_token_num + full_cell = swa_slot_bytes + compressed_cell + if available_bytes <= full_cell * swa_budget: + # 小显存: full token 数还到不了 swa 预算,swa 池跟随 full token 数(每 token 一个 swa 槽)。 + self.size = max(1, int(available_bytes / full_cell)) + else: + size_budget = max(1, int((available_bytes - swa_slot_bytes * swa_budget) / compressed_cell)) + if size_budget * self.swa_full_tokens_ratio > swa_budget: + # 比例下限生效(_planned_swa_size 会取 ratio*full),按该机制反解 full。 + self.size = max( + 1, int(available_bytes / (swa_slot_bytes * self.swa_full_tokens_ratio + compressed_cell)) + ) + else: + self.size = size_budget + else: + self.size = max(1, int(available_bytes / (swa_slot_bytes + compressed_cell))) + + if world_size > 1: + tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + self.size = tensor.item() + + logger.info( + f"{str(available_memory)} GB space is available after load the model weight\n" + f"{str(self.get_cell_size() / 1024 ** 2)} MB is the conservative size of one token kv cache\n" + f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" + ) + return + + # ------------------------------------------------------------------ buffers + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + rank_in_node = get_current_rank_in_node() + server = get_unique_server_name() + + self.swa_size = self._planned_swa_size(size) + assert self.swa_size % DSV4_SWA_PAGE_SIZE == 0 + self.swa_pool = PackedPagePool( + size=self.swa_size, + page_size=DSV4_SWA_PAGE_SIZE, + layer_num=layer_num, + data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN, + scale_bytes=self.mla_scale_bytes, + align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES, + ) + # 注意: 该别名是 page 索引([layer, num_pages, bytes_per_page])而非 token 索引, + # 只允许 get_att_input_params 的消费者使用;token 索引语义的继承接口已显式 fence。 + self.kv_buffer = self.swa_pool.buffer + # 页粒度分配(页 = 128 槽,位置对齐): 槽位不变式 slot(p) = page_base + p%128。 + # swa_size 整页对齐 ⇒ HOLD 槽(swa_size)独占池子最后一个物理页,永不参与分配。 + self.swa_num_pages = self.swa_size // DSV4_SWA_PAGE_SIZE + self.swa_page_allocator = KvCacheAllocator( + self.swa_num_pages, shared_name=f"{server}_dsv4_swa_can_use_page_num_{rank_in_node}" + ) + # 页存活计数 = 指向该页的有效 full_to_swa 行数;减到 0 归还 allocator(出窗逐 token + # 回收下,「部分出窗页」计数 > 0 自然受保护)。下标含 HOLD 页(只读不增减)。 + self.swa_page_live_count = torch.zeros((self.swa_pool.num_pages,), dtype=torch.int32, device="cuda") + # swa 压力阀(可选): 页 allocator 触底时回调(radix 对 ref==0 节点回收 swa 页), + # 由 backend 在 radix cache 创建后注入;assert 仍是最后防线。 + self._swa_pressure_valve = None + self.full_to_swa_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") + self.full_to_swa_indexs[size] = self.swa_pool.HOLD_TOKEN_MEMINDEX + + self.c4_size = _ceil_div(size, 4) + self.c128_size = _ceil_div(size, 128) + self.c4_pool: Optional[PackedPagePool] = None + self.c4_indexer_pool: Optional[PackedPagePool] = None + self.c4_allocator: Optional[KvCacheAllocator] = None + self.c4_page_allocator: Optional[KvCacheAllocator] = None + self.c4_page_live_count: Optional[torch.Tensor] = None + self.c128_pool: Optional[PackedPagePool] = None + self.c128_allocator: Optional[KvCacheAllocator] = None + self.c4_state_buffer: Optional[torch.Tensor] = None + self.c4_indexer_state_buffer: Optional[torch.Tensor] = None + self.c128_state_buffer: Optional[torch.Tensor] = None + # 压缩槽映射: 键 = 组末 token(位置 (g+1)%ratio==0)的 full 槽位,值 = 压缩池槽位。 + # 与 full_to_swa_indexs 同构: radix 持有 full 槽 => 映射行存活,free 级联回收。 + self.full_to_c4_indexs: Optional[torch.Tensor] = None + self.full_to_c128_indexs: Optional[torch.Tensor] = None + if self.n_c4 > 0: + self.c4_pool = PackedPagePool( + size=self.c4_size, + page_size=DSV4_C4_PAGE_SIZE, + layer_num=self.n_c4, + data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN, + scale_bytes=self.mla_scale_bytes, + align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES, + ) + self.c4_indexer_pool = PackedPagePool( + size=self.c4_size, + page_size=DSV4_C4_PAGE_SIZE, + layer_num=self.n_c4, + data_bytes=self.indexer_head_dim, + scale_bytes=DSV4_INDEXER_SCALE_BYTES, + ) + self.c4_num_pages = self.c4_size // DSV4_C4_PAGE_SIZE + assert self.c4_num_pages > 0, "DeepSeek-V4 c4 pool must have at least one usable full page" + self.c4_page_allocator = KvCacheAllocator( + self.c4_num_pages, shared_name=f"{server}_dsv4_c4_can_use_page_num_{rank_in_node}" + ) + self.c4_page_live_count = torch.zeros((self.c4_pool.num_pages,), dtype=torch.int32, device="cuda") + self.full_to_c4_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") + self.full_to_c4_indexs[size] = self.c4_pool.HOLD_TOKEN_MEMINDEX + # c4 compressor 在途状态(attention + indexer): swa 页派生寻址(翻译③),随 swa 页 + # 生灭 -> radix 命中零拷贝续算。行数 = 页数*ring + ring(HOLD 页) + 1(哨兵), + # 取整到 ratio;末行哨兵 kv=0/score=-inf(KVAndScore.clear 语义),其余行由内核在 + # 组起点覆写,无需按页清零。last_dim = 2*coff*head_dim(overlap coff=2)。 + state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C4_STATE_RING, 4) + self.c4_state_buffer = torch.zeros( + (self.n_c4, state_rows, 4 * self.head_dim), dtype=torch.float32, device="cuda" + ) + self.c4_indexer_state_buffer = torch.zeros( + (self.n_c4, state_rows, 4 * self.indexer_head_dim), dtype=torch.float32, device="cuda" + ) + for buf in (self.c4_state_buffer, self.c4_indexer_state_buffer): + self._init_state_sentinel(buf) + if self.n_c128 > 0: + self.c128_pool = PackedPagePool( + size=self.c128_size, + page_size=DSV4_C128_PAGE_SIZE, + layer_num=self.n_c128, + data_bytes=DSV4_MLA_DATA_BYTES_PER_TOKEN, + scale_bytes=self.mla_scale_bytes, + align_bytes=DSV4_MLA_PAGE_ALIGN_BYTES, + ) + self.c128_allocator = KvCacheAllocator( + self.c128_size, shared_name=f"{server}_dsv4_c128_can_use_token_num_{rank_in_node}" + ) + self.full_to_c128_indexs = torch.full((size + 1,), -1, dtype=torch.int32, device="cuda") + self.full_to_c128_indexs[size] = self.c128_pool.HOLD_TOKEN_MEMINDEX + # c128 compressor 在途状态: 与 c4 同样由 full->swa 推导行号,但 ring=128 且无 overlap。 + # last_dim = 2*head_dim;末行是 swa 缺失/出窗时读取的哨兵。 + state_rows = self._paged_state_rows(self.swa_num_pages, DSV4_C128_STATE_RING, 128) + self.c128_state_buffer = torch.zeros( + (self.n_c128, state_rows, 2 * self.head_dim), dtype=torch.float32, device="cuda" + ) + self._init_state_sentinel(self.c128_state_buffer) + + logger.info( + f"DeepseekV4MemoryManager pools: full_tokens={size} swa={self.swa_size}({self.swa_num_pages}p) " + f"c4={self.c4_size}(L={self.n_c4}) c128={self.c128_size}(L={self.n_c128}) " + f"packed_kv_bytes={self.mla_bytes_per_token} indexer_bytes={self.indexer_bytes_per_token}" + ) + + # ------------------------------------------------------------------ buffer accessors + def get_att_input_params(self, layer_index: int): + return self.swa_pool.get_layer_buffer(layer_index) + + def _pool_and_local_layer(self, layer_index: int): + r = self.compress_rates[layer_index] + if r == 4: + return self.c4_pool, self.layer_to_c4_idx[layer_index] + if r == 128: + return self.c128_pool, self.layer_to_c128_idx[layer_index] + raise AssertionError(f"layer {layer_index} (rate {r}) 不是压缩层,没有压缩池") + + def get_compressed_kv_buffer(self, layer_index: int) -> torch.Tensor: + pool, local_layer = self._pool_and_local_layer(layer_index) + return pool.get_layer_buffer(local_layer) + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" + return self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]) + + def get_c4_state_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged compressor state" + return self.c4_state_buffer[self.layer_to_c4_idx[layer_index]] + + def get_c4_indexer_state_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 paged indexer state" + return self.c4_indexer_state_buffer[self.layer_to_c4_idx[layer_index]] + + def get_c128_state_buffer(self, layer_index: int) -> torch.Tensor: + assert self.compress_rates[layer_index] == 128, "只有 c128(HCA) 层有 paged compressor state" + return self.c128_state_buffer[self.layer_to_c128_idx[layer_index]] + + # ------------------------------------------------------------------ swa slot lifecycle + def set_swa_pressure_valve(self, valve) -> None: + """valve(need_pages): 在页 allocator 不足时尝试腾页(radix 对 ref==0 节点回收 swa)。""" + self._swa_pressure_valve = valve + return + + def _alloc_swa_pages(self, need_pages: int) -> torch.Tensor: + if need_pages > self.swa_page_allocator.can_use_mem_size and self._swa_pressure_valve is not None: + self._swa_pressure_valve(need_pages - self.swa_page_allocator.can_use_mem_size) + return self.swa_page_allocator.alloc(need_pages) + + def _count_swa_pages(self, swa_slots: torch.Tensor, delta: int) -> torch.Tensor: + """按 slot 所在页更新存活计数,返回触达的页(去重)。""" + pages = torch.div(swa_slots.long(), DSV4_SWA_PAGE_SIZE, rounding_mode="floor") + ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device) + self.swa_page_live_count.index_add_(0, pages, ones) + return torch.unique(pages) + + def alloc_swa_prefill( + self, + b_req_idx: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_seq_len: torch.Tensor, + req_to_token_indexs: torch.Tensor, + ) -> None: + """prefill prep: 为各请求位置 [ready, seq) 的新 token 分配位置对齐的 swa 槽。 + + 槽位不变式: slot(p) = page_base(p 所在页) + p%128,page_base % 128 == 0。 + 续页(start 非整页,只可能是首页)的 base 从上一 token 的映射派生 + (full_to_swa[req_to_token[req, start-1]],该 token 必在保留窗内);其余页全新分配。 + radix 命中(ready 必 128 对齐)的借用方从全新页开始,与节点持有页天然不相交。 + 必须在 init_req_to_token_indexes 之后调用(scatter 目标经 req_to_token 行)。 + """ + page = DSV4_SWA_PAGE_SIZE + hold_req_id = self.max_request_num # padding 行的请求 id(req_manager.HOLD_REQUEST_ID) + req_list = b_req_idx.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + + segs = [] # (req_idx, start, end, n_new_pages, has_cont_page) + total_new_pages = 0 + for req_idx, start, end in zip(req_list, ready_list, seq_list): + req_idx, start, end = int(req_idx), int(start), int(end) + if req_idx == hold_req_id or end <= start: + continue + first_new_page = _ceil_div(start, page) + n_new = max(0, (end - 1) // page - first_new_page + 1) + segs.append((req_idx, start, end, n_new, start % page != 0)) + total_new_pages += n_new + if not segs: + return + + new_pages = self._alloc_swa_pages(total_new_pages).cuda(non_blocking=True).long() if total_new_pages else None + page_cursor = 0 + for req_idx, start, end, n_new, has_cont in segs: + positions = torch.arange(start, end, dtype=torch.long, device="cuda") + page_local = torch.div(positions, page, rounding_mode="floor") - start // page + bases = torch.empty(((end - 1) // page - start // page + 1,), dtype=torch.long, device="cuda") + if has_cont: + prev_slot = int(self.full_to_swa_indexs[req_to_token_indexs[req_idx, start - 1].long()].item()) + # 续页不变式: 上一 token 必驻留(retain >= 2)且位置对齐(未来 resume/MTP 改动的哨兵)。 + assert prev_slot >= 0 and prev_slot % page == (start - 1) % page + bases[0] = prev_slot - (start - 1) % page + if n_new: + bases[1 if has_cont else 0 :] = new_pages[page_cursor : page_cursor + n_new] * page + page_cursor += n_new + slots = (bases[page_local] + positions % page).to(torch.int32) + self.full_to_swa_indexs[req_to_token_indexs[req_idx, start:end].long()] = slots + self._count_swa_pages(slots, 1) + return + + def alloc_swa_decode( + self, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, + mem_indexes: torch.Tensor, + req_to_token_indexs: torch.Tensor, + ) -> None: + """decode prep: 本步 token(位置 seq-1)的 swa 槽。整页起点开新页,否则上一 token 槽 +1 + (位置对齐不变式保证同页连续)。scatter 目标用 mem_indexes(此刻 req_to_token 尚未写本步)。 + + 注意: 续槽从上一位置的映射派生,故同一请求的多行(MTP 多 token/步)在同一批内不支持 + (DSV4 启动参数已拒绝 MTP;支持需按步内顺序分段派生)。""" + page = DSV4_SWA_PAGE_SIZE + hold_req_id = self.max_request_num + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() + cont_rows, cont_prev_pos, new_rows = [], [], [] + for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)): + req_idx, seq_len = int(req_idx), int(seq_len) + if req_idx == hold_req_id or seq_len <= 0: + continue + if (seq_len - 1) % page == 0: + new_rows.append(i) + else: + cont_rows.append(i) + cont_prev_pos.append(seq_len - 2) + mem_indexes = mem_indexes.cuda().long().reshape(-1) + if cont_rows: + req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda") + prev_full = req_to_token_indexs[req_rows, torch.tensor(cont_prev_pos, device="cuda")].long() + prev_slots = self.full_to_swa_indexs[prev_full] + # 续槽不变式哨兵: 上一位置必驻留(retain 覆盖)。prep 阶段本就有同步,代价可忽略。 + assert bool((prev_slots >= 0).all()) + slots = prev_slots + 1 + self.full_to_swa_indexs[mem_indexes[cont_rows]] = slots + self._count_swa_pages(slots, 1) + if new_rows: + pages = self._alloc_swa_pages(len(new_rows)).cuda(non_blocking=True).long() + slots = (pages * page).to(torch.int32) + self.full_to_swa_indexs[mem_indexes[new_rows]] = slots + self._count_swa_pages(slots, 1) + return + + def evict_swa(self, full_slots: torch.Tensor) -> None: + """回收 full 槽位对应的 swa 槽(出窗惰性回收 / free 级联 / 压力阀共用)。 + 未映射(-1)的槽位跳过;页计数减到 0 时整页归还 allocator。""" + if full_slots.numel() == 0: + return + full_slots = full_slots.cuda().long().reshape(-1) + full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX]) + if full_slots.numel() == 0: + return + swa_slots = self.full_to_swa_indexs[full_slots] + valid = swa_slots >= 0 + valid_slots = swa_slots[valid] + if valid_slots.numel() == 0: + return + self.full_to_swa_indexs[full_slots[valid]] = -1 + touched = self._count_swa_pages(valid_slots, -1) + empty = touched[self.swa_page_live_count[touched] == 0] + if empty.numel() > 0: + self.swa_page_allocator.free(empty.to(torch.int32)) + return + + def _evict_compress(self, full_slots: torch.Tensor, mapping: torch.Tensor, allocator: KvCacheAllocator) -> None: + full_slots = full_slots.cuda().long().reshape(-1) + # 去重: 同批重复槽会 gather 出重复的压缩槽 -> allocator 双重释放(free 已去重,直呼叫方防御)。 + full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX]) + if full_slots.numel() == 0: + return + slots = mapping[full_slots] + valid = slots >= 0 + valid_slots = slots[valid] + if valid_slots.numel() == 0: + return + allocator.free(valid_slots) + mapping[full_slots[valid]] = -1 + return + + def alloc_c4_pages(self, need_pages: int) -> torch.Tensor: + assert self.c4_page_allocator is not None, "DeepSeek-V4 c4 page allocator is not initialized" + return self.c4_page_allocator.alloc(need_pages) + + def count_c4_slots(self, c4_slots: torch.Tensor, delta: int) -> torch.Tensor: + """按 c4 slot 所在页更新存活计数,返回触达的页(去重)。""" + assert self.c4_page_live_count is not None, "DeepSeek-V4 c4 page live count is not initialized" + pages = torch.div(c4_slots.long(), DSV4_C4_PAGE_SIZE, rounding_mode="floor") + ones = torch.full(pages.shape, delta, dtype=torch.int32, device=pages.device) + self.c4_page_live_count.index_add_(0, pages, ones) + return torch.unique(pages) + + def evict_c4(self, full_slots: torch.Tensor) -> None: + """回收 full 槽位(组末 token)映射的 c4 槽。非组末/未映射(-1)的槽位跳过。""" + if self.c4_page_allocator is None or full_slots.numel() == 0: + return + full_slots = full_slots.cuda().long().reshape(-1) + full_slots = torch.unique(full_slots[full_slots != self.HOLD_TOKEN_MEMINDEX]) + if full_slots.numel() == 0: + return + slots = self.full_to_c4_indexs[full_slots] + valid = slots >= 0 + valid_slots = slots[valid] + if valid_slots.numel() == 0: + return + self.full_to_c4_indexs[full_slots[valid]] = -1 + touched = self.count_c4_slots(valid_slots, -1) + empty = touched[self.c4_page_live_count[touched] == 0] + if empty.numel() > 0: + self.c4_page_allocator.free(empty.to(torch.int32)) + return + + def evict_c128(self, full_slots: torch.Tensor) -> None: + """回收 full 槽位(组末 token)映射的 c128 槽。非组末/未映射(-1)的槽位跳过。""" + if self.c128_allocator is None or full_slots.numel() == 0: + return + self._evict_compress(full_slots, self.full_to_c128_indexs, self.c128_allocator) + return + + # ------------------------------------------------------------------ alloc/free (cascade) + def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: + """释放 full token 槽位,级联回收其 swa 槽与 c4/c128 压缩槽。radix 驱逐、请求释放/暂停都走这里。 + + 先对 full 槽去重: 同批重复槽位会让映射 gather 出重复的压缩/swa 槽,导致 allocator 双重释放。""" + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int64) + if free_index.numel() > 0: + free_index = torch.unique(free_index) + self.evict_swa(free_index) + self.evict_c4(free_index) + self.evict_c128(free_index) + super().free(free_index) + return + + def free_all(self): + super().free_all() + self.swa_page_allocator.free_all() + self.swa_page_live_count.zero_() + self.full_to_swa_indexs.fill_(-1) + self.full_to_swa_indexs[self.HOLD_TOKEN_MEMINDEX] = self.swa_pool.HOLD_TOKEN_MEMINDEX + if self.c4_page_allocator is not None: + self.c4_page_allocator.free_all() + self.c4_page_live_count.zero_() + self.full_to_c4_indexs.fill_(-1) + self.full_to_c4_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c4_pool.HOLD_TOKEN_MEMINDEX + if self.c128_allocator is not None: + self.c128_allocator.free_all() + self.full_to_c128_indexs.fill_(-1) + self.full_to_c128_indexs[self.HOLD_TOKEN_MEMINDEX] = self.c128_pool.HOLD_TOKEN_MEMINDEX + return + + def alloc_c4(self, need_size) -> torch.Tensor: + raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation; call alloc_c4_pages instead") + + def alloc_c128(self, need_size) -> torch.Tensor: + return self.c128_allocator.alloc(need_size) + + def free_c4(self, free_index) -> None: + raise AssertionError("DeepSeek-V4 c4 uses page live-count release; call evict_c4 instead") + + def free_c128(self, free_index) -> None: + self.c128_allocator.free(free_index) + + # ------------------------------------------------------------------ packed codecs (torch reference) + # 与 sglang/vllm 的 fp8_ds_mla 字节布局逐位对齐(ue8m0 幂次 scale)。这些 torch 实现是该 ABI 的 + # 可执行规格(单测 oracle,triton writer 与其逐字节对拍),不可删除。 + def _pack_mla_kv(self, kv: torch.Tensor) -> torch.Tensor: + kv = kv.reshape(-1, self.mla_head_dim) + out = torch.empty((kv.shape[0], self.mla_bytes_per_token), dtype=torch.uint8, device=kv.device) + nope = kv[:, : self.mla_nope_dim].float().reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size) + scale = torch.clamp(nope.abs().amax(dim=-1) / DSV4_FP8_E4M3_MAX, min=DSV4_FP8_SCALE_MIN) + scale_exp = torch.ceil(torch.log2(scale)).to(torch.int32) + scale = torch.exp2(scale_exp.float()) + nope_fp8 = torch.clamp(nope / scale.unsqueeze(-1), -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to( + torch.float8_e4m3fn + ) + out[:, : self.mla_nope_dim].copy_(nope_fp8.reshape(-1, self.mla_nope_dim).view(dtype=torch.uint8)) + rope_start = self.mla_nope_dim + rope_end = rope_start + self.mla_rope_dim * 2 + rope = kv[:, self.mla_nope_dim : self.mla_head_dim].contiguous().to(torch.bfloat16) + out[:, rope_start:rope_end].copy_(rope.view(dtype=torch.uint8).reshape(-1, self.mla_rope_dim * 2)) + scale_start = rope_end + scale_end = scale_start + self.mla_scale_bytes - 1 + out[:, scale_start:scale_end].copy_((scale_exp + 127).to(torch.uint8)) + out[:, scale_end].zero_() + return out + + def _unpack_mla_kv(self, packed: torch.Tensor) -> torch.Tensor: + packed = packed.reshape(-1, self.mla_bytes_per_token) + if packed.shape[0] == 0: + return torch.empty((0, self.mla_head_dim), dtype=self.dtype, device=packed.device) + nope_fp8 = packed[:, : self.mla_nope_dim].view(dtype=torch.float8_e4m3fn).float() + nope_fp8 = nope_fp8.reshape(-1, self.mla_scale_bytes - 1, self.mla_quant_group_size) + rope_start = self.mla_nope_dim + rope_end = rope_start + self.mla_rope_dim * 2 + scale_start = rope_end + scale_end = scale_start + self.mla_scale_bytes - 1 + scale_exp = packed[:, scale_start:scale_end].to(torch.int32) - 127 + scale = torch.exp2(scale_exp.float()) + nope = (nope_fp8 * scale.reshape(-1, self.mla_scale_bytes - 1, 1)).reshape(-1, self.mla_nope_dim) + rope = packed[:, rope_start:rope_end].view(dtype=torch.bfloat16) + return torch.cat([nope.to(self.dtype), rope.to(self.dtype)], dim=-1) + + def _pack_indexer_k(self, indexer_k: torch.Tensor) -> torch.Tensor: + indexer_k = indexer_k.reshape(-1, self.indexer_head_dim) + out = torch.empty( + (indexer_k.shape[0], self.indexer_bytes_per_token), + dtype=torch.uint8, + device=indexer_k.device, + ) + k_float = indexer_k.float() + scale = torch.clamp( + k_float.abs().amax(dim=-1, keepdim=True) / DSV4_FP8_E4M3_MAX, + min=DSV4_FP8_SCALE_MIN, + ) + k_fp8 = torch.clamp(k_float / scale, -DSV4_FP8_E4M3_MAX, DSV4_FP8_E4M3_MAX).to(torch.float8_e4m3fn) + out[:, : self.indexer_head_dim].copy_(k_fp8.view(dtype=torch.uint8)) + out[:, self.indexer_head_dim :].copy_(scale.view(dtype=torch.uint8).reshape(-1, DSV4_INDEXER_SCALE_BYTES)) + return out + + def _unpack_indexer_k(self, packed: torch.Tensor) -> torch.Tensor: + packed = packed.reshape(-1, self.indexer_bytes_per_token) + if packed.shape[0] == 0: + return torch.empty((0, self.indexer_head_dim), dtype=self.dtype, device=packed.device) + k_fp8 = packed[:, : self.indexer_head_dim].view(dtype=torch.float8_e4m3fn).float() + scale = packed[:, self.indexer_head_dim :].view(dtype=torch.float32) + return (k_fp8 * scale).to(self.dtype) + + # ------------------------------------------------------------------ cache write paths + def pack_mla_kv_to_cache(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """标准 operator 写入路径。要求本步已对 mem_index 调过 ``alloc_swa``(prep 阶段); + HOLD/padding 槽位映射到 swa HOLD 槽,写入无害。""" + if kv.shape[0] == 0: + return + from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import ( + destindex_copy_kv_flashmla_dsv4, + ) + + swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)] + destindex_copy_kv_flashmla_dsv4( + kv.reshape(-1, self.mla_head_dim), + swa_slots, + self.swa_pool.get_layer_buffer(layer_index), + self.swa_pool.page_size, + ) + return + + def pack_mla_kv_to_cache_fused_norm_rope( + self, + layer_index: int, + mem_index: torch.Tensor, + kv: torch.Tensor, + kv_weight: torch.Tensor, + eps: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + ): + """同 pack_mla_kv_to_cache,但 rmsnorm + 尾部交错 rope 融合进写入 kernel + (sglang fused_k_norm_rope_flashmla,即 sglang _compute_kv_to_cache 的池侧), + 省掉 bf16 kv 中间量。kv 为 wkv 投影原始输出 [T, head_dim+rope_dim]。""" + if kv.shape[0] == 0: + return + from sglang.jit_kernel.dsv4 import fused_k_norm_rope_flashmla + + swa_slots = self.full_to_swa_indexs[mem_index.cuda().long().reshape(-1)] + # 未映射槽位(-1, 如 decode 图 warmup 的 HOLD 行: prep 跳过 alloc_swa)对老 triton + # 写入核是显式 no-op;sglang fused 核无负槽位防护(负页偏移=非法访存),mask 到 + # swa HOLD 槽(垃圾桶语义,与 padding 行写入一致)。 + swa_slots = torch.where(swa_slots < 0, torch.full_like(swa_slots, self.swa_pool.HOLD_TOKEN_MEMINDEX), swa_slots) + fused_k_norm_rope_flashmla( + kv=kv, + kv_weight=kv_weight, + eps=eps, + freqs_cis=freqs_cis, + positions=positions, + out_loc=swa_slots, + kvcache=self.swa_pool.get_layer_buffer(layer_index), + page_size=self.swa_pool.page_size, + ) + return + + def pack_compressed_kv_to_cache(self, layer_index: int, slots: torch.Tensor, comp: torch.Tensor): + if comp.shape[0] == 0: + return + from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_kv_flashmla_dsv4 import ( + destindex_copy_kv_flashmla_dsv4, + ) + + pool, local_layer = self._pool_and_local_layer(layer_index) + destindex_copy_kv_flashmla_dsv4( + comp.reshape(-1, self.mla_head_dim), + slots.to(comp.device), + pool.get_layer_buffer(local_layer), + pool.page_size, + ) + + def pack_indexer_k_to_cache(self, layer_index: int, slots: torch.Tensor, indexer_k: torch.Tensor): + if indexer_k.shape[0] == 0: + return + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" + from lightllm.models.deepseek_v4.triton_kernel.destindex_copy_indexer_k_dsv4 import ( + destindex_copy_indexer_k_dsv4, + ) + + destindex_copy_indexer_k_dsv4( + indexer_k.reshape(-1, self.indexer_head_dim), + slots.to(indexer_k.device), + self.c4_indexer_pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]), + self.c4_indexer_pool.page_size, + ) + + def gather_indexer_k(self, layer_index: int, slots: torch.Tensor) -> torch.Tensor: + """反量化 gather c4 indexer-K: slots [N](c4 槽位,HOLD 合法) -> [N, indexer_head_dim] bf16。 + indexer top-k 打分用(纯张量操作,cuda-graph 安全)。""" + assert self.compress_rates[layer_index] == 4, "只有 c4(CSA) 层有 indexer-K" + pool = self.c4_indexer_pool + flat = pool.get_layer_buffer(self.layer_to_c4_idx[layer_index]).view(-1) + data_offsets, scale_offsets = pool._loc_offsets(slots.reshape(-1)) + data_range = torch.arange(pool.data_bytes_per_token, device=flat.device) + scale_range = torch.arange(pool.scale_bytes_per_token, device=flat.device) + k_fp8 = flat[data_offsets.unsqueeze(1) + data_range.unsqueeze(0)].view(torch.float8_e4m3fn) + scale = flat[scale_offsets.unsqueeze(1) + scale_range.unsqueeze(0)].contiguous().view(torch.float32) + return (k_fp8.float() * scale).to(torch.bfloat16) + + # ------------------------------------------------------------------ fenced inherited APIs + # kv_buffer 是 page 索引的 uint8 slab,基类按 token 索引读写的接口会静默写坏数据,显式拦截。 + def get_index_kv_buffer(self, index): + raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io") + + def load_index_kv_buffer(self, index, load_tensor_dict): + raise NotImplementedError("DeepSeek-V4 packed page-slab cache does not support token-indexed kv_buffer io") + + def alloc_kv_move_buffer(self, max_req_total_len): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def write_mem_to_page_kv_move_buffer(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def read_page_kv_move_buffer_to_mem(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def send_to_decode_node(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def receive_from_prefill_node(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def send_to_decode_node_p2p(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") + + def receive_from_prefill_node_p2p(self, *args, **kwargs): + raise NotImplementedError("DeepSeek-V4 packed/composite KV transfer is not implemented") diff --git a/lightllm/common/kv_cache_mem_manager/operator/__init__.py b/lightllm/common/kv_cache_mem_manager/operator/__init__.py index 85c37ad39b..442c2e300e 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/operator/__init__.py @@ -5,6 +5,7 @@ from .deepseek import ( Deepseek2MemOperator, Deepseek3_2MemOperator, + DeepseekV4MemOperator, FP8PerTokenGroupQuantDeepseek3_2MemOperator, ) from .fp8_quant import ( diff --git a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py index 6e05b96e10..0725ce9b93 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/deepseek.py +++ b/lightllm/common/kv_cache_mem_manager/operator/deepseek.py @@ -8,7 +8,9 @@ class Deepseek2MemOperator(NormalMemOperator): def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import ( + Deepseek2MemoryManager, + ) mem_manager: Deepseek2MemoryManager = self.mem_manager @@ -30,7 +32,9 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: class Deepseek3_2MemOperator(Deepseek2MemOperator): def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import Deepseek3_2MemoryManager + from lightllm.common.kv_cache_mem_manager.deepseek3_2mem_manager import ( + Deepseek3_2MemoryManager, + ) mem_manager: Deepseek3_2MemoryManager = self.mem_manager from ...basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv @@ -78,3 +82,14 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: o_rope, ) return + + +class DeepseekV4MemOperator(BaseMemManagerOperator): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( + DeepseekV4MemoryManager, + ) + + mem_manager: DeepseekV4MemoryManager = self.mem_manager + mem_manager.pack_mla_kv_to_cache(layer_index, mem_index, kv) + return diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index cd534d53ec..4db55e1555 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -14,6 +14,7 @@ EXPERT_DTYPE_TO_QUANT_TYPE = { "fp8": "deepgemm-fp8w8a8-b128", "fp4": "deepgemm-fp4fp8-b32", + "mxfp4": "marlin-mxfp4w4a16-b32", } SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE) @@ -64,10 +65,13 @@ def _mapping_quant_method(self): logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, - # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 + # 按其值给 fused_moe 选用对应的量化方法。 expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None) if expert_dtype is None: return + # DeepSeek-V4 的 fp4 发布版自带预打包 MXFP4 专家。 + if expert_dtype == "fp4" and self.network_config_.get("model_type") == "deepseek_v4": + expert_dtype = "mxfp4" target = self._get_expert_quant_type(expert_dtype) for layer_num in range(self.layer_num): if self.expert_dtype is not None: diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index ec1ee90fd4..bedf22ee95 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -198,6 +198,144 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["marlin-mxfp4w4a16-b32"], platform="cuda") +class MXFP4MoEQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = "scale" + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "marlin-mxfp4w4a16-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + raise NotImplementedError("marlin-mxfp4w4a16-b32 only loads pre-packed MXFP4 expert weights") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("marlin-mxfp4w4a16-b32 is only implemented for fused MoE expert weights") + + def _probe_marlin_layout(self, size_n: int, size_k: int, dtype: torch.dtype, device_id: int): + """用零输入走一遍真实的 per-expert repack 路径,探出 marlin 终态布局的形状与类型。 + 只调用 finalize 同款的 vllm 函数,不复刻其内部公式,杜绝形状漂移。结果按维度缓存 + (各 MoE 层同维,全程只探两次: w13 一次、w2 一次)。""" + cache_key = (size_n, size_k, dtype) + cache = getattr(self, "_marlin_layout_cache", None) + if cache is None: + cache = self._marlin_layout_cache = {} + if cache_key in cache: + return cache[cache_key] + + import vllm._custom_ops as ops + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, + marlin_permute_scales, + ) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + mxfp4_marlin_process_scales, + ) + + input_dtype = get_marlin_input_dtype() + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + device = f"cuda:{device_id}" + qweight = torch.zeros((size_n, size_k // 2), dtype=torch.int8, device=device).view(torch.int32).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + is_a_8bit=is_a_8bit, + ) + scale = torch.zeros((size_k // self.block_size, size_n), dtype=dtype, device=device) + marlin_scale = marlin_permute_scales( + s=scale, size_k=size_k, size_n=size_n, group_size=self.block_size, is_a_8bit=is_a_8bit + ) + marlin_scale = mxfp4_marlin_process_scales(marlin_scale, input_dtype=input_dtype) + layout = ( + (tuple(marlin_qweight.shape), marlin_qweight.dtype), + (tuple(marlin_scale.shape), marlin_scale.dtype), + ) + cache[cache_key] = layout + return layout + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % self.block_size == 0, "MXFP4 scale dimension must be divisible by block_size" + expert_prefix = (num_experts,) if num_experts > 1 else () + # CPU 暂存区: load_hf_weights 灌入原始预打包 MXFP4,finalize 时 repack 进 CUDA 终态。 + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8, device="cpu") + weight_scale = torch.empty( + expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float8_e8m0fnu, device="cpu" + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + # CUDA 终态(marlin 布局)在构造期物化,使 mem manager 的 profile 看到真实权重占用 + # ("构造即分配、load 只灌数"的框架契约,与其它 quant 方法一致;惰性到 finalize 才 + # 进卡会让空卡 profile 把 kv 池撑到挤爆权重加载)。finalize 时 repack 结果拷入。 + (w_shape, w_dtype), (s_shape, s_dtype) = self._probe_marlin_layout(out_dim, in_dim, dtype, device_id) + mm_param.marlin_weight = torch.empty((num_experts,) + w_shape, dtype=w_dtype, device=f"cuda:{device_id}") + mm_param.marlin_weight_scale = torch.empty((num_experts,) + s_shape, dtype=s_dtype, device=f"cuda:{device_id}") + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def finalize_moe_weight(self, moe_weight): + try: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + except Exception as e: + raise RuntimeError(f"marlin-mxfp4w4a16-b32 requires vLLM MXFP4 packing utilities, error={repr(e)}") from e + + class _MXFP4Layer: + pass + + device = torch.device("cuda", moe_weight.device_id_) + layer = _MXFP4Layer() + layer.params_dtype = moe_weight.data_type_ + w13 = moe_weight.w13.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous() + w2 = moe_weight.w2.weight.view(torch.uint8).to(device=device, non_blocking=True).contiguous() + w13_scale = moe_weight.w13.weight_scale.to(device=device, non_blocking=True).contiguous() + w2_scale = moe_weight.w2.weight_scale.to(device=device, non_blocking=True).contiguous() + ( + w13_new, + w2_new, + w13_scale_new, + w2_scale_new, + _, + _, + ) = prepare_moe_mxfp4_layer_for_marlin(layer, w13, w2, w13_scale, w2_scale, None, None) + # repack 结果拷入构造期预分配的 marlin 终态 buffer(与 AWQ marlin 路径同形态), + # CPU 暂存与 repack 临时随引用释放;shape 失配会在 copy_ 处显式报错(探针保证一致)。 + moe_weight.w13.marlin_weight.copy_(w13_new) + moe_weight.w13.marlin_weight_scale.copy_(w13_scale_new) + moe_weight.w2.marlin_weight.copy_(w2_new) + moe_weight.w2.marlin_weight_scale.copy_(w2_scale_new) + moe_weight.w13.weight = moe_weight.w13.marlin_weight + moe_weight.w13.weight_scale = moe_weight.w13.marlin_weight_scale + moe_weight.w2.weight = moe_weight.w2.marlin_weight + moe_weight.w2.weight_scale = moe_weight.w2.marlin_weight_scale + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..5e7c4f96dd 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,17 +1,26 @@ import torch import collections +from dataclasses import dataclass from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig from lightllm.utils.log_utils import init_logger -from .kv_cache_mem_manager import MemoryManager +from .kv_cache_mem_manager import MemoryManager, DeepseekV4MemoryManager from typing import List, Optional, TYPE_CHECKING from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter -from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter +from lightllm.common.basemodel.triton_kernel.gen_sampling_params import ( + update_req_to_token_id_counter, +) from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache -from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager +from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import ( + LinearAttCacheManager, +) +from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( + DSV4_C4_PAGE_SIZE, + DSV4_PROMPT_CACHE_PAGE_SIZE, +) if TYPE_CHECKING: from lightllm.server.router.model_infer.infer_batch import InferReq @@ -19,6 +28,58 @@ logger = init_logger(__name__) +@dataclass +class DeepseekV4PromptCachePayload: + """prompt cache 载荷: 只剩 swa 按页有效性 bitmap。 + + 槽位与 compressor 状态都不进载荷: full_to_swa/full_to_c4/full_to_c128 以 full token 槽位 + 为键(radix 持有 full 槽 ⇒ 映射行存活,free 级联回收);c4/c128 compressor 状态以 swa + 页派生寻址(随 swa 页生灭,命中零拷贝续算)。prompt cache 对齐到 256 token, + 避免共享前缀停在 c4 物理页中间。 + + * ``swa_page_valid``: cpu bool [cache_len // page],插入时按当下 full_to_swa 映射写定 + (页内 token 映射全有效才为 True)。匹配层据此把命中裁剪到"结尾页有效"的 page 边界, + swa 压力阀回收节点页时清零。""" + + cache_len: int + swa_page_valid: Optional[torch.Tensor] = None + + +class DeepseekV4PromptCacheValueOps: + def __init__(self, req_manager: "DeepseekV4ReqManager"): + self.req_manager = req_manager + + def slice(self, payload: DeepseekV4PromptCachePayload, start: int, end: int): + return self.req_manager.slice_prompt_cache_payload(payload, start, end) + + def concat(self, payloads: List[DeepseekV4PromptCachePayload]): + return self.req_manager.concat_prompt_cache_payloads(payloads) + + def free(self, payload: DeepseekV4PromptCachePayload): + # 槽位资源全部由 mem_manager.free(full_slots) 级联回收,载荷本身没有需要释放的资源。 + return + + def invalidate_swa_pages(self, payload: DeepseekV4PromptCachePayload) -> None: + """swa 压力阀回收了该节点的 swa 页后清 bitmap: 后续命中按缩短语义裁剪,不会复活。""" + if payload is not None and payload.swa_page_valid is not None: + payload.swa_page_valid.fill_(False) + return + + def valid_match_length(self, payload: Optional[DeepseekV4PromptCachePayload], natural_len: int) -> int: + """radix 匹配裁剪: 返回 <= natural_len 的最大 prompt-cache 边界 L',使结尾页有效。 + + 有效性可能非单调(owner 生前从左驱逐、后续阀从尾回收),按候选边界回查 bitmap; + 中段 invalid 页不挡更靠后的有效命中(注意力只回看最后一个窗口)。""" + page = self.req_manager.get_prompt_cache_page_size() + if payload is None or payload.swa_page_valid is None: + return 0 + n_pages = min(natural_len // page, int(payload.swa_page_valid.numel())) + valid_idx = torch.nonzero(payload.swa_page_valid[:n_pages]) + if valid_idx.numel() == 0: + return 0 + return (int(valid_idx[-1]) + 1) * page + + class _ReqNode: def __init__(self, index): self.index = index @@ -100,6 +161,11 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu): + """每个 decode step 在 to_cuda 之前调用的钩子 (数据为原生 CPU 张量, 且已在 forward 的 + CUDA stream 上)。基类 no-op; 需要 per-step KV 槽位 prep 的模型 (DeepSeek-V4) override。""" + return + class ReqSamplingParamsManager: """ @@ -299,3 +365,469 @@ def copy_small_page_buffer_to_linear_att_state( self.req_to_conv_state.buffer[:, dest_req_idx, ...] = conv_state self.req_to_ssm_state.buffer[:, dest_req_idx, ...] = ssm_state return + + +class DeepseekV4ReqManager(ReqManager): + """DeepSeek-V4 的请求级管理。 + + 在基类 ReqManager 之上补 V4 专有的 per-request 结构。该对象在 mem manager profile 前创建, + 所以初始化只依赖 config 派生出的 compress_rates/head_dim/indexer_head_dim/sliding_window; + 真实 mem_manager 会在 `_init_mem_manager()` 后通过 `bind_mem_manager()` 接入。 + + * 压缩槽位不在本类: ``full_to_c4/c128_indexs``(mem manager)以组末 token 的 full 槽位为键。 + 本类只负责 prep 阶段的分配与 scatter(``prepare_prefill_compress_slots`` / + ``prepare_decode_compress_slots``)——必须先于 attention metadata 构建/图捕获; + 条目内容由 layer-infer 的 compressor 前向写入。 + * compressor 在途状态不在本类: c4/c128 都在 mem manager 的 swa 页派生池, + 随页生灭,命中零拷贝续算。 + * SWA 槽位分配/出窗回收(``prepare_prefill_swa`` / ``prepare_decode_swa``): 每步 prep 阶段 + 为新 token 调 mem_manager.alloc_swa,并按 per-req 水位线(``_swa_evict_marks``)惰性回收 + 已出窗位置的 swa 槽。水位线首次置为该请求首个 chunk 的 ready_cache_len(radix 共享前缀 + 的边界),因此共享前缀的 swa 槽永远不会被本请求回收(归 radix 经 mem_manager.free 级联释放)。 + """ + + def __init__( + self, + max_request_num, + max_sequence_length, + mem_manager: Optional[DeepseekV4MemoryManager] = None, + compress_rates: Optional[List[int]] = None, + head_dim: Optional[int] = None, + indexer_head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, + ): + super().__init__(max_request_num, max_sequence_length, mem_manager) + + self.mem_manager = mem_manager + if mem_manager is not None: + if compress_rates is None: + compress_rates = mem_manager.compress_rates + if head_dim is None: + head_dim = mem_manager.head_dim + if indexer_head_dim is None: + indexer_head_dim = mem_manager.indexer_head_dim + if sliding_window is None: + sliding_window = mem_manager.sliding_window + self.sliding_window = sliding_window + # 出窗回收水位线: -1 表示该 req 尚未见过 prefill chunk(首个 chunk 的 ready_cache_len + # 即共享前缀边界,作为永不下探的回收下界)。 + self._swa_evict_marks = [-1 for _ in range(max_request_num + 1)] + self.compress_rates = list(compress_rates) + self.n_c4 = sum(1 for r in self.compress_rates if r == 4) + self.n_c128 = sum(1 for r in self.compress_rates if r == 128) + self.head_dim = head_dim + self.indexer_head_dim = indexer_head_dim + self.layer_to_c4_idx = {} + self.layer_to_c128_idx = {} + c4 = c128 = 0 + for lid, r in enumerate(self.compress_rates): + if r == 4: + self.layer_to_c4_idx[lid] = c4 + c4 += 1 + elif r == 128: + self.layer_to_c128_idx[lid] = c128 + c128 += 1 + + return + + def bind_mem_manager(self, mem_manager: DeepseekV4MemoryManager): + assert isinstance(mem_manager, DeepseekV4MemoryManager) + assert self.compress_rates == mem_manager.compress_rates + assert self.head_dim == mem_manager.head_dim + assert self.indexer_head_dim == mem_manager.indexer_head_dim + if self.sliding_window is None: + self.sliding_window = mem_manager.sliding_window + else: + assert mem_manager.sliding_window is None or self.sliding_window == mem_manager.sliding_window + self.mem_manager = mem_manager + return + + # ------------------------------------------------------------------ swa slot prep (per step) + def _swa_retain_len(self) -> int: + """出窗回收的保留长度 = window + 一个 radix 页。 + + 多留一页使「最近一个完成的 prompt-cache 边界」的结尾页恒驻留: 若回收只留 window, + 则任何非对齐时刻该边界的结尾页都已被部分回收,插入门会把所有插入裁到 0。 + V4 prompt-cache 页取 256 token,正好覆盖一个 c4 物理页对应的 token 范围。""" + return int(self.sliding_window) + self.get_prompt_cache_page_size() + + def prepare_prefill_swa( + self, + b_req_idx: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_seq_len: torch.Tensor, + ) -> None: + """prefill prep: 为本 chunk 全部新 token(位置 [ready, seq))分配位置对齐的 swa 槽, + 并回收已出窗位置的槽。 + + 本 chunk 起点 L = ready_cache_len,首个新 token(位置 L)的窗口是 [L-W+1, L];回收 + 边界再额外保留一个 radix 页(_swa_retain_len),即位置 < L-retain+1。先回收再分配。 + 必须在 init_req_to_token_indexes 之后调用(位置对齐分配经 req_to_token 行派生/scatter)。""" + assert self.mem_manager is not None + if self.sliding_window is not None: + retain = self._swa_retain_len() + evict_slots = [] + req_list = b_req_idx.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() + for req_idx, ready_len in zip(req_list, ready_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + ready_len = int(ready_len) + mark = self._swa_evict_marks[req_idx] + if mark < 0: + # 首个 chunk: [0, ready_len) 是 radix 共享前缀,其 swa 槽归 radix 所有,不可回收。 + self._swa_evict_marks[req_idx] = ready_len + continue + evict_end = ready_len - retain + 1 + if evict_end > mark: + evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end]) + self._swa_evict_marks[req_idx] = evict_end + if evict_slots: + self.mem_manager.evict_swa(torch.cat(evict_slots)) + self.mem_manager.alloc_swa_prefill(b_req_idx, b_ready_cache_len, b_seq_len, self.req_to_token_indexs) + return + + def prepare_decode(self, b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu): + """decode 每步槽位 prep: 先 swa 再 compress。由 BaseModel.forward / microbatch_overlap_decode + 在 to_cuda 之前调用 (CPU 数据 + forward 的 CUDA stream); 不再放在 _decode 里。""" + self.prepare_decode_swa(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu) + self.prepare_decode_compress_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes_cpu) + return + + def prepare_decode_swa( + self, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, + mem_indexes: torch.Tensor, + ) -> None: + """decode prep: 回收出窗槽并为本步新 token 分配位置对齐的 swa 槽。当前 query 位置 + seq_len-1 的窗口是 [seq_len-W, seq_len-1];回收边界额外保留一个 radix 页 + (_swa_retain_len),即位置 < seq_len-retain。先回收再分配。 + seq_len/req_idx 从 CPU 镜像读(host 算术,无 D2H);水位线 _swa_evict_marks 仍是 host 状态。""" + assert self.mem_manager is not None + if self.sliding_window is not None: + retain = self._swa_retain_len() + evict_slots = [] + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() + for req_idx, seq_len in zip(req_list, seq_list): + if req_idx == self.HOLD_REQUEST_ID: + continue + mark = self._swa_evict_marks[req_idx] + if mark < 0: + # 未经过 prefill prep 的保守路径: 不回收旧位置,仅推进水位线。 + self._swa_evict_marks[req_idx] = max(0, seq_len - retain) + continue + evict_end = seq_len - retain + if evict_end > mark: + evict_slots.append(self.req_to_token_indexs[req_idx, mark:evict_end]) + self._swa_evict_marks[req_idx] = evict_end + if evict_slots: + self.mem_manager.evict_swa(torch.cat(evict_slots)) + self.mem_manager.alloc_swa_decode(b_req_idx_cpu, b_seq_len_cpu, mem_indexes, self.req_to_token_indexs) + return + + def init_compress_state(self, req_idx: int): + """新请求开始时重置 runtime 水位线(对应 mamba 的 init_linear_att_state 调用点)。 + + c4/c128 compressor state 都随 swa 页寻址,由内核按组覆写;请求复用时不做 per-req 清零。""" + self.clear_runtime_state(req_idx) + return + + # ------------------------------------------------------------------ compress slot prep (per step) + def _compress_mapping_alloc(self, ratio: int): + assert self.mem_manager is not None, "DeepSeek-V4 mem manager is not bound yet" + if ratio == 4: + raise AssertionError("DeepSeek-V4 c4 uses page-safe allocation") + if ratio == 128: + return self.mem_manager.full_to_c128_indexs, self.mem_manager.alloc_c128 + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {ratio}") + + def _scatter_c4_prefill_slots_slow(self, req_idx: int, first: int, last: int) -> None: + """Idempotence fallback for overlapped/repeated c4 prep.""" + page = DSV4_C4_PAGE_SIZE + mapping = self.mem_manager.full_to_c4_indexs + for page_base in range((first // page) * page, last, page): + e0 = max(first, page_base) + e1 = min(last, page_base + page) + entries = torch.arange(e0, e1, dtype=torch.long, device="cuda") + full_slots = self.req_to_token_indexs[req_idx, entries * 4 + 3].long() + existing = mapping[full_slots] + missing = existing < 0 + if not bool(missing.any()): + continue + + mapped = torch.nonzero(existing >= 0, as_tuple=False) + if mapped.numel() > 0: + j = int(mapped[0].item()) + base = int(existing[j].item()) - ((e0 + j) % page) + elif e0 > page_base: + prev_full = self.req_to_token_indexs[req_idx, e0 * 4 - 1].long() + prev_slot = int(mapping[prev_full].item()) + assert prev_slot >= 0 and prev_slot % page == (e0 - 1) % page + base = prev_slot - ((e0 - 1) % page) + else: + base = int(self.mem_manager.alloc_c4_pages(1)[0].item()) * page + + slots = (base + entries % page).to(torch.int32) + if mapped.numel() > 0: + assert bool((existing[existing >= 0] == slots[existing >= 0]).all()) + mapping[full_slots[missing]] = slots[missing] + self.mem_manager.count_c4_slots(slots[missing], 1) + return + + def _scatter_c4_prefill_slots(self, req_idx: int, first: int, last: int) -> None: + """为 logical c4 entry [first, last) 分配 page-safe c4 槽。 + + 不变式: logical entry e 映射到 physical_page * 64 + e % 64,同一 logical page + 内 entry 共享 physical_page。这是 DeepGEMM paged MQA logits 直接消费 page table 的前提。 + """ + if last <= first: + return + page = DSV4_C4_PAGE_SIZE + mapping = self.mem_manager.full_to_c4_indexs + entries = torch.arange(first, last, dtype=torch.long, device="cuda") + full_slots = self.req_to_token_indexs[req_idx, entries * 4 + 3].long() + need = mapping[full_slots] < 0 + if not bool(need.any()): + return + if not bool(need.all()): + self._scatter_c4_prefill_slots_slow(req_idx, first, last) + return + + first_page = first // page + last_page = (last - 1) // page + n_pages = last_page - first_page + 1 + bases = torch.empty((n_pages,), dtype=torch.long, device="cuda") + + base_start = 0 + if first % page != 0: + prev_full = self.req_to_token_indexs[req_idx, first * 4 - 1].long() + prev_slot = int(mapping[prev_full].item()) + assert prev_slot >= 0 and prev_slot % page == (first - 1) % page + bases[0] = prev_slot - ((first - 1) % page) + base_start = 1 + + new_page_count = n_pages - base_start + if new_page_count > 0: + new_pages = self.mem_manager.alloc_c4_pages(new_page_count).cuda(non_blocking=True).long() + bases[base_start:] = new_pages * page + + page_local = torch.div(entries, page, rounding_mode="floor") - first_page + slots = (bases[page_local] + entries % page).to(torch.int32) + mapping[full_slots] = slots + self.mem_manager.count_c4_slots(slots, 1) + return + + def _scatter_c4_decode_slots( + self, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, + mem_indexes: torch.Tensor, + ) -> None: + page = DSV4_C4_PAGE_SIZE + mapping = self.mem_manager.full_to_c4_indexs + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() + mem_indexes = mem_indexes.cuda().long().reshape(-1) + + cont_rows, cont_prev_pos, cont_offsets = [], [], [] + new_rows = [] + for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)): + req_idx, seq_len = int(req_idx), int(seq_len) + if req_idx == self.HOLD_REQUEST_ID or seq_len <= 0 or seq_len % 4 != 0: + continue + entry = seq_len // 4 - 1 + offset = entry % page + if offset == 0: + new_rows.append(i) + else: + cont_rows.append(i) + cont_prev_pos.append(entry * 4 - 1) + cont_offsets.append(offset) + + if cont_rows: + req_rows = torch.tensor([req_list[i] for i in cont_rows], dtype=torch.long, device="cuda") + prev_pos = torch.tensor(cont_prev_pos, dtype=torch.long, device="cuda") + prev_full = self.req_to_token_indexs[req_rows, prev_pos].long() + prev_slots = mapping[prev_full] + offsets = torch.tensor(cont_offsets, dtype=torch.int32, device="cuda") + assert bool((prev_slots >= 0).all()) + assert bool(((prev_slots % page) == (offsets - 1)).all()) + slots = (prev_slots + 1).to(torch.int32) + mapping[mem_indexes[cont_rows]] = slots + self.mem_manager.count_c4_slots(slots, 1) + + if new_rows: + pages = self.mem_manager.alloc_c4_pages(len(new_rows)).cuda(non_blocking=True).long() + slots = (pages * page).to(torch.int32) + mapping[mem_indexes[new_rows]] = slots + self.mem_manager.count_c4_slots(slots, 1) + return + + def _scatter_compress_slots(self, ratio: int, full_slots: torch.Tensor) -> None: + """为组末 full 槽位分配压缩槽并写入映射。已映射(>=0)的行跳过——重复 prep 幂等。""" + if full_slots.numel() == 0: + return + mapping, alloc = self._compress_mapping_alloc(ratio) + full_slots = full_slots.cuda().long().reshape(-1) + # 去重: 同批重复键会让后写覆盖先写,先分配的压缩槽成为孤儿(allocator 泄漏)。 + need = torch.unique(full_slots[mapping[full_slots] < 0]) + if need.numel() == 0: + return + new_slots = alloc(need.numel()).cuda(non_blocking=True).to(torch.int32) + mapping[need] = new_slots + return + + def prepare_prefill_compress_slots( + self, + b_req_idx: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_seq_len: torch.Tensor, + ) -> None: + """prefill prep: 为本 chunk 内的组末 token(位置 (g+1)*ratio-1 ∈ [ready, seq))分配压缩槽, + scatter 进 full_to_c4/c128_indexs。必须在 init_req_to_token_indexes 之后(组末 full 槽 + 从 req_to_token_indexs 取)、attention metadata 构建之前调用。""" + if self.n_c4 == 0 and self.n_c128 == 0: + return + req_list = b_req_idx.detach().cpu().tolist() + ready_list = b_ready_cache_len.detach().cpu().tolist() + seq_list = b_seq_len.detach().cpu().tolist() + if self.n_c4 > 0: + for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + self._scatter_c4_prefill_slots(req_idx, int(ready_len) // 4, int(seq_len) // 4) + + if self.n_c128 > 0: + ratio = 128 + end_slots = [] + for req_idx, ready_len, seq_len in zip(req_list, ready_list, seq_list): + req_idx = int(req_idx) + if req_idx == self.HOLD_REQUEST_ID: + continue + first, last = int(ready_len) // ratio, int(seq_len) // ratio + if last > first: + ends = self.req_to_token_indexs[req_idx, ratio - 1 : last * ratio : ratio] + end_slots.append(ends[first:]) + if end_slots: + self._scatter_compress_slots(ratio, torch.cat(end_slots)) + return + + def prepare_decode_compress_slots( + self, + b_req_idx_cpu: torch.Tensor, + b_seq_len_cpu: torch.Tensor, + mem_indexes: torch.Tensor, + ) -> None: + """decode prep: 本步 token 关闭一个组(seq_len % ratio == 0)时为其分配压缩槽并 scatter。 + 组末 full 槽即本步的 mem_index(此刻 req_to_token_indexs 尚未写入本步槽位)。 + 从 CPU 镜像读 seq_len/req_idx(host 算术,无 D2H);非关组步 rows 为空 => 不调 _scatter,零同步。""" + if self.n_c4 == 0 and self.n_c128 == 0: + return + req_list = b_req_idx_cpu.tolist() + seq_list = b_seq_len_cpu.tolist() + if self.n_c4 > 0: + self._scatter_c4_decode_slots(b_req_idx_cpu, b_seq_len_cpu, mem_indexes) + + if self.n_c128 > 0: + ratio = 128 + rows = [ + i + for i, (req_idx, seq_len) in enumerate(zip(req_list, seq_list)) + if req_idx != self.HOLD_REQUEST_ID and seq_len > 0 and seq_len % ratio == 0 + ] + if rows: + self._scatter_compress_slots(ratio, mem_indexes.reshape(-1)[rows]) + return + + def alloc(self): + req_idx = super().alloc() + if req_idx is not None: + self.init_compress_state(req_idx) + return req_idx + + def clear_runtime_state(self, req_idx: int): + # swa 槽位本身由 mem_manager.free 级联回收(随 full 槽位),这里只复位出窗水位线。 + self._swa_evict_marks[req_idx] = -1 + return + + def get_prompt_cache_value_ops(self): + return DeepseekV4PromptCacheValueOps(self) + + def get_prompt_cache_page_size(self): + return DSV4_PROMPT_CACHE_PAGE_SIZE + + def compute_swa_page_valid(self, full_slots: torch.Tensor) -> torch.Tensor: + """按当下 full_to_swa 映射给出按页有效性: full_slots [L](L 为 page 整数倍) -> + cpu bool [L/page],页内全部映射有效才为 True。GPU gather + 同步,测试/校验用; + 插入热路径用 swa_page_valid_from_watermark(纯 CPU,免同步)。""" + page = self.get_prompt_cache_page_size() + assert full_slots.numel() % page == 0 + if full_slots.numel() == 0: + return torch.zeros((0,), dtype=torch.bool) + swa = self.mem_manager.full_to_swa_indexs[full_slots.cuda().long().reshape(-1)] + return (swa.view(-1, page) >= 0).all(dim=1).cpu() + + def swa_page_valid_from_watermark(self, req_idx: int, cache_len: int) -> torch.Tensor: + """插入时的按页有效性,纯 CPU: 请求自有 token 的 swa 映射只被出窗水位线回收 + (阀不触活跃请求,级联只在 free 时),页 p 全驻留 ⟺ 页起点 128p >= 水位线。 + + 与 compute_swa_page_valid 在插入时刻对自有 token 等价,但不做 GPU gather/同步—— + router 关键路径上每次插入省一次对全部在途 kernel 的等待。bitmap 中借入前缀 + ([0, ready) 的页)的行在 radix insert 切片时被丢弃(既有节点保留自己的 bitmap), + 其取值无影响。""" + page = self.get_prompt_cache_page_size() + mark = max(0, self._swa_evict_marks[req_idx]) + n_pages = int(cache_len) // page + return torch.arange(n_pages, dtype=torch.long) * page >= mark + + def slice_prompt_cache_payload(self, payload: DeepseekV4PromptCachePayload, start: int, end: int): + start = int(start) + end = int(end) + page = self.get_prompt_cache_page_size() + # radix page 保证分裂点页对齐,bitmap 可整页切分。 + return DeepseekV4PromptCachePayload( + cache_len=end - start, + swa_page_valid=payload.swa_page_valid[start // page : end // page].clone() + if payload.swa_page_valid is not None + else None, + ) + + def concat_prompt_cache_payloads(self, payloads: List[DeepseekV4PromptCachePayload]): + if len(payloads) == 0: + return None + bitmaps = [p.swa_page_valid for p in payloads] + return DeepseekV4PromptCachePayload( + cache_len=sum(p.cache_len for p in payloads), + swa_page_valid=torch.cat(bitmaps, dim=0) if all(b is not None for b in bitmaps) else None, + ) + + def build_prompt_cache_payload( + self, + req_idx: int, + cache_len: int, + ) -> DeepseekV4PromptCachePayload: + """构造插入载荷。compressor 状态不进载荷(c4 随 swa 页生灭、c128 边界自然归零), + cache_len 不再受序列末端约束——任意 128 对齐前缀皆可插入。 + swa_page_valid 不在此填: 它必须用插入时刻的映射(infer batch 在 insert 前补)。""" + assert self.mem_manager is not None + return DeepseekV4PromptCachePayload(cache_len=int(cache_len)) + + def free(self, free_req_indexes, free_token_index): + """dense/swa/压缩槽全部经 mem_manager.free(free_token_index) 级联回收。""" + for req_index in free_req_indexes: + self.clear_runtime_state(req_index) + super().free(free_req_indexes, free_token_index) + return + + def free_req(self, free_req_index: int): + self.clear_runtime_state(free_req_index) + return super().free_req(free_req_index) + + def free_all(self): + super().free_all() + self._swa_evict_marks = [-1 for _ in range(self.max_request_num + 1)] + return diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index f619b1d88f..3d376d160d 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -20,6 +20,7 @@ from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel +from lightllm.models.deepseek_v4.model import DeepseekV4TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py index 30e5a59248..a8f851de2a 100644 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py +++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py @@ -29,6 +29,8 @@ def _rotary_kernel( BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, NUM_STAGE: tl.constexpr, + HAS_K: tl.constexpr, + INVERSE: tl.constexpr, ): head_start_index = tl.program_id(0) seq_block_index = tl.program_id(1) @@ -44,6 +46,8 @@ def _rotary_kernel( off_dimcos_sin = seq_index * stride_cosbs + cos_range * stride_cosd cos = tl.load(Cos + off_dimcos_sin) sin = tl.load(Sin + off_dimcos_sin) + if INVERSE: + sin = -sin if HEAD_PARALLEL_NUM == 1: for q_head_index in tl.static_range(0, HEAD_Q, step=1): @@ -56,18 +60,19 @@ def _rotary_kernel( tl.store(Q + off_q0, out_q0) tl.store(Q + off_q1, out_q1) - for k_head_index in tl.static_range(0, HEAD_K, step=1): - off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd - off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd + if HAS_K: + for k_head_index in tl.static_range(0, HEAD_K, step=1): + off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd + off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd - k0 = tl.load(K + off_k0) - k1 = tl.load(K + off_k1) + k0 = tl.load(K + off_k0) + k1 = tl.load(K + off_k1) - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos - tl.store(K + off_k0, out_k0) - tl.store(K + off_k1, out_k1) + tl.store(K + off_k0, out_k0) + tl.store(K + off_k1, out_k1) else: for q_head_index in tl.range(head_start_index, HEAD_Q, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE): off_q0 = seq_index * stride_qbs + q_head_index * stride_qh + dim_range0 * stride_qd @@ -79,18 +84,19 @@ def _rotary_kernel( tl.store(Q + off_q0, out_q0) tl.store(Q + off_q1, out_q1) - for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE): - off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd - off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd + if HAS_K: + for k_head_index in tl.range(head_start_index, HEAD_K, step=HEAD_PARALLEL_NUM, num_stages=NUM_STAGE): + off_k0 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range0 * stride_kd + off_k1 = seq_index * stride_kbs + k_head_index * stride_kh + dim_range1 * stride_kd - k0 = tl.load(K + off_k0) - k1 = tl.load(K + off_k1) + k0 = tl.load(K + off_k0) + k1 = tl.load(K + off_k1) - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos - tl.store(K + off_k0, out_k0) - tl.store(K + off_k1, out_k1) + tl.store(K + off_k0, out_k0) + tl.store(K + off_k1, out_k1) return @@ -109,7 +115,10 @@ def get_test_configs(): def get_static_key(q, k): - head_num_q, head_num_k, head_dim = q.shape[1], k.shape[1], q.shape[2] + assert q is not None, "q can not be None" + head_num_q = q.shape[1] + head_num_k = k.shape[1] if k is not None else 0 + head_dim = q.shape[2] return { "Q_HEAD_NUM": head_num_q, "K_HEAD_NUM": head_num_k, @@ -126,12 +135,17 @@ def get_static_key(q, k): mutates_args=["q", "k"], ) @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, run_config=None): +def rotary_emb_fwd(q, k, cos, sin, inverse=False, run_config=None): + assert q is not None, "q can not be None" + has_k = k is not None and k.shape[1] != 0 total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] + head_num_q = q.shape[1] + head_num_k = k.shape[1] if k is not None else 0 head_dim = q.shape[2] assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + if k is not None: + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + assert k.shape[2] == head_dim, f"k shape {k.shape} q head_dim {head_dim}" assert triton.next_power_of_2(head_dim) == head_dim if not run_config: @@ -157,9 +171,9 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None): stride_qbs=q.stride(0), stride_qh=q.stride(1), stride_qd=q.stride(2), - stride_kbs=k.stride(0), - stride_kh=k.stride(1), - stride_kd=k.stride(2), + stride_kbs=k.stride(0) if k is not None else 0, + stride_kh=k.stride(1) if k is not None else 0, + stride_kd=k.stride(2) if k is not None else 0, stride_cosbs=cos.stride(0), stride_cosd=cos.stride(1), stride_sinbs=sin.stride(0), @@ -171,6 +185,8 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None): BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, NUM_STAGE=num_stages, + HAS_K=has_k, + INVERSE=inverse, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/models/deepseek_v4/__init__.py b/lightllm/models/deepseek_v4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/infer_struct.py b/lightllm/models/deepseek_v4/infer_struct.py new file mode 100644 index 0000000000..ca2ac83b03 --- /dev/null +++ b/lightllm/models/deepseek_v4/infer_struct.py @@ -0,0 +1,58 @@ +import torch +from lightllm.common.basemodel import InferStateInfo +from lightllm.common.req_manager import DeepseekV4ReqManager +from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager + + +class DeepseekV4InferStateInfo(InferStateInfo): + req_manager: DeepseekV4ReqManager + mem_manager: DeepseekV4MemoryManager + + """Per-token interleaved-rope cos/sin for the two rope variants (sliding / compressed), following + the gemma4 two-variant convention (_cos_cached_* -> position_cos_*). The full rope tables are + model constants and live on the model / layer infers, not here.""" + + def __init__(self): + super().__init__() + self.position_cos_sliding = None + self.position_sin_sliding = None + self.position_cos_compress = None + self.position_sin_compress = None + # layer-independent sparse-index metadata, built once per forward in init_some_extra_state + # (None until then so copy_for_cuda_graph's tensor-attr loop skips them). + self.dsv4_sparse_req_idx = None + self.dsv4_swa_indices = None + self.dsv4_swa_lengths = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) # sets position_ids, b_q_seq_len, b_q_start_loc (prefill) + pos = self.position_ids + self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, pos) # [T, rope_dim//2] + self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, pos) + self.position_cos_compress = torch.index_select(model._cos_cached_compress, 0, pos) + self.position_sin_compress = torch.index_select(model._sin_cached_compress, 0, pos) + # Per-token request id (decode: one token per req; prefill: ragged -> repeat by q-len). + # Layer-independent; the swa kernel + build_metadata's c4/c128 readers all reuse it. + if self.is_prefill: + self.dsv4_sparse_req_idx = torch.repeat_interleave(self.b_req_idx, self.b_q_seq_len.long()) + else: + self.dsv4_sparse_req_idx = self.b_req_idx + # Sliding-window indices: layer-independent (full_to_swa is global, window is const), so build + # once here via one fused kernel instead of recomputing per layer. const [T, window] shape is + # cuda-graph-safe (no max_kv_seq_len dependence) and auto-staged by copy_for_cuda_graph. + from lightllm.models.deepseek_v4.triton_kernel.build_swa_index_dsv4 import build_swa_index + + self.dsv4_swa_indices, self.dsv4_swa_lengths = build_swa_index( + req_idx=self.dsv4_sparse_req_idx, + positions=self.position_ids, + req_to_token_indexs=self.req_manager.req_to_token_indexs, + full_to_swa_indexs=self.mem_manager.full_to_swa_indexs, + window=int(self.mem_manager.sliding_window), + ) + # prefill-cudagraph 桶填充的 HOLD 尾请求的 q 行数。其注意力读 HOLD 槽位(内容被并发写 + # 竞争,每轮不同),输出必须清零,否则 pad 行 hidden 不确定 -> MoE 路由抖动 -> 共享 expert + # 批次组成变化 -> 真实行 GEMM 归约顺序变化(ulp 级),44 层放大后翻转低置信 token。 + self._dsv4_prefill_pad_q_len = 0 + if self.is_prefill and self.b_req_idx.numel() > 0: + if int(self.b_req_idx[-1].item()) == self.req_manager.HOLD_REQUEST_ID: + self._dsv4_prefill_pad_q_len = int((self.b_seq_len[-1] - self.b_ready_cache_len[-1]).item()) diff --git a/lightllm/models/deepseek_v4/layer_infer/__init__.py b/lightllm/models/deepseek_v4/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/layer_infer/compressor.py b/lightllm/models/deepseek_v4/layer_infer/compressor.py new file mode 100644 index 0000000000..c66fd22de5 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/compressor.py @@ -0,0 +1,471 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import triton +import triton.language as tl +from triton.language.extra import libdevice + +from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager +from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import ( + DSV4_C4_STATE_RING, + DSV4_C128_STATE_RING, + DSV4_SWA_PAGE_SIZE, +) + + +@dataclass +class CoreCompressorMetadata: + layer_idx: int + compress_ratio: int + out_slots: torch.Tensor + mem_index: torch.Tensor + state_buffer: torch.Tensor + out_buffer: torch.Tensor + out_page_size: int + position_ids: torch.Tensor + b_req_idx: torch.Tensor + b_seq_len: torch.Tensor + b_ready_cache_len: Optional[torch.Tensor] + b_q_start_loc: Optional[torch.Tensor] + req_to_token_indexs: torch.Tensor + full_to_swa_indexs: torch.Tensor + token_to_batch_idx: Optional[torch.Tensor] + kv_score: Optional[torch.Tensor] + is_prefill: bool + + +@triton.jit +def _add_ape_to_kv_score_kernel( + kv_score, + kv_score_stride0, + kv_score_stride1, + ape, + ape_stride0, + positions, + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + BLOCK: tl.constexpr, +): + token_idx = tl.program_id(0) + offs = tl.arange(0, BLOCK) + mask = offs < STATE_WIDTH + + position = tl.load(positions + token_idx) + ape_row = position % COMPRESS_RATIO + score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask) + ape_value = tl.load(ape + ape_row * ape_stride0 + offs, mask=mask) + tl.store( + kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, + score + ape_value, + mask=mask, + ) + return + + +@triton.jit +def _save_partial_states_kernel( + kv_score, + kv_score_stride0, + kv_score_stride1, + positions, + token_to_batch_idx, + b_req_idx, + b_seq_len, + mem_index, + full_to_swa, + state_buffer, + STATE_WIDTH: tl.constexpr, + STATE_LAST_DIM: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + IS_C4: tl.constexpr, + IS_PREFILL: tl.constexpr, + SWA_PAGE_SIZE: tl.constexpr, + STATE_RING: tl.constexpr, + BLOCK: tl.constexpr, +): + token_idx = tl.program_id(0) + batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx + position = tl.load(positions + token_idx) + seq_len = tl.load(b_seq_len + batch_idx) + + if IS_C4: + same_page_next = (position % SWA_PAGE_SIZE) + STATE_RING < SWA_PAGE_SIZE + if same_page_next and position + STATE_RING < seq_len: + return + else: + if position + COMPRESS_RATIO < seq_len: + return + + full_slot = tl.load(mem_index + token_idx).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot).to(tl.int64) + if swa_slot < 0: + return + state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING) + + offs = tl.arange(0, BLOCK) + mask = offs < STATE_WIDTH + kv = tl.load(kv_score + token_idx * kv_score_stride0 + offs * kv_score_stride1, mask=mask) + score = tl.load(kv_score + token_idx * kv_score_stride0 + (STATE_WIDTH + offs) * kv_score_stride1, mask=mask) + state_base = state_buffer + state_row * STATE_LAST_DIM + tl.store(state_base + offs, kv, mask=mask) + tl.store(state_base + STATE_WIDTH + offs, score, mask=mask) + return + + +@triton.jit +def _fused_compress_norm_rope_insert_kernel( + kv_score, + kv_score_stride0, + kv_score_stride1, + state_buffer, + positions, + token_to_batch_idx, + b_req_idx, + b_seq_len, + b_ready_cache_len, + b_q_start_loc, + req_to_token, + req_to_token_stride0, + full_to_swa, + out_slots, + norm_weight, + rms_eps, + cos_table, + cos_stride0, + sin_table, + sin_stride0, + out_buffer, + HEAD_DIM: tl.constexpr, + STATE_WIDTH: tl.constexpr, + STATE_LAST_DIM: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + WINDOW_SIZE: tl.constexpr, + IS_C4: tl.constexpr, + IS_PREFILL: tl.constexpr, + SWA_PAGE_SIZE: tl.constexpr, + STATE_RING: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + FP8_MAX: tl.constexpr, + SCALE_MIN: tl.constexpr, + NOPE_DIM: tl.constexpr, + QUANT_BLOCK: tl.constexpr, + SCALE_BYTES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, + BLOCK: tl.constexpr, + OUTPUT_BF16: tl.constexpr, +): + token_idx = tl.program_id(0) + out_slot = tl.load(out_slots + token_idx).to(tl.int64) + if out_slot < 0: + return + + position = tl.load(positions + token_idx) + if (position + 1) % COMPRESS_RATIO != 0: + return + + batch_idx = tl.load(token_to_batch_idx + token_idx) if IS_PREFILL else token_idx + req_idx = tl.load(b_req_idx + batch_idx).to(tl.int64) + seq_len = tl.load(b_seq_len + batch_idx) + if IS_PREFILL: + ready_len = tl.load(b_ready_cache_len + batch_idx) + q_start = tl.load(b_q_start_loc + batch_idx) + else: + ready_len = position + q_start = token_idx + + token_offsets = tl.arange(0, WINDOW_SIZE) + start = position - WINDOW_SIZE + 1 + gather_pos = start + token_offsets + valid_pos = (gather_pos >= 0) & (gather_pos < seq_len) + use_current = (gather_pos >= ready_len) & valid_pos if IS_PREFILL else gather_pos == position + current_idx = q_start + (gather_pos - ready_len) if IS_PREFILL else token_idx + token_offsets * 0 + + if IS_C4: + full_slot = tl.load( + req_to_token + req_idx * req_to_token_stride0 + gather_pos, + mask=valid_pos & (~use_current), + other=0, + ).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64) + state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING) + state_valid = valid_pos & (~use_current) & (swa_slot >= 0) + head_offset = tl.where(token_offsets >= COMPRESS_RATIO, HEAD_DIM, 0) + else: + full_slot = tl.load( + req_to_token + req_idx * req_to_token_stride0 + gather_pos, + mask=valid_pos & (~use_current), + other=0, + ).to(tl.int64) + swa_slot = tl.load(full_to_swa + full_slot, mask=valid_pos & (~use_current), other=-1).to(tl.int64) + state_row = (swa_slot // SWA_PAGE_SIZE) * STATE_RING + (swa_slot % STATE_RING) + state_valid = valid_pos & (~use_current) & (swa_slot >= 0) + head_offset = token_offsets * 0 + + offs = tl.arange(0, BLOCK) + dim_mask = offs < HEAD_DIM + current_mask = use_current[:, None] & dim_mask[None, :] + state_mask = state_valid[:, None] & dim_mask[None, :] + + cur_kv = tl.load( + kv_score + current_idx[:, None] * kv_score_stride0 + (head_offset[:, None] + offs[None, :]) * kv_score_stride1, + mask=current_mask, + other=0.0, + ) + cur_score = tl.load( + kv_score + + current_idx[:, None] * kv_score_stride0 + + (STATE_WIDTH + head_offset[:, None] + offs[None, :]) * kv_score_stride1, + mask=current_mask, + other=float("-inf"), + ) + state_kv = tl.load( + state_buffer + state_row[:, None] * STATE_LAST_DIM + head_offset[:, None] + offs[None, :], + mask=state_mask, + other=0.0, + ) + state_score = tl.load( + state_buffer + state_row[:, None] * STATE_LAST_DIM + STATE_WIDTH + head_offset[:, None] + offs[None, :], + mask=state_mask, + other=float("-inf"), + ) + + kv = tl.where(current_mask, cur_kv, state_kv) + score = tl.where(current_mask, cur_score, state_score) + score = tl.softmax(score, dim=0) + compressed_kv = tl.sum(kv * score, axis=0) + + rms_w = tl.load(norm_weight + offs, mask=dim_mask, other=0.0) + variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_DIM + rrms = tl.rsqrt(variance + rms_eps) + normed = compressed_kv * rrms * rms_w + + num_pairs: tl.constexpr = BLOCK // 2 + nope_pairs: tl.constexpr = NOPE_DIM // 2 + pair_2d = tl.reshape(normed, (num_pairs, 2)) + even, odd = tl.split(pair_2d) + pair_idx = tl.arange(0, num_pairs) + rope_pair_local = pair_idx - nope_pairs + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cos_v = tl.load(cos_table + compressed_pos * cos_stride0 + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(sin_table + compressed_pos * sin_stride0 + cs_idx, mask=is_rope_pair, other=0.0) + new_even = even * cos_v - odd * sin_v + new_odd = odd * cos_v + even * sin_v + rotated = tl.interleave(new_even, new_odd) + + if OUTPUT_BF16: + # indexer-K path: emit the post-rope full HEAD_DIM vector as dense bf16 (token-indexed), + # leaving the fp8 single-amax pack to destindex_copy_indexer_k_dsv4 (the c4_indexer_pool + # ABI differs from the latent slab: whole-vector fp8 + one fp32 scale, no bf16 rope tail). + tl.store(out_buffer + token_idx * HEAD_DIM + offs, rotated.to(tl.bfloat16), mask=dim_mask) + return + + page = out_slot // PAGE_SIZE + token_in_page = out_slot % PAGE_SIZE + data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_HEAD_DIM * 2) + scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_HEAD_DIM * 2) + token_in_page * SCALE_BYTES + + n_quant_blocks: tl.constexpr = BLOCK // QUANT_BLOCK + n_nope_blocks: tl.constexpr = NOPE_DIM // QUANT_BLOCK + quant_input = normed.to(tl.bfloat16).to(tl.float32) + quant_2d = tl.reshape(quant_input, (n_quant_blocks, QUANT_BLOCK)) + abs_2d = tl.abs(quant_2d) + block_absmax = tl.max(abs_2d, axis=1) + scale_exp = tl.ceil(libdevice.log2(tl.maximum(block_absmax / FP8_MAX, SCALE_MIN))).to(tl.int32) + scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True) + kv_fp8 = tl.clamp(quant_2d / scale[:, None], -FP8_MAX, FP8_MAX).to(tl.float8e4nv) + kv_u8 = tl.reshape(kv_fp8.to(tl.uint8, bitcast=True), (BLOCK,)) + tl.store(out_buffer + data_base + offs, kv_u8, mask=offs < NOPE_DIM) + + scale_idx = tl.arange(0, SCALE_BYTES) + scale_bytes = tl.where(scale_idx < n_nope_blocks, scale_exp + 127, 0).to(tl.uint8) + tl.store(out_buffer + scale_base + scale_idx, scale_bytes) + + rope_local = offs - NOPE_DIM + rope_mask = (offs >= NOPE_DIM) & dim_mask + rope_ptr = (out_buffer + data_base + NOPE_DIM).to(tl.pointer_type(tl.bfloat16)) + tl.store(rope_ptr + rope_local, rotated.to(tl.bfloat16), mask=rope_mask) + return + + +def prepare_compress_states(*, infer_state, layer_idx: int, compress_ratio: int, is_in_indexer: bool = False): + if compress_ratio == 0: + return None + + mem_manager: DeepseekV4MemoryManager = infer_state.mem_manager + if is_in_indexer: + # c4 Lightning-Indexer key compression: same window/state machinery as the c4 latent + # compressor but with index_head_dim, a separate state pool, and a DENSE bf16 scratch + # out_buffer (the kernel's OUTPUT_BF16 path); the fp8 pack into c4_indexer_pool is done + # afterwards by pack_indexer_k_to_cache. + assert compress_ratio == 4, "只有 c4(CSA) 层有 indexer-K" + out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = mem_manager.get_c4_indexer_state_buffer(layer_idx) + out_buffer = torch.empty( + (infer_state.mem_index.numel(), mem_manager.indexer_head_dim), + dtype=torch.bfloat16, + device=infer_state.mem_index.device, + ) + out_page_size = 1 # unused under OUTPUT_BF16 (token-indexed dense scratch, not paged) + else: + if compress_ratio == 4: + out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = mem_manager.get_c4_state_buffer(layer_idx) + out_pool = mem_manager.c4_pool + elif compress_ratio == 128: + out_slots = mem_manager.full_to_c128_indexs[infer_state.mem_index.long().reshape(-1)] + state_buffer = mem_manager.get_c128_state_buffer(layer_idx) + out_pool = mem_manager.c128_pool + else: + raise AssertionError(f"invalid DeepSeek-V4 compress ratio {compress_ratio}") + out_buffer = mem_manager.get_compressed_kv_buffer(layer_idx) + out_page_size = out_pool.page_size + + token_to_batch_idx = infer_state.b_req_idx + if infer_state.is_prefill: + token_to_batch_idx = getattr(infer_state, "_dsv4_token_to_batch_idx", None) + if token_to_batch_idx is None or token_to_batch_idx.numel() != infer_state.position_ids.numel(): + q_lens = (infer_state.b_seq_len - infer_state.b_ready_cache_len).to(torch.long) + batch_idx = torch.arange(infer_state.b_req_idx.shape[0], device=infer_state.b_req_idx.device) + token_to_batch_idx = torch.repeat_interleave(batch_idx, q_lens).to(torch.int32) + infer_state._dsv4_token_to_batch_idx = token_to_batch_idx + + return CoreCompressorMetadata( + layer_idx=layer_idx, + compress_ratio=compress_ratio, + out_slots=out_slots, + mem_index=infer_state.mem_index, + state_buffer=state_buffer, + out_buffer=out_buffer, + out_page_size=out_page_size, + position_ids=infer_state.position_ids, + b_req_idx=infer_state.b_req_idx, + b_seq_len=infer_state.b_seq_len, + b_ready_cache_len=infer_state.b_ready_cache_len, + b_q_start_loc=infer_state.b_q_start_loc, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + full_to_swa_indexs=mem_manager.full_to_swa_indexs, + token_to_batch_idx=token_to_batch_idx, + kv_score=None, + is_prefill=infer_state.is_prefill, + ) + + +def prepare_partial_states( + *, + kv_score: torch.Tensor, + metadata: Optional[CoreCompressorMetadata], + ape: torch.Tensor, + compress_ratio: int, +): + if metadata is None or kv_score.shape[0] == 0: + return + state_width = kv_score.shape[-1] // 2 + _add_ape_to_kv_score_kernel[(kv_score.shape[0],)]( + kv_score, + kv_score.stride(0), + kv_score.stride(1), + ape, + ape.stride(0), + metadata.position_ids, + STATE_WIDTH=state_width, + COMPRESS_RATIO=compress_ratio, + BLOCK=triton.next_power_of_2(state_width), + num_warps=4, + ) + return + + +def fused_compress( + *, + kv_score: torch.Tensor, + metadata: Optional[CoreCompressorMetadata], + norm_weight: torch.Tensor, + ape: torch.Tensor, + eps: float, + head_dim: int, + qk_rope_head_dim: int, + compress_ratio: int, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + output_bf16: bool = False, +): + if metadata is None or kv_score.shape[0] == 0: + return + + state_width = kv_score.shape[-1] // 2 + state_last_dim = metadata.state_buffer.shape[-1] + is_c4 = compress_ratio == 4 + state_ring = DSV4_C4_STATE_RING if is_c4 else DSV4_C128_STATE_RING + block_state = triton.next_power_of_2(state_width) + block_head = triton.next_power_of_2(head_dim) + + _fused_compress_norm_rope_insert_kernel[(kv_score.shape[0],)]( + kv_score, + kv_score.stride(0), + kv_score.stride(1), + metadata.state_buffer, + metadata.position_ids, + metadata.token_to_batch_idx, + metadata.b_req_idx, + metadata.b_seq_len, + metadata.b_ready_cache_len if metadata.b_ready_cache_len is not None else metadata.b_seq_len, + metadata.b_q_start_loc if metadata.b_q_start_loc is not None else metadata.b_seq_len, + metadata.req_to_token_indexs, + metadata.req_to_token_indexs.stride(0), + metadata.full_to_swa_indexs, + metadata.out_slots, + norm_weight, + eps, + cos_table, + cos_table.stride(0), + sin_table, + sin_table.stride(0), + metadata.out_buffer, + HEAD_DIM=head_dim, + STATE_WIDTH=state_width, + STATE_LAST_DIM=state_last_dim, + COMPRESS_RATIO=compress_ratio, + WINDOW_SIZE=compress_ratio * (2 if is_c4 else 1), + IS_C4=is_c4, + IS_PREFILL=metadata.is_prefill, + SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE, + STATE_RING=state_ring, + ROPE_HEAD_DIM=qk_rope_head_dim, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + SCALE_MIN=1e-4, + NOPE_DIM=head_dim - qk_rope_head_dim, + QUANT_BLOCK=64, + SCALE_BYTES=(head_dim - qk_rope_head_dim) // 64 + 1, + PAGE_SIZE=metadata.out_page_size, + BYTES_PER_PAGE=metadata.out_buffer.shape[-1], + BLOCK=block_head, + OUTPUT_BF16=output_bf16, + num_warps=4, + ) + + _save_partial_states_kernel[(kv_score.shape[0],)]( + kv_score, + kv_score.stride(0), + kv_score.stride(1), + metadata.position_ids, + metadata.token_to_batch_idx, + metadata.b_req_idx, + metadata.b_seq_len, + metadata.mem_index, + metadata.full_to_swa_indexs, + metadata.state_buffer, + STATE_WIDTH=state_width, + STATE_LAST_DIM=state_last_dim, + COMPRESS_RATIO=compress_ratio, + IS_C4=is_c4, + IS_PREFILL=metadata.is_prefill, + SWA_PAGE_SIZE=DSV4_SWA_PAGE_SIZE, + STATE_RING=state_ring, + BLOCK=block_state, + num_warps=4, + ) + return diff --git a/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py new file mode 100644 index 0000000000..080ebabd89 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/hyper_connection.py @@ -0,0 +1,75 @@ +import torch + +try: + import vllm.model_executor.layers.mhc # noqa: F401 +except Exception as e: + raise RuntimeError("DeepSeek-V4 requires vLLM mHC custom ops; failed to import vllm MHC kernels") from e + + +# vllm DeepseekV4DecoderLayer.hc_post_alpha +HC_POST_ALPHA = 2.0 + + +def hc_pre(residual, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps): + """Standalone hc_pre for the first layer. residual:[T, hc, dim] -> + (x[T,dim], residual, post_mix[T,hc,1], res_mix[T,hc,hc]); the sub-layer RMSNorm is fused via norm_weight.""" + post_mix, res_mix, x = torch.ops.vllm.mhc_pre_tilelang( + residual=residual, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=rms_eps, + hc_pre_eps=hc_eps, + hc_sinkhorn_eps=hc_eps, + hc_post_mult_value=HC_POST_ALPHA, + sinkhorn_repeat=sinkhorn_iters, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + return x, residual, post_mix, res_mix + + +def hc_fused_post_pre( + x, residual, post_mix, res_mix, hc_fn, hc_scale, hc_base, rms_eps, hc_eps, sinkhorn_iters, norm_weight, norm_eps +): + """hc_post of the previous sub-layer fused with hc_pre of the next one (norm fused too). + Returns (x[T,dim], residual[T,hc,dim], post_mix, res_mix).""" + residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre_tilelang( + x=x, + residual=residual, + post_layer_mix=post_mix, + comb_res_mix=res_mix, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=rms_eps, + hc_pre_eps=hc_eps, + hc_sinkhorn_eps=hc_eps, + hc_post_mult_value=HC_POST_ALPHA, + sinkhorn_repeat=sinkhorn_iters, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + return x, residual, post_mix, res_mix + + +def hc_post(x, residual, post_mix, res_mix): + """Complete the hc_post left pending by the last layer. -> streams [T, hc, dim].""" + return torch.ops.vllm.mhc_post_tilelang(x, residual, post_mix, res_mix) + + +def hc_head(streams, hc_fn, hc_scale, hc_base, hc_mult, dim, rms_eps, hc_eps, alloc_func): + """Final stream collapse before the lm_head. streams:[N, hc*dim] -> [N, dim].""" + out = alloc_func((streams.shape[0], dim), dtype=streams.dtype, device=streams.device) + torch.ops.vllm.hc_head_fused_kernel_tilelang( + streams.view(-1, hc_mult, dim).contiguous(), + hc_fn, + hc_scale, + hc_base, + out, + dim, + rms_eps, + hc_eps, + hc_mult, + ) + return out diff --git a/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..8eddfb3b9d --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/post_layer_infer.py @@ -0,0 +1,27 @@ +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from .hyper_connection import hc_head, hc_post +from ..infer_struct import DeepseekV4InferStateInfo + + +class DeepseekV4PostLayerInfer(LlamaPostLayerInfer): + """Collapse the hc_mult residual streams (hc_head) to [T, hidden], then final norm + lm_head.""" + + def token_forward(self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight): + cfg = layer_weight.network_config_ + if isinstance(input_embdings, tuple): + # truncated-layer runs (autotune warmup) end before the last layer's _hc_ffn_out + # collapse; finish the pending hc_post here. + streams = hc_post(*input_embdings) + input_embdings = streams.reshape(streams.shape[0], -1) + collapsed = hc_head( + input_embdings, + layer_weight.hc_head_fn_.weight, + layer_weight.hc_head_scale_.weight, + layer_weight.hc_head_base_.weight, + cfg["hc_mult"], + cfg["hidden_size"], + cfg["rms_norm_eps"], + cfg.get("hc_eps", 1e-6), + self.alloc_tensor, + ) + return super().token_forward(collapsed, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..b95f5a14a8 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/pre_layer_infer.py @@ -0,0 +1,24 @@ +import torch +import torch.distributed as dist +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.distributed.communication_op import all_reduce +from ..infer_struct import DeepseekV4InferStateInfo + + +class DeepseekV4PreLayerInfer(LlamaPreLayerInfer): + """Token embedding, then expand to the hc_mult parallel residual streams [T, hc_mult*hidden].""" + + def __init__(self, network_config): + super().__init__(network_config) + self.hc_mult = network_config["hc_mult"] + return + + def context_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + t, hidden = input_embdings.shape + return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden) + + def token_forward(self, input_ids, infer_state: DeepseekV4InferStateInfo, layer_weight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + t, hidden = input_embdings.shape + return input_embdings.unsqueeze(1).expand(t, self.hc_mult, hidden).reshape(t, self.hc_mult * hidden) diff --git a/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..57b171a8d1 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_infer/transformer_layer_infer.py @@ -0,0 +1,712 @@ +import os +import torch +import torch.distributed as dist +from lightllm.common.basemodel import TransformerLayerInferTpl +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.distributed.communication_op import all_reduce +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import DeepseekV4TransformerLayerWeight +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor +from lightllm.utils.vllm_utils import vllm_ops +from .hyper_connection import hc_pre, hc_fused_post_pre, hc_post +from .compressor import fused_compress as fused_compress_op +from .compressor import prepare_partial_states +from .compressor import prepare_compress_states +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from ..infer_struct import DeepseekV4InferStateInfo + + +class DeepseekV4TransformerLayerInfer(Deepseek3_2TransformerLayerInfer): + def __init__(self, layer_num, network_config): + TransformerLayerInferTpl.__init__(self, layer_num, network_config) + self.eps_ = network_config["rms_norm_eps"] + self.embed_dim_ = network_config["hidden_size"] + self.num_heads = network_config["num_attention_heads"] + self.head_dim_ = network_config["head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.qk_nope_head_dim = self.head_dim_ - self.qk_rope_head_dim + self.v_head_dim = self.head_dim_ + self.o_groups = network_config["o_groups"] + self.hc_mult = network_config["hc_mult"] + self.sinkhorn_iters = network_config["hc_sinkhorn_iters"] + self.hc_eps = network_config["hc_eps"] + self.compress_ratio = network_config["compress_ratios"][layer_num] + self.is_hash = layer_num < network_config["num_hash_layers"] + self.is_last_layer = layer_num == network_config["n_layer"] - 1 + # complex64 rope table for this layer's variant (sliding / compressed); set by + # DeepseekV4TpPartModel._init_to_get_rotary once the tables are built. The full compress + # cos/sin tables (compressor entry rope uses entry positions, not token positions) are + # wired there too. + self.freqs_cis = None + self.cos_compress_table = None + self.sin_compress_table = None + self.num_experts_per_tok = network_config["num_experts_per_tok"] + self.routed_scaling_factor = network_config["routed_scaling_factor"] + self.swiglu_limit = network_config["swiglu_limit"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.tp_q_head_num_ = self.num_heads // self.tp_world_size_ + self.tp_groups = self.o_groups // self.tp_world_size_ + self.enable_ep_moe = get_env_start_args().enable_ep_moe + self.compressor = CompressorInfer( + layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_ + ) + self.index_infer = DeepseekV4IndexInfer( + layer_idx=self.layer_num_, network_config=self.network_config_, tp_world_size=self.tp_world_size_ + ) + + # ------------------------------------------------------------------ forward (HC-threaded) + def _hc_attn_in(self, input_embdings, layer_weight: DeepseekV4TransformerLayerWeight): + """Layer input -> attention input (attn_norm fused). First layer gets the raw streams + and runs a standalone hc_pre; later layers get (x, residual, post_mix, res_mix) and fuse + the previous layer's ffn hc_post with this layer's attn hc_pre.""" + if torch.is_tensor(input_embdings): + residual = input_embdings.view(-1, self.hc_mult, self.embed_dim_) + return hc_pre( + residual, + layer_weight.hc_attn_fn_.weight, + layer_weight.hc_attn_scale_.weight, + layer_weight.hc_attn_base_.weight, + self.eps_, + self.hc_eps, + self.sinkhorn_iters, + layer_weight.attn_norm_.weight, + self.eps_, + ) + x, residual, post_mix, res_mix = input_embdings + return hc_fused_post_pre( + x, + residual, + post_mix, + res_mix, + layer_weight.hc_attn_fn_.weight, + layer_weight.hc_attn_scale_.weight, + layer_weight.hc_attn_base_.weight, + self.eps_, + self.hc_eps, + self.sinkhorn_iters, + layer_weight.attn_norm_.weight, + self.eps_, + ) + + def _hc_ffn_in(self, x, residual, post_mix, res_mix, layer_weight: DeepseekV4TransformerLayerWeight): + """Attention output -> ffn input (ffn_norm fused): fused attn hc_post + ffn hc_pre.""" + return hc_fused_post_pre( + x, + residual, + post_mix, + res_mix, + layer_weight.hc_ffn_fn_.weight, + layer_weight.hc_ffn_scale_.weight, + layer_weight.hc_ffn_base_.weight, + self.eps_, + self.hc_eps, + self.sinkhorn_iters, + layer_weight.ffn_norm_.weight, + self.eps_, + ) + + def _hc_ffn_out(self, x, residual, post_mix, res_mix): + """Mid layers leave the ffn hc_post pending for the next layer's fused post+pre; the last + layer completes it and hands the flat streams [T, hc_mult*hidden] back to the model loop.""" + if not self.is_last_layer: + return x, residual, post_mix, res_mix + streams = hc_post(x, residual, post_mix, res_mix) + return streams.reshape(streams.shape[0], -1) + + def context_forward( + self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight) + x = self.context_attention_forward(x, infer_state, layer_weight) + x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight) + x = self._ffn(x, infer_state, layer_weight) + return self._hc_ffn_out(x, residual, post_mix, res_mix) + + def token_forward( + self, input_embdings, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + x, residual, post_mix, res_mix = self._hc_attn_in(input_embdings, layer_weight) + x = self.token_attention_forward(x, infer_state, layer_weight) + x, residual, post_mix, res_mix = self._hc_ffn_in(x, residual, post_mix, res_mix, layer_weight) + x = self._ffn(x, infer_state, layer_weight) + return self._hc_ffn_out(x, residual, post_mix, res_mix) + + # ------------------------------------------------------------------ shared projections / cache + def _select_rope(self, infer_state: DeepseekV4InferStateInfo): + if self.compress_ratio: + return infer_state.position_cos_compress, infer_state.position_sin_compress + return infer_state.position_cos_sliding, infer_state.position_sin_sliding + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: DeepseekV4InferStateInfo, + layer_weight: DeepseekV4TransformerLayerWeight, + ): + from lightllm.third_party.sglang_jit.dsv4 import fused_q_norm_rope + + input = self._tpsp_allgather(input=input, infer_state=infer_state) + T = input.shape[0] + qa = layer_weight.q_norm_(layer_weight.wq_a_.mm(input), eps=self.eps_) + q_in = layer_weight.wq_b_.mm(qa).view(T, self.tp_q_head_num_, self.head_dim_) + # per-(token, head) weightless self-RMSNorm + interleaved rope on the last rope_dim dims, + # fused in one sglang dsv4 jit kernel (fp32 norm/rotation, bf16 in between -- same as eager). + q = self.alloc_tensor(q_in.shape, dtype=q_in.dtype, device=q_in.device) + fused_q_norm_rope(q_in, q, self.eps_, self.freqs_cis, infer_state.position_ids) + # kv: rmsnorm + rope + fp8 pack + scatter 进 swa 池,一个 sglang jit kernel 完成 + # (同 sglang _compute_kv_to_cache),替代 eager norm/rope/cat + _post_cache_kv。 + # bf16 kv 中间量没有其他消费者: flashmla 路径注意力读 cache,压缩器/indexer 取 x。 + infer_state.mem_manager.pack_mla_kv_to_cache_fused_norm_rope( + layer_index=self.layer_num_, + mem_index=infer_state.mem_index, + kv=layer_weight.wkv_.mm(input), + kv_weight=layer_weight.kv_norm_.weight, + eps=self.eps_, + freqs_cis=self.freqs_cis, + positions=infer_state.position_ids, + ) + return q, qa, input + + def _get_o(self, o, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): + # o: [T, tp_q_head_num_, head_dim_] after inverse rope -> grouped low-rank O -> [T, embed_dim_] + position_cos, position_sin = self._select_rope(infer_state) + rotary_emb_fwd(o[..., -self.qk_rope_head_dim :], None, position_cos, position_sin, inverse=True) + T = o.shape[0] + o = o.reshape(T, self.tp_groups, -1).transpose(0, 1).contiguous() # [groups, T, per_group_in] + o = layer_weight.wo_a_.bmm(o).transpose(0, 1).reshape(T, -1) # [T, groups*o_lora] + o = layer_weight.wo_b_.mm(o) + return self._tpsp_reduce(input=o, infer_state=infer_state) + + # ------------------------------------------------------------------ attention (prefill) + def context_attention_forward( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + # _get_qkv writes the chunk's packed latent into the swa pool (fused kernel) before + # attention reads it back via full_to_swa indices (this custom forward bypasses the + # tpl _post_cache_kv path). + q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight) + o = self._context_attention_wrapper_run(q, q_lora, full_x, infer_state, layer_weight) + return self._get_o(o, infer_state, layer_weight) + + def _context_attention_wrapper_run( + self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + if torch.cuda.is_current_stream_capturing(): + q = q.contiguous() + q_lora = q_lora.contiguous() + x = x.contiguous() + _q = tensor_to_no_ref_tensor(q) + _q_lora = tensor_to_no_ref_tensor(q_lora) + _x = tensor_to_no_ref_tensor(x) + + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + # Same graph-split output handoff as the template, but avoid its dry-run because + # DSV4 attention mutates compressor/cache state before returning. + o = self.alloc_tensor((q.shape[0], self.tp_q_head_num_, self.head_dim_), dtype=q.dtype, device=q.device) + _o = tensor_to_no_ref_tensor(o) + + def att_func(new_infer_state: DeepseekV4InferStateInfo): + tmp_o = self._context_attention_kernel(_q, _q_lora, _x, new_infer_state, layer_weight) + assert tmp_o.shape == _o.shape + _o.copy_(tmp_o) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=att_func, after_graph=pre_capture_graph) + return o + + return self._context_attention_kernel(q, q_lora, x, infer_state, layer_weight) + + def _context_attention_kernel( + self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + self.compressor.prepare_states(x, infer_state, layer_weight) + self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) + # Write this step's c4 Lightning-Indexer keys (no-op off c4) BEFORE build_metadata so the + # scorer (gather + deep_gemm.fp8_mqa_logits) reads fresh+accumulated entries from the indexer pool. + self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) + # Build the FINAL flash_mla index tensors here (model side), so att_control is a thin + # transport of ready-to-forward tensors -- not indexer raw material. Must stay after + # fused_compress (c4 reads the indexer-K pool it writes) and before prefill_att (keeps the + # c4 scorer/topk at the same cuda-graph capture position). + meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight) + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "flashmla_kvcache": True, + "layer_index": self.layer_num_, + "compress_ratio": self.compress_ratio, + "head_dim_v": self.v_head_dim, + "softmax_scale": self.softmax_scale, + "attn_sink": layer_weight.attn_sink_.weight, + **meta, + }, + ) + out = infer_state.prefill_att_state.prefill_att( + q=q, + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) + pad_q_len = getattr(infer_state, "_dsv4_prefill_pad_q_len", 0) + if pad_q_len: + # pad 行读 HOLD 槽位(参见 infer_struct._dsv4_prefill_pad_q_len),清零以保持确定性 + out[-pad_q_len:] = 0 + return out + + # ------------------------------------------------------------------ attention (decode) + def token_attention_forward( + self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + q, q_lora, full_x = self._get_qkv(x, infer_state, layer_weight) + o = self._token_attention_kernel(q, q_lora, full_x, infer_state, layer_weight) + return self._get_o(o, infer_state, layer_weight) + + def _token_attention_kernel( + self, q, q_lora, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + self.compressor.prepare_states(x, infer_state, layer_weight) + self.compressor.fused_compress(infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) + self.index_infer.write_indexer_k(x, infer_state, layer_weight, self.cos_compress_table, self.sin_compress_table) + meta = self.index_infer.build_metadata(x, q_lora, infer_state, layer_weight) + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "flashmla_kvcache": True, + "layer_index": self.layer_num_, + "compress_ratio": self.compress_ratio, + "head_dim_v": self.v_head_dim, + "softmax_scale": self.softmax_scale, + "attn_sink": layer_weight.attn_sink_.weight, + **meta, + }, + ) + return infer_state.decode_att_state.decode_att( + q=q, + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) + + # ------------------------------------------------------------------ moe + def _routed_experts(self, x, weights, indices, layer_weight: DeepseekV4TransformerLayerWeight): + return layer_weight.experts_.experts_with_preselected( + input_tensor=x, + topk_weights=weights, + topk_ids=indices, + clamp_limit=float(self.swiglu_limit), + ) + + def _ffn(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight): + x = x.view(-1, self.embed_dim_) + if not self.enable_ep_moe: + x = self._tpsp_allgather(input=x, infer_state=infer_state) + + logits = layer_weight.gate_weight_.mm(x.float()).contiguous() + weights, indices = self._select_experts(logits, infer_state, layer_weight) + # shared expert 必须先于 routed 计算: fp8 路径 (FuseMoeTriton) 的 fused_experts + # 是 inplace 的,_routed_experts 返回后 x 已被覆盖为 routed 输出。 + # 复用 Llama 的 _ffn_tp: fused gate_up matmul + silu_and_mul triton kernel,无 swiglu clamp, + # 对齐参考 DeepseekV4MLP(=LlamaMLP)。swiglu_limit clamp 只属于 routed 专家 (见 _routed_experts)。 + shared = self._ffn_tp(input=x, infer_state=infer_state, layer_weight=layer_weight) + routed = self._routed_experts(x, weights, indices, layer_weight) + if self.enable_ep_moe: + if self.tp_world_size_ > 1: + all_reduce( + shared, + op=dist.ReduceOp.SUM, + group=infer_state.dist_group, + async_op=False, + ) + return routed + shared + out = routed + shared + return self._tpsp_reduce(input=out, infer_state=infer_state) + + def _select_experts( + self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + return self._select_experts_vllm(logits, infer_state, layer_weight) + + def _select_experts_vllm( + self, logits, infer_state: DeepseekV4InferStateInfo, layer_weight: DeepseekV4TransformerLayerWeight + ): + M = logits.shape[0] + bias = None + input_tokens = None + hash_indices_table = None + indices_dtype = torch.int64 + if self.is_hash: + hash_indices_table = layer_weight.gate_tid2eid_.weight + if not hash_indices_table.is_contiguous(): + hash_indices_table = hash_indices_table.contiguous() + indices_dtype = hash_indices_table.dtype + input_tokens = infer_state.input_ids.to(dtype=indices_dtype).contiguous() + else: + bias = layer_weight.gate_bias_.weight + + weights = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.float32, device=logits.device) + indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=indices_dtype, device=logits.device) + token_expert_indices = self.alloc_tensor((M, self.num_experts_per_tok), dtype=torch.int32, device=logits.device) + vllm_ops.topk_hash_softplus_sqrt( + weights, + indices, + token_expert_indices, + logits, + True, + self.routed_scaling_factor, + bias, + input_tokens, + hash_indices_table, + ) + return weights, indices.long() + + +class CompressorInfer: + """Window-softmax compressor. is_in_indexer=False compresses the c4/c128 latent KV into the + paged fp8 slab (attention extra_k); is_in_indexer=True reuses the SAME machinery (mirroring + sglang's Compressor(is_in_indexer=...)) with the indexer weights/dims/state pool to produce the + per-c4-entry Lightning-Indexer keys, emitted as dense bf16 (OUTPUT_BF16) then fp8-packed into + c4_indexer_pool by the caller. Indexer mode is c4-only.""" + + def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int, is_in_indexer: bool = False): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.tp_world_size_ = tp_world_size + self.is_in_indexer = is_in_indexer + self.compress_ratio = network_config["compress_ratios"][layer_idx] + self.head_dim = network_config["head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.eps = network_config["rms_norm_eps"] + self._metadata = None + + def prepare_states( + self, + x: torch.Tensor, + infer_state: DeepseekV4InferStateInfo, + layer_weight: DeepseekV4TransformerLayerWeight, + ): + self._metadata = prepare_compress_states( + infer_state=infer_state, + layer_idx=self.layer_idx_, + compress_ratio=self.compress_ratio, + is_in_indexer=self.is_in_indexer, + ) + if self._metadata is not None: + if self.is_in_indexer: + # indexer wkv/wgate are two separate replicated weights; cat -> [T, 2*coff*idx_hd] + # (same [kv | score] layout the fused compressor_wkv_gate_ produces for attention). + kv = layer_weight.idx_cmp_wkv_.mm(x) + gate = layer_weight.idx_cmp_wgate_.mm(x) + self._metadata.kv_score = torch.cat([kv, gate], dim=-1).float() + ape = layer_weight.idx_cmp_ape_.weight + else: + self._metadata.kv_score = layer_weight.compressor_wkv_gate_.mm(x).float() + ape = layer_weight.compressor_ape_.weight + prepare_partial_states( + kv_score=self._metadata.kv_score, + metadata=self._metadata, + ape=ape, + compress_ratio=self.compress_ratio, + ) + return self._metadata + + def fused_compress( + self, + infer_state: DeepseekV4InferStateInfo, + layer_weight: DeepseekV4TransformerLayerWeight, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + ): + if self.compress_ratio == 0: + return None + metadata = self._metadata + if metadata is None: + raise RuntimeError("DeepSeek-V4 compressor.prepare_states must run before fused_compress") + if self.is_in_indexer: + norm_weight = layer_weight.idx_cmp_norm_.weight + ape = layer_weight.idx_cmp_ape_.weight + head_dim = self.index_head_dim + else: + norm_weight = layer_weight.compressor_norm_.weight + ape = layer_weight.compressor_ape_.weight + head_dim = self.head_dim + return fused_compress_op( + kv_score=metadata.kv_score, + metadata=metadata, + norm_weight=norm_weight, + ape=ape, + eps=self.eps, + head_dim=head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + compress_ratio=self.compress_ratio, + cos_table=cos_table, + sin_table=sin_table, + output_bf16=self.is_in_indexer, + ) + + +class DeepseekV4IndexInfer: + """Model-side builder for the FlashMLA sparse-index metadata. Mirrors deepseek3_2's NsaInfer + boundary (the model owns ALL index construction; the attention backend only forwards final + tensors to flash_mla.flash_mla_with_kvcache) AND its c4 implementation: hadamard'd fp8 q/K, a + ragged gather of the compressed c4 keys, deep_gemm.fp8_mqa_logits, then topk -- adapted for the + replicated indexer (no gather-q/all_reduce), the c4-compressed entry space, and topk-512 (no + inheritance only because of those data-shape differences). swa metadata is precomputed in + init_some_extra_state; this class owns the c4/c128 entry gather (build_compress_index) AND the c4 + Lightning-Indexer scoring (gather + deep_gemm.fp8_mqa_logits + topk). Holds only static per-layer + config; all per-request data flows in via args. Invoke from _context/_token_attention_kernel + (after compressor.fused_compress, before *_att) so the c4 scorer/topk keep the same cuda-graph + capture position they had when this lived in the backend. The indexer is replicated (no TP collective).""" + + def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): + self.layer_idx_ = layer_idx + self.compress_ratio = network_config["compress_ratios"][layer_idx] + self.index_topk = network_config["index_topk"] + self.index_head_dim = network_config["index_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_n_heads = network_config["index_n_heads"] + self.tp_world_size_ = tp_world_size + self.indexer_score_scale = self.index_head_dim ** -0.5 + self.indexer_weight_scale = self.indexer_score_scale * self.index_n_heads ** -0.5 + # c4 layers own a second compressor (is_in_indexer) that writes the Lightning-Indexer key + # pool every step; _c4_indices gathers it back + scores via deep_gemm.fp8_mqa_logits. + self.indexer_compressor = ( + CompressorInfer(layer_idx, network_config, tp_world_size, is_in_indexer=True) + if self.compress_ratio == 4 + else None + ) + + def write_indexer_k(self, x, infer_state: DeepseekV4InferStateInfo, layer_weight, cos_table, sin_table): + """c4-only: compress this step's tokens into per-c4-entry indexer keys and pack them into + c4_indexer_pool. MUST run before build_metadata so the scorer (gather + deep_gemm.fp8_mqa_logits) + reads the finished entries; runs every step (incl. in the decode graph) so keys accumulate for + later long-context scoring. No-op on c128 / dense layers.""" + if self.compress_ratio != 4: + return + self.indexer_compressor.prepare_states(x, infer_state, layer_weight) + self.indexer_compressor.fused_compress(infer_state, layer_weight, cos_table, sin_table) + scratch = self.indexer_compressor._metadata.out_buffer # [T, index_head_dim] bf16 (group-end rows valid) + # Rotate K (post norm+rope) by the SAME 1/sqrt(d) Hadamard the q kernel applies, so + # (Hq)·(Hk)=q·k (H orthogonal) and the fp8 quant of K stays accurate. + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + + scratch = hadamard_transform(scratch, scale=self.index_head_dim ** -0.5) + mem_manager = infer_state.mem_manager + positions = infer_state.position_ids + out_slots = mem_manager.full_to_c4_indexs[infer_state.mem_index.long().reshape(-1)] + # only group-end tokens finish a c4 entry; mask the rest to -1 so the packer skips them + # (mid-group tokens share the group's c4 slot -> avoids racing a finished slot). + completed = ((positions + 1) % 4 == 0) & (out_slots >= 0) + masked_slots = torch.where(completed, out_slots, torch.full_like(out_slots, -1)).to(torch.int32) + mem_manager.pack_indexer_k_to_cache(self.layer_idx_, masked_slots, scratch) + + def build_metadata(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + """Return the final flash_mla index tensors for this layer's compress variant. swa indices and + the per-token req_idx are layer-independent and precomputed once in init_some_extra_state + (read here); only the c4 scorer / c128 gather is per-layer. The backend pairs these with the + (data-independent, layer-keyed) fp8 cache-byte views it owns.""" + swa_indices = infer_state.dsv4_swa_indices.unsqueeze(1) + swa_lengths = infer_state.dsv4_swa_lengths + req_idx = infer_state.dsv4_sparse_req_idx + positions = infer_state.position_ids + extra_indices = extra_lengths = None + if self.compress_ratio == 4: + idx_q_fp8, weights = self._indexer_q_weight(x, q_lora, infer_state, layer_weight) + extra_indices, extra_lengths = self._c4_indices(infer_state, idx_q_fp8, weights, positions) + elif self.compress_ratio == 128: + extra_indices, extra_lengths = self._c128_indices(infer_state, req_idx, positions) + return { + "swa_indices": swa_indices, + "swa_lengths": swa_lengths, + "extra_indices": extra_indices, + "extra_lengths": extra_lengths, + } + + def _c128_indices(self, infer_state: DeepseekV4InferStateInfo, req_idx, positions): + from ..triton_kernel.build_compress_index_dsv4 import build_compress_index + + cap = ((max(1, int(infer_state.max_kv_seq_len) // 128) + 63) // 64) * 64 + indices, lengths = build_compress_index( + req_idx, + positions, + infer_state.req_manager.req_to_token_indexs, + infer_state.mem_manager.full_to_c128_indexs, + ratio=128, + cap=cap, + ) + return indices.unsqueeze(1), lengths + + def _indexer_q_weight(self, x, q_lora, infer_state: DeepseekV4InferStateInfo, layer_weight): + """fp8 indexer q (mirrors deepseek3_2 NsaInfer): wq_b -> rope(last rope dims) -> 1/sqrt(d) + Hadamard -> per-token fp8 quant. Returns (idx_q_fp8 [T,H,d], weights [T,H]); the per-token q + fp8 scale and the head_dim^-0.5 * n_heads^-0.5 score scale are folded into weights -- the + deep_gemm.fp8_mqa_logits contract (fp8 q carries no companion scale). Replicated -> full heads.""" + from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + + cos_tok = infer_state.position_cos_compress + sin_tok = infer_state.position_sin_compress + token_num = q_lora.shape[0] + if x.shape[0] != token_num: + raise RuntimeError( + f"DeepSeek-V4 indexer expects full-token hidden states, got x={x.shape[0]} q_lora={token_num}" + ) + idx_q = layer_weight.idx_wq_b_.mm(q_lora).view(token_num, self.index_n_heads, self.index_head_dim) + rotary_emb_fwd(idx_q[..., -self.qk_rope_head_dim :], None, cos_tok, sin_tok) + idx_q = hadamard_transform(idx_q, scale=self.index_head_dim ** -0.5) + idx_q_fp8, q_scale = act_quant(idx_q, self.index_head_dim, None) # fp8 [T,H,d], scale [T,H,1] + weights = layer_weight.idx_weights_proj_.mm(x).float() * self.indexer_weight_scale # [T, H] + weights = weights.unsqueeze(-1) * q_scale # fold per-token q scale + return idx_q_fp8, weights.squeeze(-1).contiguous() + + def _c4_indices(self, infer_state: DeepseekV4InferStateInfo, idx_q_fp8, weights, positions): + """c4 scorer via ds3.2-style gather + deep_gemm.fp8_mqa_logits. Gather each request's causal c4 + keys into a padded-per-request ragged fp8 buffer (k row r*c4_cap+e), score every query token + over its absolute [ks, ke) range, then masked topk-512 -> c4 slots. Fixed shapes (c4_cap pinned + per graph bucket) keep the decode cuda graph capturable.""" + mem_manager = infer_state.mem_manager + index_topk = self.index_topk + max_entries = max(1, int(infer_state.max_kv_seq_len) // 4) + c4_cap = ((max_entries + 63) // 64) * 64 + + # entry space fits the budget -> every causal entry is selected; no scoring needed. The + # captured decode graph (graph_max_len -> max_entries > topk) always takes the scorer branch + # below, so this only shortcuts tiny eager contexts. + if max_entries <= index_topk: + from ..triton_kernel.build_compress_index_dsv4 import build_compress_index + + slots, lengths = build_compress_index( + infer_state.dsv4_sparse_req_idx, + positions, + infer_state.req_manager.req_to_token_indexs, + mem_manager.full_to_c4_indexs, + ratio=4, + cap=c4_cap, + ) + return slots.unsqueeze(1), lengths + + b_req_idx = infer_state.b_req_idx + batch = b_req_idx.shape[0] + device = positions.device + c4_len = torch.div(infer_state.b_seq_len, 4, rounding_mode="floor").to(torch.int32) # entries/req + + if os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER", "0") == "1": + out = self._c4_indices_paged( + infer_state=infer_state, + idx_q_fp8=idx_q_fp8, + weights=weights, + positions=positions, + c4_len=c4_len, + c4_cap=c4_cap, + ) + if out is not None: + return out + + import deep_gemm + from ..triton_kernel.gather_c4_indexer_k_dsv4 import gather_c4_indexer_k_ragged + + k_fp8, k_scale, ragged_slots = gather_c4_indexer_k_ragged( + mem_manager, + self.layer_idx_, + b_req_idx, + c4_len, + c4_cap, + infer_state.req_manager.req_to_token_indexs, + ) + # batch position of each query token -> absolute [ks, ke) into the padded buffer. + if infer_state.is_prefill: + token_batch_pos = torch.repeat_interleave( + torch.arange(batch, device=device, dtype=torch.int32), infer_state.b_q_seq_len + ) + else: + token_batch_pos = torch.arange(batch, device=device, dtype=torch.int32) + valid_len = ((positions + 1) // 4).to(torch.int32) # causal candidate count per query + ks = token_batch_pos * c4_cap + ke = ks + valid_len + logits = deep_gemm.fp8_mqa_logits( + idx_q_fp8, (k_fp8, k_scale), weights, ks, ke, clean_logits=False, max_seqlen_k=c4_cap + ) # [T, c4_cap] f32, left-aligned: logits[t, j] = q_t . k[ks[t]+j] + col = torch.arange(c4_cap, device=device) + logits = logits.masked_fill(col.unsqueeze(0) >= valid_len.unsqueeze(1), float("-inf")) + top = logits.topk(index_topk, dim=-1).indices.to(torch.int32) # relative positions in [0, valid_len) + abs_idx = top + ks.unsqueeze(1) # absolute compact row + top_slots = ragged_slots[abs_idx.long()] # compact row -> c4 pool slot + invalid = top >= valid_len.unsqueeze(1) # topk over -inf padding when valid_len < index_topk + top_slots = torch.where(invalid, torch.full_like(top_slots, -1), top_slots) + topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1) + return top_slots.unsqueeze(1), topk_lengths.contiguous() + + def _c4_indices_paged(self, infer_state, idx_q_fp8, weights, positions, c4_len, c4_cap): + import deep_gemm + from lightllm.third_party.sglang_jit.dsv4 import topk_transform_512 + from ..triton_kernel.gather_c4_indexer_k_dsv4 import build_c4_indexer_page_table + + mem_manager = infer_state.mem_manager + index_topk = self.index_topk + device = positions.device + b_req_idx = infer_state.b_req_idx + batch = b_req_idx.shape[0] + validate = os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER_VALIDATE", "0") == "1" + validate_now = validate and not torch.cuda.is_current_stream_capturing() + + page_table, valid_flag = build_c4_indexer_page_table( + mem_manager, + b_req_idx, + c4_len, + c4_cap, + infer_state.req_manager.req_to_token_indexs, + infer_state.req_manager.HOLD_REQUEST_ID, + validate=validate_now, + ) + if validate_now and int(valid_flag.item()) == 0: + if os.getenv("LIGHTLLM_DSV4_PAGED_INDEXER_STRICT", "0") == "1": + raise RuntimeError("DeepSeek-V4 paged indexer requires page-aligned c4 slots") + return None + + if infer_state.is_prefill: + token_batch_pos = torch.repeat_interleave( + torch.arange(batch, device=device, dtype=torch.int32), infer_state.b_q_seq_len + ) + row_page_table = page_table[token_batch_pos.long()].contiguous() + else: + row_page_table = page_table + + valid_len = ((positions + 1) // 4).to(torch.int32) + ctx_lens = torch.clamp(valid_len, min=1).reshape(-1, 1).contiguous() + kv_cache = mem_manager.c4_indexer_pool.get_layer_buffer(mem_manager.layer_to_c4_idx[self.layer_idx_]).view( + mem_manager.c4_indexer_pool.num_pages, + mem_manager.c4_indexer_pool.page_size, + 1, + self.index_head_dim + 4, + ) + metadata = deep_gemm.get_paged_mqa_logits_metadata( + ctx_lens, + mem_manager.c4_indexer_pool.page_size, + deep_gemm.get_num_sms(), + ) + logits = deep_gemm.fp8_paged_mqa_logits( + idx_q_fp8.unsqueeze(1), + kv_cache, + weights, + ctx_lens, + row_page_table, + metadata, + c4_cap, + False, + ) + top_slots = torch.empty((idx_q_fp8.shape[0], index_topk), dtype=torch.int32, device=device) + topk_transform_512( + logits, + valid_len, + row_page_table, + top_slots, + mem_manager.c4_indexer_pool.page_size, + ) + topk_lengths = torch.clamp(torch.minimum(valid_len, torch.full_like(valid_len, index_topk)), min=1) + return top_slots.unsqueeze(1), topk_lengths.contiguous() diff --git a/lightllm/models/deepseek_v4/layer_weights/__init__.py b/lightllm/models/deepseek_v4/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..54f29ce574 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,37 @@ +import torch +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + RMSNormWeight, + ParameterWeight, +) + + +class DeepseekV4PreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + + hidden = network_config["hidden_size"] + vocab = network_config["vocab_size"] + hc_mult = network_config["hc_mult"] + + # embeddings / lm_head / final norm (bf16, vocab tensor-parallel). V4 has no `model.` prefix + # and does not tie embeddings (tie_word_embeddings=false). + self.wte_weight_ = EmbeddingWeight( + dim=hidden, vocab_size=vocab, weight_name="embed.weight", data_type=self.data_type_ + ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden, vocab_size=vocab, weight_name="head.weight", data_type=self.data_type_ + ) + self.final_norm_weight_ = RMSNormWeight(dim=hidden, weight_name="norm.weight", data_type=self.data_type_) + + # final hyper-connection head (collapses the hc_mult residual streams before the lm_head) + self.hc_head_fn_ = ParameterWeight( + weight_name="hc_head_fn", data_type=torch.float32, weight_shape=(hc_mult, hc_mult * hidden) + ) + self.hc_head_base_ = ParameterWeight( + weight_name="hc_head_base", data_type=torch.float32, weight_shape=(hc_mult,) + ) + self.hc_head_scale_ = ParameterWeight(weight_name="hc_head_scale", data_type=torch.float32, weight_shape=(1,)) + return diff --git a/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..5896027b38 --- /dev/null +++ b/lightllm/models/deepseek_v4/layer_weights/transformer_layer_weight.py @@ -0,0 +1,331 @@ +import torch +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + ROWBMMWeight, + RMSNormWeight, + ParameterWeight, + TpAttSinkWeight, + FusedMoeWeight, +) +from ..triton_kernel.quant_convert import dequant_fp8_block_to_bf16 + + +class DeepseekV4TransformerLayerWeight(TransformerLayerWeight): + """Per-layer weights for DeepSeek-V4-Flash. + + DS4 does not share DS2/DS3.2's ``model.layers.*.self_attn/mlp`` layout. Its attention is + HC + CSA, and routed experts are checkpointed as MXFP4 (fp4 release) or + FP8 block-128 (fp8 release, same layout as the dense fp8 weights). + """ + + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _parse_config(self): + cfg = self.network_config_ + self.hidden = cfg["hidden_size"] + self.n_heads = cfg["num_attention_heads"] + self.head_dim = cfg["head_dim"] + self.rope_dim = cfg["qk_rope_head_dim"] + self.q_lora_rank = cfg["q_lora_rank"] + self.o_lora_rank = cfg["o_lora_rank"] + self.o_groups = cfg["o_groups"] + self.index_n_heads = cfg["index_n_heads"] + self.index_head_dim = cfg["index_head_dim"] + self.n_routed_experts = cfg["n_routed_experts"] + self.moe_inter = cfg["moe_intermediate_size"] + self.num_hash_layers = cfg["num_hash_layers"] + self.vocab_size = cfg["vocab_size"] + self.hc_mult = cfg["hc_mult"] + self.mix_hc = (2 + self.hc_mult) * self.hc_mult + self.compress_ratio = cfg["compress_ratios"][self.layer_num_] + self.has_compressor = self.compress_ratio != 0 + self.has_indexer = self.compress_ratio == 4 + self.is_hash = self.layer_num_ < self.num_hash_layers + assert self.n_heads % self.tp_world_size_ == 0 + assert self.o_groups % self.tp_world_size_ == 0 + self.prefix = f"layers.{self.layer_num_}" + + def _init_weight(self): + self._init_qkvo() + if self.has_compressor: + self._init_compressor() + if self.has_indexer: + self._init_indexer() + self._init_moe() + self._init_norm() + self._init_hyper_connection() + + # ------------------------------------------------------------------ attention + def _init_qkvo(self): + p = f"{self.prefix}.attn" + # q low-rank (a replicated, b column-parallel over heads), kv single head (replicated) + self.wq_a_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.q_lora_rank], + weight_names=f"{p}.wq_a.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("wq_a"), + tp_rank=0, + tp_world_size=1, + ) + self.wq_b_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.n_heads * self.head_dim], + weight_names=f"{p}.wq_b.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("wq_b"), + ) + self.wkv_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.head_dim], + weight_names=f"{p}.wkv.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("wkv"), + tp_rank=0, + tp_world_size=1, + ) + self.q_norm_ = RMSNormWeight(dim=self.q_lora_rank, weight_name=f"{p}.q_norm.weight", data_type=self.data_type_) + self.kv_norm_ = RMSNormWeight(dim=self.head_dim, weight_name=f"{p}.kv_norm.weight", data_type=self.data_type_) + self.attn_sink_ = TpAttSinkWeight( + all_q_head_num=self.n_heads, weight_name=f"{p}.attn_sink", data_type=torch.float32 + ) + # grouped low-rank output projection: wo_a is a per-group batched matmul [groups, in, o_lora], + # wo_b is row-parallel [groups*o_lora -> hidden]. wo_a is reshaped in load_hf_weights. + per_group_in = self.n_heads * self.head_dim // self.o_groups + self.wo_a_ = ROWBMMWeight( + dim0=self.o_groups, + dim1=per_group_in, + dim2=self.o_lora_rank, + weight_names=f"{p}.wo_a.weight", + data_type=self.data_type_, + quant_method=None, + ) + self.wo_b_ = COLMMWeight( + in_dim=self.o_groups * self.o_lora_rank, + out_dims=[self.hidden], + weight_names=f"{p}.wo_b.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("wo_b"), + ) + + # ------------------------------------------------------------------ compressor / indexer + def _init_compressor(self): + prefix = f"{self.prefix}.attn.compressor" + head_dim = self.head_dim + ratio = self.compress_ratio + + coff = 2 if ratio == 4 else 1 + # wkv/wgate are bf16 (no scale) and replicated (single KV head). + self.compressor_wkv_gate_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * head_dim, coff * head_dim], + weight_names=[f"{prefix}.wkv.weight", f"{prefix}.wgate.weight"], + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.compressor_norm_ = RMSNormWeight( + dim=head_dim, weight_name=f"{prefix}.norm.weight", data_type=self.data_type_ + ) + self.compressor_ape_ = ParameterWeight( + weight_name=f"{prefix}.ape", data_type=torch.float32, weight_shape=(ratio, coff * head_dim) + ) + + def _init_indexer(self): + p = f"{self.prefix}.attn.indexer" + # The Lightning-Indexer is REPLICATED across TP ranks (like sglang/vllm), not head-sharded: + # q_lora and the attn input are already full on every rank, so each rank scores all + # index_n_heads locally and the c4 top-k is identical everywhere -- no gather/all_reduce. + # wq_b is FP8 in the checkpoint -> de-quantized to bf16 at load. + self.idx_wq_b_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{p}.wq_b.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("idx_wq_b"), + tp_rank=0, + tp_world_size=1, + ) + self.idx_weights_proj_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.index_n_heads], + weight_names=f"{p}.weights_proj.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + coff = 2 # indexer compressor always uses ratio 4 (overlap) + self.idx_cmp_wkv_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * self.index_head_dim], + weight_names=f"{p}.compressor.wkv.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.idx_cmp_wgate_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[coff * self.index_head_dim], + weight_names=f"{p}.compressor.wgate.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.idx_cmp_norm_ = RMSNormWeight( + dim=self.index_head_dim, weight_name=f"{p}.compressor.norm.weight", data_type=self.data_type_ + ) + self.idx_cmp_ape_ = ParameterWeight( + weight_name=f"{p}.compressor.ape", data_type=torch.float32, weight_shape=(4, coff * self.index_head_dim) + ) + + # ------------------------------------------------------------------ moe + def _init_moe(self): + p = f"{self.prefix}.ffn" + # router gate (replicated). Stored as fp32: the topk_hash_softplus_sqrt router wants fp32 logits, + # so keep the gate matmul in fp32 — but store the (constant) weight as fp32 once here instead of + # re-casting it to fp32 on every forward in _ffn. + self.gate_weight_ = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.n_routed_experts], + weight_names=f"{p}.gate.weight", + data_type=torch.float32, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + if self.is_hash: + self.gate_tid2eid_ = ParameterWeight( + weight_name=f"{p}.gate.tid2eid", + data_type=torch.int64, + weight_shape=(self.vocab_size, self.network_config_["num_experts_per_tok"]), + ) + else: + self.gate_bias_ = ParameterWeight( + weight_name=f"{p}.gate.bias", data_type=torch.float32, weight_shape=(self.n_routed_experts,) + ) + # shared expert (dense, bf16 after de-quant): w1=gate, w3=up fused (row), w2=down (col). + # Named gate_up_proj/down_proj so the inherited Llama `_ffn_tp` (fused gate_up matmul + + # silu_and_mul triton kernel, no swiglu clamp) drives it directly. Order [w1, w3] = [gate, up] + # matches silu_and_mul_fwd's blocked layout (first half gate, second half up). + sp = f"{p}.shared_experts" + self.gate_up_proj = ROWMMWeight( + in_dim=self.hidden, + out_dims=[self.moe_inter, self.moe_inter], + weight_names=[f"{sp}.w1.weight", f"{sp}.w3.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_gate"), + ) + self.down_proj = COLMMWeight( + in_dim=self.moe_inter, + out_dims=[self.hidden], + weight_names=f"{sp}.w2.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("shared_down"), + ) + self.experts_ = FusedMoeWeight( + gate_proj_name="w1", + down_proj_name="w2", + up_proj_name="w3", + e_score_correction_bias_name="", + weight_prefix=f"{p}.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.hidden, + moe_intermediate_size=self.moe_inter, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) + + def _init_norm(self): + self.attn_norm_ = RMSNormWeight( + dim=self.hidden, weight_name=f"{self.prefix}.attn_norm.weight", data_type=self.data_type_ + ) + self.ffn_norm_ = RMSNormWeight( + dim=self.hidden, weight_name=f"{self.prefix}.ffn_norm.weight", data_type=self.data_type_ + ) + + def _init_hyper_connection(self): + p = self.prefix + self.hc_attn_fn_ = ParameterWeight( + weight_name=f"{p}.hc_attn_fn", + data_type=torch.float32, + weight_shape=(self.mix_hc, self.hc_mult * self.hidden), + ) + self.hc_attn_base_ = ParameterWeight( + weight_name=f"{p}.hc_attn_base", data_type=torch.float32, weight_shape=(self.mix_hc,) + ) + self.hc_attn_scale_ = ParameterWeight( + weight_name=f"{p}.hc_attn_scale", data_type=torch.float32, weight_shape=(3,) + ) + self.hc_ffn_fn_ = ParameterWeight( + weight_name=f"{p}.hc_ffn_fn", + data_type=torch.float32, + weight_shape=(self.mix_hc, self.hc_mult * self.hidden), + ) + self.hc_ffn_base_ = ParameterWeight( + weight_name=f"{p}.hc_ffn_base", data_type=torch.float32, weight_shape=(self.mix_hc,) + ) + self.hc_ffn_scale_ = ParameterWeight( + weight_name=f"{p}.hc_ffn_scale", data_type=torch.float32, weight_shape=(3,) + ) + + # ------------------------------------------------------------------ loading + def load_hf_weights(self, weights): + self._dequant_in_place(weights) + return super().load_hf_weights(weights) + + def _fp8_scale_renames(self): + """Map weight name -> the scale name its quant method loads (e.g. `weight_scale_inv` + for DeepGEMM). Read from each MM weight's own `weight_scale_names`, so the rename + target always matches what that weight will look up; no-quant weights have None + entries and are skipped.""" + renames = {} + for attr in self.__dict__.values(): + weight_names = getattr(attr, "weight_names", ()) + scale_names = getattr(attr, "weight_scale_names", ()) + for weight_name, scale_name in zip(weight_names, scale_names): + if scale_name is not None: + renames[weight_name] = scale_name + return renames + + def _dequant_in_place(self, weights): + p = self.prefix + "." + scale_renames = self._fp8_scale_renames() + # Convert every `.scale` belonging to this layer. Weights are loaded incrementally + # per safetensors shard, so the paired weight may live in another shard: + # - routed expert `.scale` follows the fused_moe quant method's weight_scale_suffix: + # MXFP4 consumes `.scale` as-is, FP8 DeepGEMM expects `.weight_scale_inv` (rename only); + # - FP8 matmul scales only need renaming for DeepGEMM, no weight required; + # - FP8 pairs on no-quant paths (wo_a's ROWBMMWeight) are expanded to bf16, + # the only case that truly requires weight and scale in the same shard. + expert_scale_suffix = self.experts_.quant_method.weight_scale_suffix + for scale_k in [k for k in list(weights.keys()) if k.startswith(p) and k.endswith(".scale")]: + if scale_k.startswith(f"{p}ffn.experts."): + if expert_scale_suffix is not None and expert_scale_suffix != "scale": + weights[scale_k[: -len("scale")] + expert_scale_suffix] = weights[scale_k].to(torch.float32) + del weights[scale_k] + continue + k = scale_k[: -len(".scale")] + ".weight" + target = scale_renames.get(k) + if target is not None: # FP8 e4m3, block-128 scale, run by DeepGEMM directly + weights[target] = weights[scale_k].to(torch.float32) + del weights[scale_k] + else: + weights[k] = dequant_fp8_block_to_bf16(weights[k], weights[scale_k]).to(self.data_type_) + del weights[scale_k] + # grouped-O: reshape [groups*o_lora, in] -> [groups, in, o_lora] for the batched matmul + woa = f"{self.prefix}.attn.wo_a.weight" + if woa in weights and weights[woa].dim() == 2: + w = weights[woa] + per_group_in = self.n_heads * self.head_dim // self.o_groups + weights[woa] = w.view(self.o_groups, self.o_lora_rank, per_group_in).transpose(1, 2).contiguous() + return diff --git a/lightllm/models/deepseek_v4/model.py b/lightllm/models/deepseek_v4/model.py new file mode 100644 index 0000000000..887e72f433 --- /dev/null +++ b/lightllm/models/deepseek_v4/model.py @@ -0,0 +1,279 @@ +import copy +import importlib.util +import os + +import torch +from lightllm.models.registry import ModelRegistry +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.common.req_manager import DeepseekV4ReqManager +from lightllm.common.kv_cache_mem_manager import DeepseekV4MemoryManager +from lightllm.models.deepseek_v4.layer_weights.pre_and_post_layer_weight import ( + DeepseekV4PreAndPostLayerWeight, +) +from lightllm.models.deepseek_v4.layer_weights.transformer_layer_weight import ( + DeepseekV4TransformerLayerWeight, +) +from lightllm.models.deepseek_v4.layer_infer.pre_layer_infer import ( + DeepseekV4PreLayerInfer, +) +from lightllm.models.deepseek_v4.layer_infer.post_layer_infer import ( + DeepseekV4PostLayerInfer, +) +from lightllm.models.deepseek_v4.layer_infer.transformer_layer_infer import ( + DeepseekV4TransformerLayerInfer, +) +from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class +from lightllm.models.deepseek_v4.infer_struct import DeepseekV4InferStateInfo +from lightllm.models.llama.yarn_rotary_utils import ( + find_correction_range, + linear_ramp_mask, +) +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager + +logger = init_logger(__name__) +DSV4_DECODE_CUDAGRAPH_MAX_LEN = 8192 + + +@ModelRegistry("deepseek_v4") +class DeepseekV4TpPartModel(LlamaTpPartModel): + req_manager: DeepseekV4ReqManager + mem_manager: DeepseekV4MemoryManager + + pre_and_post_weight_class = DeepseekV4PreAndPostLayerWeight + transformer_weight_class = DeepseekV4TransformerLayerWeight + + pre_layer_infer_class = DeepseekV4PreLayerInfer + post_layer_infer_class = DeepseekV4PostLayerInfer + transformer_layer_infer_class = DeepseekV4TransformerLayerInfer + + infer_state_class = DeepseekV4InferStateInfo + + def _verify_params(self): + assert self.load_way == "HF", "only support HF format weights" + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + assert self.config["o_groups"] % self.tp_world_size_ == 0 + assert self.config["index_n_heads"] % self.tp_world_size_ == 0 + return + + def _init_req_manager(self): + create_max_seq_len = 0 + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self._dsv4_req_manager_seq_len = create_max_seq_len + layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num() + self._dsv4_compress_rates = self._get_compress_rates(layer_num) + self.req_manager = DeepseekV4ReqManager( + self.max_req_num, + create_max_seq_len, + compress_rates=self._dsv4_compress_rates, + head_dim=self.config["head_dim"], + indexer_head_dim=self.config["index_head_dim"], + sliding_window=self.config["sliding_window"], + ) + return + + def _get_compress_rates(self, layer_num): + rates = list(self.config["compress_ratios"]) + return rates[:layer_num] + + def _init_mem_manager(self): + layer_num = self.config["n_layer"] + get_added_mtp_kv_layer_num() + compress_rates = getattr(self, "_dsv4_compress_rates", self._get_compress_rates(layer_num)) + sliding_window = int(self.config["sliding_window"]) + # 活跃窗口之外的 swa 余量: 在途 prefill chunk 的瞬时占用(出窗槽位到下一次 prep 才回收) + # + radix cache 持有的窗口尾部(每条缓存序列约一个 window)。 + swa_extra_token_num = int(self.batch_max_tokens or 0) + self.max_req_num * sliding_window + self.mem_manager = DeepseekV4MemoryManager( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["head_dim"], + layer_num=layer_num, + compress_rates=compress_rates, + indexer_head_dim=self.config["index_head_dim"], + max_request_num=self.max_req_num, + sliding_window=sliding_window, + swa_extra_token_num=swa_extra_token_num, + mem_fraction=self.mem_fraction, + ) + assert isinstance(self.req_manager, DeepseekV4ReqManager) + self.req_manager.bind_mem_manager(self.mem_manager) + return + + def _init_cudagraph(self): + if not self.disable_cudagraph and self.graph_max_len_in_batch > DSV4_DECODE_CUDAGRAPH_MAX_LEN: + logger.info( + "DeepSeek-V4 caps decode cudagraph max_len_in_batch from %s to %s for the current " + "graph-safe sparse-attention path; longer decode batches run eager.", + self.graph_max_len_in_batch, + DSV4_DECODE_CUDAGRAPH_MAX_LEN, + ) + self.graph_max_len_in_batch = DSV4_DECODE_CUDAGRAPH_MAX_LEN + return super()._init_cudagraph() + + def _init_att_backend(self): + args = get_env_start_args() + if args.llm_kv_type == "None": + args.llm_kv_type = "fp8kv_dsa" + # TODO: 支持其他 kv type + if args.llm_kv_type != "fp8kv_dsa": + raise RuntimeError("DeepSeek-V4 requires llm_kv_type=fp8kv_dsa for packed FlashMLA sparse attention") + self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) + return + + def _init_custom(self): + self._init_to_get_rotary() + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) + return + + def _init_to_get_rotary(self): + # Interleaved (GPT-J) rope. Build complex64 freqs_cis tables (_freqs_cis_*) following the + # gemma4 two-variant convention; the fused sglang q kernel consumes them directly, while + # _cos_cached_*/_sin_cached_* are .real/.imag views of the same storage for the kv rope, + # inverse rope and compressor paths (deepseek2's interleaved triton rotary_emb_fwd). + # Sliding-window layers use base rope_theta (no YaRN); + # compressed (CSA/HCA) layers use compress_rope_theta with configured rope_scaling. + # Kept fp32 for accuracy (the apply upcasts anyway). + cfg = self.config + rs = cfg.get("rope_scaling", {}) or {} + dim = cfg["qk_rope_head_dim"] + # The rope tables MUST span every absolute position any request can produce (the served + # max_req_total_len / max_position_embeddings, up to 1M). Capping them shorter makes + # init_some_extra_state's index_select(cos/sin, position_ids) read OOB past the table at + # contexts beyond the cap (device-side assert / crash). ~268MB total at 1M, fp32x32 x4 views. + max_seq = max(int(self.max_seq_length), int(cfg.get("max_position_embeddings", 8192))) + freq_exponents = torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim + positions = torch.arange(max_seq, dtype=torch.float32, device="cuda") + + sliding_freqs = 1.0 / (cfg["rope_theta"] ** freq_exponents) + f = torch.outer(positions, sliding_freqs) # [max_seq, dim//2] + self._freqs_cis_sliding = torch.complex(f.cos(), f.sin()) + + compress_freqs = 1.0 / (cfg["compress_rope_theta"] ** freq_exponents) + rope_type = rs.get("rope_type", rs.get("type", "default")) + orig_max = rs.get("original_max_position_embeddings", 0) + if rope_type == "yarn" and orig_max > 0: + beta_fast = rs.get("beta_fast", 32) + beta_slow = rs.get("beta_slow", 1) + factor = rs.get("factor", 1) + if factor is None: + factor = cfg.get("max_position_embeddings", max_seq) / orig_max + low, high = find_correction_range(beta_fast, beta_slow, dim, cfg["compress_rope_theta"], orig_max) + smooth = 1 - linear_ramp_mask(low, high, dim // 2).cuda() + compress_freqs = compress_freqs / factor * (1 - smooth) + compress_freqs * smooth + f = torch.outer(positions, compress_freqs) # [max_seq, dim//2] + self._freqs_cis_compress = torch.complex(f.cos(), f.sin()) + self._cos_cached_sliding = self._freqs_cis_sliding.real + self._sin_cached_sliding = self._freqs_cis_sliding.imag + self._cos_cached_compress = self._freqs_cis_compress.real + self._sin_cached_compress = self._freqs_cis_compress.imag + # Each layer uses exactly one rope variant; wire its table once here (layers are already + # built: _init_infer_layer runs before _init_custom) instead of relaying via infer_state. + # The compressor needs the full compress tables (entry rope positions != token positions). + for layer in self.layers_infer: + layer.freqs_cis = self._freqs_cis_compress if layer.compress_ratio else self._freqs_cis_sliding + layer.cos_compress_table = self._cos_cached_compress + layer.sin_compress_table = self._sin_cached_compress + return + + +class DeepSeekV4Tokenizer: + """Tokenizer wrapper for DeepSeek-V4's Python prompt encoding.""" + + # DeepSeek-V4 has a per-request thinking mode (...) toggled via + # chat_template_kwargs={"thinking": true}. It has no Jinja chat_template string, + # so advertise thinking support explicitly for tokenizer_supports_force_thinking(). + supports_thinking = True + + def __init__(self, tokenizer, model_dir): + self.tokenizer = tokenizer + self.model_dir = model_dir + self._encoding_module = None + self._added_vocab = None + + def __getattr__(self, name): + return getattr(self.tokenizer, name) + + def get_added_vocab(self): + if self._added_vocab is None: + self._added_vocab = self.tokenizer.get_added_vocab() + return self._added_vocab + + def _get_encoding_module(self): + if self._encoding_module is not None: + return self._encoding_module + + encoding_path = os.path.join(self.model_dir, "encoding", "encoding_dsv4.py") + if not os.path.exists(encoding_path): + raise FileNotFoundError(f"DeepSeek-V4 encoding file not found: {encoding_path}") + + spec = importlib.util.spec_from_file_location("lightllm_deepseek_v4_encoding_dsv4", encoding_path) + if spec is None or spec.loader is None: + raise ImportError(f"failed to load DeepSeek-V4 encoding module from {encoding_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self._encoding_module = module + return module + + def apply_chat_template( + self, + conversation=None, + messages=None, + tools=None, + tokenize=False, + add_generation_prompt=True, + thinking=None, + enable_thinking=None, + **kwargs, + ): + msgs = conversation if conversation is not None else messages + if msgs is None: + raise ValueError("Either 'conversation' or 'messages' must be provided") + + msgs = copy.deepcopy(msgs) + + if tools: + wrapped_tools = [] + for tool in tools: + if "function" in tool: + wrapped_tools.append(tool) + else: + wrapped_tools.append({"type": "function", "function": tool}) + + injected = False + for msg in msgs: + if msg.get("role") == "system": + existing = msg.get("tools") or [] + msg["tools"] = existing + wrapped_tools + injected = True + break + + if not injected: + msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) + + if thinking is None: + thinking = bool(enable_thinking) if enable_thinking is not None else False + thinking_mode = "thinking" if thinking else "chat" + encoding = self._get_encoding_module() + prompt = encoding.encode_messages( + msgs, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + add_default_bos_token=kwargs.get("add_default_bos_token", True), + reasoning_effort=kwargs.get("reasoning_effort"), + ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt diff --git a/lightllm/models/deepseek_v4/triton_kernel/__init__.py b/lightllm/models/deepseek_v4/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py new file mode 100644 index 0000000000..b09192498d --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/build_compress_index_dsv4.py @@ -0,0 +1,78 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_compress_index_kernel( + req_idx_ptr, + pos_ptr, + req_to_token_ptr, + req_to_token_stride0, + full_to_c_ptr, + index_ptr, + length_ptr, + cap, + RATIO: tl.constexpr, + BLOCK_E: tl.constexpr, +): + t = tl.program_id(0) + eb = tl.program_id(1) + req = tl.load(req_idx_ptr + t).to(tl.int64) + pos = tl.load(pos_ptr + t).to(tl.int64) + raw_len = (pos + 1) // RATIO + + e = eb * BLOCK_E + tl.arange(0, BLOCK_E) + e_mask = e < cap + valid = (e < raw_len) & e_mask + # group-end token of compressed entry e: position e*RATIO + (RATIO-1). + end_pos = e * RATIO + (RATIO - 1) + safe_pos = tl.where(valid, end_pos, 0) + full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_pos, mask=valid, other=0).to(tl.int64) + c_slot = tl.load(full_to_c_ptr + full_slot, mask=valid, other=-1).to(tl.int32) + tl.store(index_ptr + t * cap + e, c_slot, mask=e_mask) + + if eb == 0: + tl.store(length_ptr + t, tl.maximum(raw_len, 1).to(tl.int32)) + + +def build_compress_index( + req_idx: torch.Tensor, + positions: torch.Tensor, + req_to_token_indexs: torch.Tensor, + full_to_c_indexs: torch.Tensor, + ratio: int, + cap: int, +): + """Fused two-level group-end gather for the c4/c128 compressed-entry index tables. + + For token t (at request `req_idx[t]`, absolute `positions[t]`) and compressed entry e: + slot[t, e] = full_to_c[ req_to_token[req, e*ratio + (ratio-1)] ] (the group-end token's full slot) + with slot = -1 where e >= (pos+1)//ratio (beyond the causal compressed length) or where the + full->c map is unset. Returns (index [T, cap] int32, length [T] int32 = clamp((pos+1)//ratio, 1)). + + Replaces the eager _gather_compress_slots/_c128/c4-causal torch chain. `cap` must be a multiple of + 64 (FlashMLA topk alignment); the tiled grid (T, ceil(cap/BLOCK_E)) scales to 1M-context caps. + cuda-graph-safe: cap is fixed per graph bucket, shapes static. + """ + T = positions.shape[0] + index = torch.empty((T, cap), dtype=torch.int32, device=positions.device) + length = torch.empty((T,), dtype=torch.int32, device=positions.device) + if T == 0: + return index, length + BLOCK_E = 256 + grid = (T, triton.cdiv(cap, BLOCK_E)) + _build_compress_index_kernel[grid]( + req_idx, + positions, + req_to_token_indexs, + req_to_token_indexs.stride(0), + full_to_c_indexs, + index, + length, + cap, + RATIO=ratio, + BLOCK_E=BLOCK_E, + num_warps=4, + ) + return index, length diff --git a/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py new file mode 100644 index 0000000000..e5b7b80bb4 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/build_swa_index_dsv4.py @@ -0,0 +1,75 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_swa_index_kernel( + req_idx_ptr, + pos_ptr, + req_to_token_ptr, + req_to_token_stride0, + full_to_swa_ptr, + swa_index_ptr, + swa_length_ptr, + WINDOW: tl.constexpr, + BLOCK_W: tl.constexpr, +): + token_idx = tl.program_id(0) + req = tl.load(req_idx_ptr + token_idx).to(tl.int64) + pos = tl.load(pos_ptr + token_idx).to(tl.int64) + + w = tl.arange(0, BLOCK_W) + w_mask = w < WINDOW + # most-recent-first window, identical to the eager _swa_indices (offset = position - arange). + offset = pos - w + valid = (offset >= 0) & w_mask + safe_offset = tl.where(valid, offset, 0) + full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + safe_offset, mask=valid, other=0).to(tl.int64) + swa_slot = tl.load(full_to_swa_ptr + full_slot, mask=valid, other=-1) + out = tl.where(valid, swa_slot, -1).to(tl.int32) + tl.store(swa_index_ptr + token_idx * WINDOW + w, out, mask=w_mask) + + length = tl.minimum(tl.maximum(pos + 1, 1), WINDOW).to(tl.int32) + tl.store(swa_length_ptr + token_idx, length) + + +def build_swa_index( + req_idx: torch.Tensor, + positions: torch.Tensor, + req_to_token_indexs: torch.Tensor, + full_to_swa_indexs: torch.Tensor, + window: int, +): + """Per-token sliding-window FlashMLA index table, built ONCE per forward (layer-independent: + full_to_swa is a single global map and the window is a model constant, so every layer's swa + indices are identical). Replaces DeepseekV4IndexInfer._swa_indices: for token t at + (req_idx, position) gather the last `window` tokens' full slots via req_to_token, then map + full -> swa; out-of-range positions store -1. + + Returns (swa_index [T, window] int32, swa_length [T] int32). `window` is 128 (a multiple of the + FlashMLA 64 alignment) so no extra pad is needed; the reader adds the s_q axis via unsqueeze(1). + Const output shape (no max_kv_seq_len dependence) makes this cuda-graph-safe to stage from + init_some_extra_state via copy_for_cuda_graph. + """ + # window must stay 64-aligned: the output is the FlashMLA `indices` tensor directly (no separate + # _pad_last_dim), and the extra-cache fork requires the topk dim to be a multiple of 64. + assert window % 64 == 0, f"DeepSeek-V4 sliding_window must be a multiple of 64 for FlashMLA, got {window}" + T = positions.shape[0] + swa_index = torch.empty((T, window), dtype=torch.int32, device=positions.device) + swa_length = torch.empty((T,), dtype=torch.int32, device=positions.device) + if T == 0: + return swa_index, swa_length + _build_swa_index_kernel[(T,)]( + req_idx, + positions, + req_to_token_indexs, + req_to_token_indexs.stride(0), + full_to_swa_indexs, + swa_index, + swa_length, + WINDOW=window, + BLOCK_W=triton.next_power_of_2(window), + num_warps=4, + ) + return swa_index, swa_length diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py new file mode 100644 index 0000000000..3510b92c30 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_indexer_k_dsv4.py @@ -0,0 +1,92 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_k_dsv4( + K, + Dest_loc, + O_fp8, + O_f32, + stride_k_bs, + stride_k_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + SCALE_MIN: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + # negative dest (unmapped slot) is a no-op, not an OOB write into a neighboring page. + if dest_index < 0: + return + + page = dest_index // PAGE_SIZE + token_in_page = dest_index % PAGE_SIZE + + offs_d = tl.arange(0, HEAD_DIM) + vals = tl.load(K + cur_index * stride_k_bs + offs_d * stride_k_d).to(tl.float32) + amax = tl.max(tl.abs(vals), axis=0) + # per-token plain fp32 scale (not ue8m0), matching DeepseekV4MemoryManager._pack_indexer_k + scale = tl.maximum(amax / FP8_MAX, SCALE_MIN) + k_fp8 = tl.clamp(vals / scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + data_base = page * BYTES_PER_PAGE + token_in_page * HEAD_DIM + tl.store(O_fp8 + data_base + offs_d, k_fp8) + scale_idx = (page * BYTES_PER_PAGE + PAGE_SIZE * HEAD_DIM) // 4 + token_in_page + tl.store(O_f32 + scale_idx, scale) + return + + +@torch.no_grad() +def destindex_copy_indexer_k_dsv4( + K: torch.Tensor, + DestLoc: torch.Tensor, + O_buffer: torch.Tensor, + page_size: int, +): + """Packed indexer-K page-slab writer (DeepSeek-V4 c4/CSA layers). + + K: [T, 128] bf16 unquantized indexer keys. + DestLoc: [T] int — c4-pool-local token slots; must already be allocated by the caller. + Negative slots (unmapped) are skipped. + O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from the c4 indexer + PackedPagePool (128B fp8 data region + 4B fp32 scale tail per token). + + Bit-compatible with DeepseekV4MemoryManager._pack_indexer_k + PackedPagePool.write. + """ + seq_len = DestLoc.shape[0] + if seq_len == 0: + return + head_dim, scale_bytes = 128, 4 + + K = K.reshape(-1, head_dim) + assert K.shape[0] == seq_len, f"Expected K shape[0]={seq_len}, got {K.shape[0]}" + assert K.dtype == torch.bfloat16, f"Expected bf16 indexer K, got {K.dtype}" + bytes_per_page = O_buffer.shape[-1] + assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous() + assert bytes_per_page % 4 == 0 + assert bytes_per_page >= page_size * (head_dim + scale_bytes) + + flat = O_buffer.view(-1) + _fwd_kernel_destindex_copy_indexer_k_dsv4[(seq_len,)]( + K, + DestLoc, + flat.view(torch.float8_e4m3fn), + flat.view(torch.float32), + K.stride(0), + K.stride(1), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + SCALE_MIN=1e-4, + HEAD_DIM=head_dim, + PAGE_SIZE=page_size, + BYTES_PER_PAGE=bytes_per_page, + num_warps=1, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py new file mode 100644 index 0000000000..a3ec6ed8cf --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/destindex_copy_kv_flashmla_dsv4.py @@ -0,0 +1,121 @@ +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_flashmla_dsv4( + KV, + Dest_loc, + O_fp8, + O_bf16, + O_u8, + stride_kv_bs, + stride_kv_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + SCALE_MIN: tl.constexpr, + NOPE_DIM: tl.constexpr, + ROPE_DIM: tl.constexpr, + GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, + SCALE_BYTES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + # negative dest (unmapped slot, e.g. full_to_c* rows that never closed a group) is a no-op, + # not an OOB write into a neighboring page. + if dest_index < 0: + return + + page = dest_index // PAGE_SIZE + token_in_page = dest_index % PAGE_SIZE + data_base = page * BYTES_PER_PAGE + token_in_page * (NOPE_DIM + ROPE_DIM * 2) + scale_base = page * BYTES_PER_PAGE + PAGE_SIZE * (NOPE_DIM + ROPE_DIM * 2) + token_in_page * SCALE_BYTES + + # nope: per-group ue8m0 quant. SCALE_BYTES(=NUM_GROUPS+1) lanes cover the exponent bytes + # plus the trailing zero pad byte in one store. libdevice.log2 (not tl.log2, which is the + # approx instruction) and the bit-packed 2**e keep this bit-exact with the torch oracle + # DeepseekV4MemoryManager._pack_mla_kv. + offs_g = tl.arange(0, SCALE_BYTES) + offs_e = tl.arange(0, GROUP_SIZE) + group_mask = offs_g < NUM_GROUPS + kv_ptrs = KV + cur_index * stride_kv_bs + (offs_g[:, None] * GROUP_SIZE + offs_e[None, :]) * stride_kv_d + vals = tl.load(kv_ptrs, mask=group_mask[:, None], other=0.0).to(tl.float32) + amax = tl.max(tl.abs(vals), axis=1) + scale_exp = tl.ceil(libdevice.log2(tl.maximum(amax / FP8_MAX, SCALE_MIN))).to(tl.int32) + scale = ((scale_exp + 127) << 23).to(tl.float32, bitcast=True) + kv_fp8 = tl.clamp(vals / scale[:, None], min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + tl.store(O_fp8 + data_base + offs_g[:, None] * GROUP_SIZE + offs_e[None, :], kv_fp8, mask=group_mask[:, None]) + scale_bytes = tl.where(group_mask, scale_exp + 127, 0).to(tl.uint8) + tl.store(O_u8 + scale_base + offs_g, scale_bytes) + + # rope: bf16 passthrough into the data region right after the nope bytes + offs_r = tl.arange(0, ROPE_DIM) + rope = tl.load(KV + cur_index * stride_kv_bs + (NOPE_DIM + offs_r) * stride_kv_d) + tl.store(O_bf16 + (data_base + NOPE_DIM) // 2 + offs_r, rope) + return + + +@torch.no_grad() +def destindex_copy_kv_flashmla_dsv4( + KV: torch.Tensor, + DestLoc: torch.Tensor, + O_buffer: torch.Tensor, + page_size: int, +): + """fp8_ds_mla packed page-slab writer (DeepSeek-V4 ABI, all latent pools). + + KV: [T, 512] bf16 — 448 normed-latent dims + 64 rope'd dims per token. + DestLoc: [T] int — pool-local token slots (page = slot // page_size); the pool HOLD slot is + a valid in-bounds row, negative slots (unmapped) are skipped. Slots must already be + resolved/allocated by the caller. + O_buffer: [num_pages, bytes_per_page] uint8 — one layer's slab from PackedPagePool + (swa page=128 / c4 page=64 / c128 page=2 all share this kernel). + + Per token: 448B fp8(e4m3) in 7x64 ue8m0 groups + 128B bf16 rope in the page data region; + 7 exponent bytes (e+127) + 1 zero pad at the page scale tail. Bit-compatible with + DeepseekV4MemoryManager._pack_mla_kv + PackedPagePool.write. + """ + seq_len = DestLoc.shape[0] + if seq_len == 0: + return + nope_dim, rope_dim, group_size = 448, 64, 64 + head_dim = nope_dim + rope_dim + scale_bytes = nope_dim // group_size + 1 + + KV = KV.reshape(-1, head_dim) + assert KV.shape[0] == seq_len, f"Expected KV shape[0]={seq_len}, got {KV.shape[0]}" + assert KV.dtype == torch.bfloat16, f"Expected bf16 KV (rope bytes are stored as-is), got {KV.dtype}" + bytes_per_page = O_buffer.shape[-1] + assert O_buffer.dtype == torch.uint8 and O_buffer.is_contiguous() + assert bytes_per_page % 2 == 0 + assert bytes_per_page >= page_size * (nope_dim + rope_dim * 2 + scale_bytes) + + flat = O_buffer.view(-1) + _fwd_kernel_destindex_copy_kv_flashmla_dsv4[(seq_len,)]( + KV, + DestLoc, + flat.view(torch.float8_e4m3fn), + flat.view(torch.bfloat16), + flat, + KV.stride(0), + KV.stride(1), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + SCALE_MIN=1e-4, + NOPE_DIM=nope_dim, + ROPE_DIM=rope_dim, + GROUP_SIZE=group_size, + NUM_GROUPS=nope_dim // group_size, + SCALE_BYTES=scale_bytes, + PAGE_SIZE=page_size, + BYTES_PER_PAGE=bytes_per_page, + num_warps=4, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py new file mode 100644 index 0000000000..a7a0a4be85 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/gather_c4_indexer_k_dsv4.py @@ -0,0 +1,200 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _gather_c4_indexer_k_kernel( + req_idx_ptr, # [batch] int — req_manager slot per batch position + c4_len_ptr, # [batch] int — number of causal c4 entries per request (= seq_len // ratio) + req_to_token_ptr, + req_to_token_stride0, + full_to_c4_ptr, + SlabFp8_ptr, # c4 indexer pool, viewed as fp8 (flat) + SlabF32_ptr, # same pool, viewed as f32 (flat) + Kout_fp8_ptr, # [batch*c4_cap, HEAD_DIM] fp8 + Kout_scale_ptr, # [batch*c4_cap] f32 + Slots_out_ptr, # [batch*c4_cap] int32 (compact->c4-slot map; -1 for padding) + c4_cap, + RATIO: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BYTES_PER_PAGE: tl.constexpr, + SCALE_OFFSET: tl.constexpr, # page_size * head_dim (byte offset of the scale tail) +): + # entry index on grid-X (limit ~2^31), batch on grid-Y (<= running_max_req_size): c4_cap reaches + # 65536 at 256K context, which would blow the 65535 grid-Y cap if entries were the grid-Y axis. + e = tl.program_id(0) + r = tl.program_id(1) + # int64: out_pos*HEAD_DIM can exceed int32 at high batch + long context (read side is already int64). + out_pos = r.to(tl.int64) * c4_cap + e + c4_len = tl.load(c4_len_ptr + r) + if e >= c4_len: + # padding entry: mark slot invalid; K is never read (ke bounds the scorer range). + tl.store(Slots_out_ptr + out_pos, -1) + return + + # group-end token of compressed entry e lives at position e*RATIO + (RATIO-1). + req = tl.load(req_idx_ptr + r).to(tl.int64) + end_tok = e * RATIO + (RATIO - 1) + full_slot = tl.load(req_to_token_ptr + req * req_to_token_stride0 + end_tok).to(tl.int64) + c4_slot = tl.load(full_to_c4_ptr + full_slot).to(tl.int64) + valid = c4_slot >= 0 + + # inline PackedPagePool byte addressing (matches destindex_copy_indexer_k_dsv4 / gather_indexer_k): + # fp8 K at page*bytes_per_page + tok*head_dim; fp32 scale at (page*bytes_per_page + scale_off)//4 + tok. + page = c4_slot // PAGE_SIZE + tok = c4_slot % PAGE_SIZE + data_base = page * BYTES_PER_PAGE + tok * HEAD_DIM + scale_base = (page * BYTES_PER_PAGE + SCALE_OFFSET) // 4 + tok + + offs_d = tl.arange(0, HEAD_DIM) + k_fp8 = tl.load(SlabFp8_ptr + data_base + offs_d, mask=valid, other=0.0) + k_scale = tl.load(SlabF32_ptr + scale_base, mask=valid, other=0.0) + tl.store(Kout_fp8_ptr + out_pos * HEAD_DIM + offs_d, k_fp8) + tl.store(Kout_scale_ptr + out_pos, k_scale) + tl.store(Slots_out_ptr + out_pos, tl.where(valid, c4_slot, -1).to(tl.int32)) + + +@torch.no_grad() +def gather_c4_indexer_k_ragged( + mem_manager, + layer_index: int, + b_req_idx: torch.Tensor, + c4_len: torch.Tensor, + c4_cap: int, + req_to_token_indexs: torch.Tensor, +): + """Gather each request's causal c4 indexer keys into a padded-per-request ragged buffer for the + deep_gemm fp8_mqa_logits scorer (mirrors deepseek3_2's extract_indexer_ks, but reads our + PackedPagePool by c4 slot instead of a token-indexed [N,1,132] buffer). + + For batch position r and compressed entry e in [0, c4_len[r]): + c4_slot = full_to_c4[req_to_token[b_req_idx[r], e*ratio + (ratio-1)]] + The raw fp8 key + f32 scale at that slot land at row r*c4_cap + e of the output (so query token t + of request r reads keys [r*c4_cap, r*c4_cap + (pos+1)//ratio) -- absolute ks/ke offsets the caller + builds). Returns (k_fp8 [batch*c4_cap, HEAD_DIM] fp8, k_scale [batch*c4_cap] f32, slots + [batch*c4_cap] int32 = compact-row -> c4 pool slot, -1 for padding). Fixed shapes -> cuda-graph + safe (c4_cap is pinned per graph bucket); the padding region is never read by the scorer. + """ + pool = mem_manager.c4_indexer_pool + head_dim = mem_manager.indexer_head_dim + buf = pool.get_layer_buffer(mem_manager.layer_to_c4_idx[layer_index]).view(-1) + slab_fp8 = buf.view(torch.float8_e4m3fn) + slab_f32 = buf.view(torch.float32) + batch = b_req_idx.shape[0] + n = batch * c4_cap + k_fp8 = torch.empty((n, head_dim), dtype=torch.float8_e4m3fn, device=buf.device) + k_scale = torch.empty((n,), dtype=torch.float32, device=buf.device) + slots = torch.empty((n,), dtype=torch.int32, device=buf.device) + _gather_c4_indexer_k_kernel[(c4_cap, batch)]( + b_req_idx, + c4_len, + req_to_token_indexs, + req_to_token_indexs.stride(0), + mem_manager.full_to_c4_indexs, + slab_fp8, + slab_f32, + k_fp8, + k_scale, + slots, + c4_cap, + RATIO=4, + HEAD_DIM=head_dim, + PAGE_SIZE=pool.page_size, + BYTES_PER_PAGE=pool.bytes_per_page, + SCALE_OFFSET=pool.scale_offset_in_page, + num_warps=1, + ) + return k_fp8, k_scale, slots + + +@triton.jit +def _build_c4_indexer_page_table_kernel( + req_idx_ptr, # [batch] int + c4_len_ptr, # [batch] int + req_to_token_ptr, + req_to_token_stride0, + full_to_c4_ptr, + page_table_ptr, # [batch, page_cap] int32 + valid_flag_ptr, # [1] int32, initialized to 1; set to 0 on layout mismatch + page_cap, + hold_req_id, + RATIO: tl.constexpr, + PAGE_SIZE: tl.constexpr, + VALIDATE: tl.constexpr, +): + p = tl.program_id(0) + r = tl.program_id(1) + req = tl.load(req_idx_ptr + r).to(tl.int64) + c4_len = tl.load(c4_len_ptr + r).to(tl.int64) + page_start = p * PAGE_SIZE + active = (req != hold_req_id) & (page_start < c4_len) + + full_pos0 = page_start * RATIO + (RATIO - 1) + full_slot0 = tl.load( + req_to_token_ptr + req * req_to_token_stride0 + full_pos0, + mask=active, + other=0, + ).to(tl.int64) + c4_slot0 = tl.load(full_to_c4_ptr + full_slot0, mask=active, other=0).to(tl.int64) + phys_page = c4_slot0 // PAGE_SIZE + tl.store(page_table_ptr + r * page_cap + p, tl.where(active, phys_page, 0).to(tl.int32)) + + if VALIDATE: + offs = tl.arange(0, PAGE_SIZE) + e = page_start + offs + valid = active & (e < c4_len) + full_pos = e * RATIO + (RATIO - 1) + full_slot = tl.load( + req_to_token_ptr + req * req_to_token_stride0 + full_pos, + mask=valid, + other=0, + ).to(tl.int64) + c4_slot = tl.load(full_to_c4_ptr + full_slot, mask=valid, other=-1).to(tl.int64) + expected = phys_page * PAGE_SIZE + offs + ok = tl.where(valid, (c4_slot == expected) & (c4_slot >= 0), True) + if tl.min(ok.to(tl.int32), axis=0) == 0: + tl.store(valid_flag_ptr, 0) + + +@torch.no_grad() +def build_c4_indexer_page_table( + mem_manager, + b_req_idx: torch.Tensor, + c4_len: torch.Tensor, + c4_cap: int, + req_to_token_indexs: torch.Tensor, + hold_req_id: int, + validate: bool = False, +): + """Build the logical-c4-page -> physical-c4-page table expected by DeepGEMM paged logits. + + This is safe only when each logical c4 page maps to a physical page with matching offsets: + c4_slot(entry p*64 + o) == page_table[p] * 64 + o + The optional validation flag checks that invariant and lets the caller fall back to the + gather path while we keep the current token-slot allocator. + """ + pool = mem_manager.c4_indexer_pool + page_size = pool.page_size + assert c4_cap % page_size == 0 + batch = b_req_idx.shape[0] + page_cap = c4_cap // page_size + page_table = torch.empty((batch, page_cap), dtype=torch.int32, device=b_req_idx.device) + valid_flag = torch.ones((1,), dtype=torch.int32, device=b_req_idx.device) + _build_c4_indexer_page_table_kernel[(page_cap, batch)]( + b_req_idx, + c4_len, + req_to_token_indexs, + req_to_token_indexs.stride(0), + mem_manager.full_to_c4_indexs, + page_table, + valid_flag, + page_cap, + int(hold_req_id), + RATIO=4, + PAGE_SIZE=page_size, + VALIDATE=validate, + num_warps=1, + ) + return page_table, valid_flag diff --git a/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py new file mode 100644 index 0000000000..47d87d4932 --- /dev/null +++ b/lightllm/models/deepseek_v4/triton_kernel/quant_convert.py @@ -0,0 +1,16 @@ +import torch + + +def e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + """float8_e8m0fnu encodes 2**(byte-127); torch decodes it correctly on .to(float32).""" + return scale.to(torch.float32) + + +def dequant_fp8_block_to_bf16(weight_e4m3: torch.Tensor, scale_e8m0: torch.Tensor, block_size: int = 128): + """De-quantize an FP8 e4m3 weight [out, in] with block-[bs,bs] ue8m0 scale to bf16.""" + from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant + + w = weight_e4m3.cuda().contiguous() + s = e8m0_to_fp32(scale_e8m0).cuda().contiguous() + # weight_dequant runs with torch default dtype for the output; force bf16 result. + return weight_dequant(w, s, block_size) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1bdf8f3427..e745492173 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -161,6 +161,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "qwen", "deepseekv31", "deepseekv32", + "deepseekv4", "glm47", "kimi_k2", "qwen3_coder", @@ -620,8 +621,11 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, default=None, choices=["fp8", "fp4"], - help="""Expert quantization dtype for EP MoE. Supported values are - fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""", + help="""Requested dtype for MoE expert weights, fp8 or fp4. Resolves the fused_moe + quant method: fp8 -> deepgemm-fp8w8a8-b128; fp4 -> deepgemm-fp4fp8-b32 (online + quantization) on SM100 GPUs, or marlin-mxfp4w4a16-b32 (Marlin W4A16, TP only) on other GPUs. + Defaults to `expert_dtype` in config.json if present. Per-layer override: + --quant_cfg mix_bits with name `fused_moe`.""", ) parser.add_argument( "--vit_quant_type", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3cf431d650..c13f562af9 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -10,7 +10,11 @@ from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name +from lightllm.utils.envs_utils import ( + set_env_start_args, + set_unique_server_name, + get_unique_server_name, +) from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process @@ -23,6 +27,8 @@ has_vision_module, is_linear_att_mixed_model, auto_set_max_req_total_len, + get_model_type, + get_config_json, ) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args @@ -108,6 +114,19 @@ def normal_or_p_d_start(args): else: args.enable_multimodal = True + model_type = get_model_type(args.model_dir) + if model_type == "deepseek_v4": + if args.run_mode != "normal": + raise NotImplementedError("DeepSeek-V4 currently supports only run_mode=normal in LightLLM.") + if args.enable_cpu_cache or args.enable_disk_cache: + raise NotImplementedError("DeepSeek-V4 CPU/disk KV cache is not supported yet.") + if args.mtp_mode is not None or args.mtp_draft_model_dir is not None or args.mtp_step != 0: + raise NotImplementedError("DeepSeek-V4 MTP/speculative decoding is not supported yet.") + if args.enable_ep_moe: + raise NotImplementedError("DeepSeek-V4 EP MoE is not supported yet; use TP for now.") + if "prompt_cache_kv_buffer" in get_config_json(args.model_dir): + raise NotImplementedError("DeepSeek-V4 prompt_cache_kv_buffer is not supported yet.") + if args.enable_cpu_cache: # 生成一个用于创建cpu kv cache的共享内存id。 args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789 @@ -333,7 +352,14 @@ def normal_or_p_d_start(args): from lightllm.utils.config_utils import get_dtype args.data_type = get_dtype(args.model_dir) - assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + assert args.data_type in [ + "fp16", + "float16", + "bf16", + "bfloat16", + "fp32", + "float32", + ] already_uesd_ports = [args.port] if args.nccl_port is not None: @@ -425,7 +451,6 @@ def normal_or_p_d_start(args): ) if not args.disable_vision: - if not args.visual_use_proxy_mode: from .visualserver.manager import start_visual_process @@ -609,7 +634,14 @@ def visual_only_start(args): from lightllm.utils.config_utils import get_dtype args.data_type = get_dtype(args.model_dir) - assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + assert args.data_type in [ + "fp16", + "float16", + "bf16", + "bfloat16", + "fp32", + "float32", + ] logger.info(f"alloced ports: {can_use_ports}") diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 54d22a0d0d..0565a8f0cf 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -53,6 +53,13 @@ def tokenizer_supports_force_thinking() -> bool: assert tokenizer is not None + # Tokenizers that encode prompts in Python (e.g. DeepSeek-V4) have no Jinja + # chat_template string to inspect, so advertise thinking support via an + # explicit attribute instead. + if getattr(tokenizer, "supports_thinking", False): + logger.info("tokenizer_supports_force_thinking : True (explicit attribute)") + return True + try: ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template logger.debug(f"chat_template: {tokenizer.chat_template}") diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index dfcb2f8d9e..63c9f6ac8f 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -40,6 +40,7 @@ "[TOOL_CALLS]", "<|tool▁calls▁begin|>", "<|DSML|function_calls>", + "<|DSML|tool_calls>", ] @@ -1480,11 +1481,14 @@ class DeepSeekV32Detector(BaseFormatDetector): Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 """ - def __init__(self): + def __init__(self, block_name: str = "function_calls"): super().__init__() self.dsml_token = "|DSML|" - self.bot_token = f"<{self.dsml_token}function_calls>" - self.eot_token = f"" + # DeepSeek V3.2 wraps tool calls in a `function_calls` block; V4 uses + # `tool_calls`. Only the outer block name differs — the invoke/parameter + # grammar is identical — so subclasses just override block_name. + self.bot_token = f"<{self.dsml_token}{block_name}>" + self.eot_token = f"" self.invoke_start_prefix = f"<{self.dsml_token}invoke" self.invoke_end_token = f"" self.param_end_token = f"" @@ -1962,6 +1966,32 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() +class DeepSeekV4Detector(DeepSeekV32Detector): + """ + Detector for DeepSeek V4 model function call format using DSML. + + Identical grammar to V3.2 (``<|DSML|invoke name="...">`` blocks with + ``<|DSML|parameter name="k" string="true|false">v`` + tags), except the outer block is named ``tool_calls`` instead of + ``function_calls`` — matching the model's own encoding (encoding_dsv4.py: + ``tool_calls_block_name = "tool_calls"``) and system prompt. + + Format Structure: + ``` + <|DSML|tool_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">Hangzhou + + + ``` + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V4 + """ + + def __init__(self): + super().__init__(block_name="tool_calls") + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1975,6 +2005,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, "deepseekv32": DeepSeekV32Detector, + "deepseekv4": DeepSeekV4Detector, "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 21e26c5854..ff09c018be 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -2,7 +2,7 @@ import torch import numpy as np import collections -from typing import Tuple, Dict, Set, List, Optional, Union +from typing import Any, Tuple, Dict, Set, List, Optional, Union from sortedcontainers import SortedSet from .shared_arr import SharedArray @@ -25,6 +25,7 @@ def __init__(self): self.parent: TreeNode = None self.token_id_key: torch.Tensor = None self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.token_extra_value: Any = None self.ref_counter = 0 self.time_id = time_gen.generate_time_id() # 用于标识时间周期 @@ -34,14 +35,17 @@ def __init__(self): def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) - def split_node(self, prefix_len): + def split_node(self, prefix_len, child_key_fn=None, extra_value_ops=None): split_parent_node = TreeNode() split_parent_node.parent = self.parent - split_parent_node.parent.children[self.token_id_key[0].item()] = split_parent_node + split_parent_node.parent.children[child_key_fn(self.token_id_key)] = split_parent_node split_parent_node.token_id_key = self.token_id_key[0:prefix_len] split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + if self.token_extra_value is not None and extra_value_ops is not None: + split_parent_node.token_extra_value = extra_value_ops.slice(self.token_extra_value, 0, prefix_len) + self.token_extra_value = extra_value_ops.slice(self.token_extra_value, prefix_len, len(self.token_id_key)) split_parent_node.children = {} - split_parent_node.children[self.token_id_key[prefix_len].item()] = self + split_parent_node.children[child_key_fn(self.token_id_key[prefix_len:])] = self split_parent_node.ref_counter = self.ref_counter new_len = len(split_parent_node.token_mem_index_value) @@ -56,11 +60,12 @@ def split_node(self, prefix_len): self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len return split_parent_node - def add_and_return_new_child(self, token_id_key, token_mem_index_value): + def add_and_return_new_child(self, token_id_key, token_mem_index_value, token_extra_value=None, child_key=None): child = TreeNode() child.token_id_key = token_id_key child.token_mem_index_value = token_mem_index_value - first_token_key = child.token_id_key[0].item() + child.token_extra_value = token_extra_value + first_token_key = child.token_id_key[0].item() if child_key is None else child_key assert first_token_key not in self.children.keys() self.children[first_token_key] = child child.parent = self @@ -71,9 +76,17 @@ def add_and_return_new_child(self, token_id_key, token_mem_index_value): return child def remove_child(self, child_node: "TreeNode"): - del self.children[child_node.token_id_key[0].item()] - child_node.parent = None - return + child_key = child_node.token_id_key[0].item() + if child_key in self.children: + del self.children[child_key] + child_node.parent = None + return + for key, value in list(self.children.items()): + if value is child_node: + del self.children[key] + child_node.parent = None + return + raise KeyError("child node not found") def update_time(self): self.time_id = time_gen.generate_time_id() @@ -103,12 +116,22 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + def __init__( + self, + unique_name, + total_token_num, + rank_in_node, + mem_manager=None, + page_size: int = 1, + extra_value_ops=None, + ): from lightllm.common.kv_cache_mem_manager import MemoryManager self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 + self.page_size = max(1, int(page_size)) + self.extra_value_ops = extra_value_ops self.root_node = TreeNode() self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) @@ -125,30 +148,66 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 - def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + def _align_len(self, length: int) -> int: + if self.page_size <= 1: + return int(length) + return int(length) // self.page_size * self.page_size + + def align_len(self, length: int) -> int: + return self._align_len(length) + + def _child_key(self, key: torch.Tensor): + if self.page_size <= 1: + return key[0].item() + return tuple(key[: self.page_size].tolist()) + + def _match_len(self, key: torch.Tensor, node_key: torch.Tensor) -> int: + prefix_len = match(key, node_key) + return self._align_len(prefix_len) + + def _slice_extra(self, extra_value, start: int, end: int): + if extra_value is None: + return None + assert self.extra_value_ops is not None + return self.extra_value_ops.slice(extra_value, start, end) + + def _concat_extra(self, values: list): + values = [v for v in values if v is not None] + if len(values) == 0: + return None + assert self.extra_value_ops is not None + return self.extra_value_ops.concat(values) + + def insert(self, key, value=None, extra_value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key + align_len = self._align_len(len(key)) + key = key[:align_len] + value = value[:align_len] + if extra_value is not None: + extra_value = self._slice_extra(extra_value, 0, align_len) + assert len(key) == len(value) # and len(key) >= 1 if len(key) == 0: return 0, None - return self._insert_helper(self.root_node, key, value) + return self._insert_helper(self.root_node, key, value, extra_value) - def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + def _insert_helper(self, node: TreeNode, key, value, extra_value) -> Tuple[int, Optional[TreeNode]]: handle_stack = collections.deque() update_list = collections.deque() - handle_stack.append((node, key, value)) + handle_stack.append((node, key, value, extra_value)) ans_prefix_len = 0 ans_node = None while len(handle_stack) != 0: - node, key, value = handle_stack.popleft() - ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) - if len(ans_tuple) == 4: - (_prefix_len, new_node, new_key, new_value) = ans_tuple + node, key, value, extra_value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value, extra_value=extra_value) + if len(ans_tuple) == 5: + (_prefix_len, new_node, new_key, new_value, new_extra_value) = ans_tuple ans_prefix_len += _prefix_len - handle_stack.append((new_node, new_key, new_value)) + handle_stack.append((new_node, new_key, new_value, new_extra_value)) else: _prefix_len, ans_node = ans_tuple ans_prefix_len += _prefix_len @@ -166,15 +225,15 @@ def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[Tree return ans_prefix_len, ans_node def _insert_helper_no_recursion( - self, node: TreeNode, key: torch.Tensor, value: torch.Tensor - ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor, extra_value=None + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor, Any]]: if node.is_leaf(): self.evict_tree_set.discard(node) - first_key_id = key[0].item() + first_key_id = self._child_key(key) if first_key_id in node.children.keys(): child: TreeNode = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) + prefix_len = self._match_len(key, child.token_id_key) if prefix_len == len(key): if prefix_len == len(child.token_id_key): if child.is_leaf(): @@ -184,10 +243,14 @@ def _insert_helper_no_recursion( self.evict_tree_set.add(child) return prefix_len, child elif prefix_len < len(child.token_id_key): + if prefix_len == 0: + return 0, node if child.is_leaf(): self.evict_tree_set.discard(child) - split_parent_node = child.split_node(prefix_len) + split_parent_node = child.split_node( + prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops + ) if split_parent_node.is_leaf(): self.evict_tree_set.add(split_parent_node) @@ -199,13 +262,23 @@ def _insert_helper_no_recursion( assert False, "can not run to here" elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if prefix_len == 0: + return 0, node if child.is_leaf(): self.evict_tree_set.discard(child) + new_extra_value = self._slice_extra(extra_value, prefix_len, len(key)) key = key[prefix_len:] value = value[prefix_len:] - split_parent_node = child.split_node(prefix_len) - new_node = split_parent_node.add_and_return_new_child(key, value) + split_parent_node = child.split_node( + prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops + ) + new_node = split_parent_node.add_and_return_new_child( + key, + value, + token_extra_value=new_extra_value, + child_key=self._child_key(key), + ) # update total token num self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) @@ -218,12 +291,23 @@ def _insert_helper_no_recursion( self.evict_tree_set.add(child) return prefix_len, new_node elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + return ( + prefix_len, + child, + key[prefix_len:], + value[prefix_len:], + self._slice_extra(extra_value, prefix_len, len(key)), + ) else: assert False, "can not run to here" else: - new_node = node.add_and_return_new_child(key, value) + new_node = node.add_and_return_new_child( + key, + value, + token_extra_value=extra_value, + child_key=first_key_id, + ) # update total token num self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) if new_node.is_leaf(): @@ -231,7 +315,12 @@ def _insert_helper_no_recursion( return 0, new_node def match_prefix(self, key, update_refs=False): - assert len(key) != 0 + key = key[: self._align_len(len(key))] + if len(key) == 0: + return None, 0, None + key = self._trim_key_by_extra_value_validity(key) + if len(key) == 0: + return None, 0, None ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if tree_node != self.root_node: @@ -245,6 +334,30 @@ def match_prefix(self, key, update_refs=False): self.dec_node_ref_counter(self.root_node) return None, 0, None + def _trim_key_by_extra_value_validity(self, key: torch.Tensor) -> torch.Tensor: + """命中有效性裁剪(extra_value_ops 提供 valid_match_length 时启用,如 DeepSeek-V4 的 + swa 按页 bitmap): 先做一次只读探测遍历得到自然命中与沿路 extra_value,按其有效边界截短 + key,随后的正常遍历(加引用/分裂)只走截短后的前缀 —— 引用计数与最终返回值在同一次遍历 + 内保持一致,不存在事后裁剪导致的漏减/多减。 + + 探测遍历可能分裂部分命中的节点(与正常遍历同语义,树不变式不受影响)。裁剪只会缩短命中, + 没有任何失败路径。""" + if self.extra_value_ops is None: + return key + valid_match_length = getattr(self.extra_value_ops, "valid_match_length", None) + if valid_match_length is None: + return key + probe_values = [] + probe_node = self._match_prefix_helper(self.root_node, key, probe_values, update_refs=False) + if probe_node == self.root_node or len(probe_values) == 0: + return key + natural_len = sum(len(v) for v in probe_values) + extra_value = self.get_extra_value_by_node(probe_node) + valid_len = int(valid_match_length(extra_value, natural_len)) + if valid_len < natural_len: + return key[:valid_len] + return key + def _match_prefix_helper( self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False ) -> TreeNode: @@ -290,20 +403,24 @@ def _match_prefix_helper_no_recursion( if len(key) == 0: return node - first_key_id = key[0].item() + first_key_id = self._child_key(key) if first_key_id not in node.children.keys(): return node else: child = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) + prefix_len = self._match_len(key, child.token_id_key) if prefix_len == len(child.token_id_key): ans_value_list.append(child.token_mem_index_value) return (child, key[prefix_len:]) elif prefix_len < len(child.token_id_key): + if prefix_len == 0: + return node if child.is_leaf(): self.evict_tree_set.discard(child) - split_parent_node = child.split_node(prefix_len) + split_parent_node = child.split_node( + prefix_len, child_key_fn=self._child_key, extra_value_ops=self.extra_value_ops + ) ans_value_list.append(split_parent_node.token_mem_index_value) if update_refs: @@ -334,6 +451,8 @@ def evict(self, need_remove_tokens, evict_callback): ), "error evict tree node state" num_evicted += len(node.token_mem_index_value) evict_callback(node.token_mem_index_value) + if self.extra_value_ops is not None and node.token_extra_value is not None: + self.extra_value_ops.free(node.token_extra_value) # update total token num self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) parent_node: TreeNode = node.parent @@ -369,11 +488,12 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: child_node.token_mem_index_value = torch.cat( [parent_node.token_mem_index_value, child_node.token_mem_index_value] ) + child_node.token_extra_value = self._concat_extra([parent_node.token_extra_value, child_node.token_extra_value]) child_node.node_value_len = len(child_node.token_mem_index_value) child_node.time_id = max(parent_node.time_id, child_node.time_id) grandparent_node = parent_node.parent - key_in_grandparent = parent_node.token_id_key[0].item() + key_in_grandparent = self._child_key(parent_node.token_id_key) grandparent_node.children[key_in_grandparent] = child_node child_node.parent = grandparent_node @@ -469,6 +589,19 @@ def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: ans_list.reverse() return torch.concat(ans_list, dim=0) + def get_extra_value_by_node(self, node: TreeNode): + if node is None or self.extra_value_ops is None: + return None + + ans_list = [] + while node is not None: + if node.token_extra_value is not None: + ans_list.append(node.token_extra_value) + node = node.parent + + ans_list.reverse() + return self._concat_extra(ans_list) + def get_refed_tokens_num(self): return self.refed_tokens_num.arr[0] @@ -489,6 +622,36 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return + def reclaim_unreferenced_swa_pages(self, need_pages: int) -> None: + """DeepSeek-V4 swa 压力阀: 页 allocator 触底时,沿 LRU 序(evict_tree_set)只对 + ref_count==0 的节点链回收其 swa 页(full 槽与压缩条目保留——节点仍可服务更长前缀的 + 中段命中),并清载荷 bitmap 位使后续命中按缩短语义裁剪。所有权判定直接复用 radix + 引用计数: 节点被任何活跃请求借用即 ref>0,其页不可达。不够时由 allocator 的 assert + 兜底(最后防线)。""" + if self.mem_manager is None or self.extra_value_ops is None: + return + invalidate = getattr(self.extra_value_ops, "invalidate_swa_pages", None) + if invalidate is None: + return + allocator = self.mem_manager.swa_page_allocator + target = allocator.can_use_mem_size + int(need_pages) + for leaf in list(self.evict_tree_set): + if allocator.can_use_mem_size >= target: + break + node = leaf + # 叶子起步沿父链回收: 引用计数向上累加(add_node_ref_counter 走父链), + # 因此 ref==0 的祖先必无任何活跃借用方。重复访问无害(evict_swa/-1 跳过)。 + # 每回收一个节点就复查目标,避免多回收(无谓削减命中可用性)。 + while node is not None and node is not self.root_node and node.ref_counter == 0: + if len(node.token_mem_index_value) > 0: + self.mem_manager.evict_swa(node.token_mem_index_value) + if node.token_extra_value is not None: + invalidate(node.token_extra_value) + if allocator.can_use_mem_size >= target: + return + node = node.parent + return + def free_radix_cache_to_get_enough_token(self, need_token_num): assert self.mem_manager is not None if need_token_num > self.mem_manager.allocator.can_use_mem_size: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5c2d0d45fb..2bf2314185 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -8,7 +8,7 @@ from sortedcontainers import SortedDict from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any, Union -from lightllm.common.req_manager import ReqManager, ReqManagerForMamba +from lightllm.common.req_manager import DeepseekV4ReqManager, ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -50,7 +50,9 @@ def register( vocab_size: int, ): self.args = get_env_start_args() - from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend + from lightllm.server.router.model_infer.mode_backend.base_backend import ( + ModeBackend, + ) self.backend: ModeBackend = backend self.req_manager = req_manager @@ -122,11 +124,20 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): + is_dsv4_req_manager = hasattr(self.req_manager, "build_prompt_cache_payload") if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) + if is_dsv4_req_manager: + # 槽位随 full 槽经 mem_manager.free 级联回收。pause 路径不释放 req_idx, + # 必须在此复位出窗水位线 + 清 c128 在途状态(恢复命中走 extend,不会再有 + # restore/zero 时机;c4 状态随 swa 页生灭,无需处理)。 + self.req_manager.init_compress_state(req.req_idx) else: if not self.is_linear_att_mixed_model: - self._full_att_free_req(free_token_index=free_token_index, req=req) + if is_dsv4_req_manager: + self._dsv4_full_att_free_req(free_token_index=free_token_index, req=req) + else: + self._full_att_free_req(free_token_index=free_token_index, req=req) else: self._linear_att_free_req(free_token_index=free_token_index, req=req) assert len(req.linear_att_len_to_big_page_id) == 0 @@ -134,6 +145,11 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): req.shm_req.shm_cur_kv_len = req.cur_kv_len return + def _append_free_token_index(self, free_token_index: List, tensor: torch.Tensor): + if tensor.numel() > 0: + free_token_index.append(tensor) + return + def _full_att_free_req(self, free_token_index: List, req: "InferReq"): input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") @@ -149,6 +165,68 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"): req.shared_kv_node = None return + def _dsv4_full_att_free_req(self, free_token_index: List, req: "InferReq"): + if req.cur_kv_len == 0: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0:0]) + return + + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + inserted_len = old_prefix_len + duplicate_prefix_len = old_prefix_len + + # 载荷只剩按页 bitmap(compressor 状态随 swa 页生灭/边界自然归零,不进载荷), + # 任意 128 对齐前缀皆可插入——含生成段(floor(cur_kv_len) 边界,回收保留尾页保证其驻留)。 + cache_len = self.radix_cache.align_len(req.cur_kv_len) + self.req_manager: DeepseekV4ReqManager + if cache_len > old_prefix_len: + payload = self.req_manager.build_prompt_cache_payload(req.req_idx, cache_len) + value = self.req_manager.req_to_token_indexs[req.req_idx][:cache_len].detach().cpu() + # 按页有效性 bitmap 用插入时刻的映射写定(此后只会被阀清 0,不会复活)。水位线 + # 纯 CPU 推导,避免 router 关键路径上的 GPU gather 同步(每插入一次要等全部在途 + # decode kernel)。插入门: 截掉结尾的 invalid 页 —— 它们生来不可命中,还会永久 + # 挡住后续更长前缀复用同一段 token(全量重插会因前缀已存在而保留旧 bitmap)。 + page_size = self.req_manager.get_prompt_cache_page_size() + bitmap = self.req_manager.swa_page_valid_from_watermark(req.req_idx, cache_len) + n_pages = int(bitmap.numel()) + while n_pages > 0 and not bool(bitmap[n_pages - 1]): + n_pages -= 1 + gated_len = n_pages * page_size + if gated_len < cache_len: + logger.info( + f"DeepSeek-V4 prompt cache insert gate: trailing swa pages already evicted, " + f"shrink insert {cache_len} -> {gated_len}" + ) + cache_len = gated_len + payload.cache_len = cache_len + payload.swa_page_valid = bitmap[:n_pages].clone() + + if cache_len > old_prefix_len: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0:cache_len], dtype=torch.int64, device="cpu") + duplicate_prefix_len, cache_node = self.radix_cache.insert(key, value[:cache_len], extra_value=payload) + inserted_len = 0 if cache_node is None else cache_node.node_prefix_total_len + if inserted_len != cache_len: + inserted_len = old_prefix_len + duplicate_prefix_len = old_prefix_len + + dense_row = self.req_manager.req_to_token_indexs[req.req_idx] + self._append_free_token_index(free_token_index, dense_row[old_prefix_len:duplicate_prefix_len]) + self._append_free_token_index(free_token_index, dense_row[inserted_len : req.cur_kv_len]) + if len(free_token_index) == 0: + free_token_index.append(dense_row[0:0]) + # 释放的 full 槽经 mem_manager.free 级联回收 swa/c4/c128(映射键控,无需收集槽位)。 + + # pause 路径不会走 req_manager.free/init: 复位出窗水位线(残留水位线会破坏下一次 + # prefill 的共享前缀保护)并清 c128 在途状态(恢复命中走 extend 续算,若残留暂停前的 + # 半窗聚合会算错;c128 状态在 128 对齐命中边界本应为零)。 + self.req_manager.init_compress_state(req.req_idx) + + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= max(inserted_len, old_prefix_len) + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + return + def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert g_infer_context.is_linear_att_mixed_model is True args = get_env_start_args() @@ -325,7 +403,12 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): self.req_manager.free_token(free_token_index) return self - def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): + def recover_paused_reqs( + self, + paused_reqs: List["InferReq"], + is_master_in_dp: bool, + can_alloc_token_num: int, + ): if paused_reqs: for req in paused_reqs: @@ -382,7 +465,9 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L ) big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True) - from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + from lightllm.common.basemodel.triton_kernel.linear_att_copy import ( + copy_linear_att_state_to_kv_buffer, + ) copy_linear_att_state_to_kv_buffer( b_req_idx=b_req_idx, @@ -412,9 +497,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...] dst_buffer_idx = req.tail_linear_att_small_page_buffer_id - dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache( - buffer_idx=dst_buffer_idx - ) + ( + dst_conv_state, + dst_ssm_state, + ) = self.radix_cache.linear_att_small_page_buffers.get_state_cache(buffer_idx=dst_buffer_idx) # TODO 对于非连续对象调用 copy_ 效率并不高 dst_conv_state.copy_(gpu_conv_state, non_blocking=True) dst_ssm_state.copy_(gpu_ssm_state, non_blocking=True) @@ -591,6 +677,8 @@ def _init_all_state(self): self.cur_output_len = 0 g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self) + if hasattr(g_infer_context.req_manager, "init_compress_state"): + g_infer_context.req_manager.init_compress_state(req_idx=self.req_idx) self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list() # token healing mode 才被使用的管理对象 @@ -626,6 +714,9 @@ def _match_radix_cache(self): ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + # DeepSeek-V4 命中无需任何恢复: 槽位由 full_to_* 映射键控(radix 持有 full 槽即有效, + # 命中长度已在 match_prefix 内按 bitmap 裁剪),c4 compressor 状态随 swa 页常驻 + # (零拷贝续算),c128 状态在 128 对齐边界自然归零(init_compress_state 已清)。 self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 @@ -639,7 +730,10 @@ def _linear_match_radix_cache(self): enable_prompt_cache = (not self.sampling_param.disable_prompt_cache) and g_infer_context.radix_cache is not None linear_hash_list = self.shm_req.linear_att_token_hash_list.get_all() linear_att_hash_page_size = self.args.linear_att_hash_page_size - match_tokens = min(len(linear_hash_list) * linear_att_hash_page_size, self.get_cur_total_len() - 1) + match_tokens = min( + len(linear_hash_list) * linear_att_hash_page_size, + self.get_cur_total_len() - 1, + ) match_tokens = max(0, match_tokens) match_tokens = (match_tokens // linear_att_hash_page_size) * linear_att_hash_page_size match_block_num = match_tokens // linear_att_hash_page_size @@ -705,7 +799,8 @@ def _linear_match_radix_cache(self): # 将 对应的 value_tensors 中的 kv 数据 拷贝到 tail_mems 中对应的数据去 radix_cache.mem_manager.operator.copy_mem_to_mem( - value_tensor[cur_big_page_tokens:shared_kv_len], tail_mems + value_tensor[cur_big_page_tokens:shared_kv_len], + tail_mems, ) self.shared_kv_node = share_node # 只是为了保证 copy_small_page_buffer_to_linear_att_state 正确调用 @@ -736,7 +831,8 @@ def _linear_match_radix_cache(self): assert self.tail_linear_att_small_page_buffer_id is None # 恢复linear att 状态 g_infer_context.req_manager.copy_big_page_buffer_to_linear_att_state( - big_page_buffer_idx=share_node.big_page_buffer_idx, req=self + big_page_buffer_idx=share_node.big_page_buffer_idx, + req=self, ) self.shm_req.shm_cur_kv_len = self.cur_kv_len @@ -792,6 +888,7 @@ def get_input_token_ids(self): def get_chuncked_input_token_ids(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size) + chunked_end = self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end) return self.shm_req.shm_prompt_ids.arr[0:chunked_end] def get_chuncked_input_token_ids_for_linear_att(self): @@ -812,6 +909,23 @@ def get_chuncked_input_token_ids_for_linear_att(self): def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.args.chunked_prefill_size) + return self._align_chuncked_end_for_prompt_cache(chunked_start, chunked_end) + + def _align_chuncked_end_for_prompt_cache(self, chunked_start: int, chunked_end: int): + radix_cache = g_infer_context.radix_cache + page_size = getattr(radix_cache, "page_size", 1) if radix_cache is not None else 1 + if page_size <= 1 or self.sampling_param.disable_prompt_cache: + return chunked_end + prompt_end = int(self.shm_req.input_len) + chunked_start = int(chunked_start) + chunked_end = int(chunked_end) + if chunked_end >= prompt_end: + return chunked_end + + assert self.args.chunked_prefill_size % page_size == 0, ( + f"chunked_prefill_size={self.args.chunked_prefill_size} must be divisible by " + f"prompt-cache page_size={page_size}" + ) return chunked_end def get_chuncked_input_token_len_for_linear_att(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a65dfb1bbb..1f19b9462a 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -153,6 +153,8 @@ def init_model(self, kvargs): self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.is_linear_att_mixed_model = isinstance(self.model.req_manager, ReqManagerForMamba) + if hasattr(self.model.req_manager, "build_prompt_cache_payload"): + self.support_overlap = False if self.is_linear_att_mixed_model: self.linear_att_cache_manager = LinearAttCacheManager( @@ -164,6 +166,7 @@ def init_model(self, kvargs): if not self.use_dynamic_prompt_cache: self.radix_cache = None + setattr(self.args, "dynamic_prompt_cache_page_size", 1) else: if self.is_linear_att_mixed_model: self.radix_cache = LinearAttPagedRadixCache( @@ -175,13 +178,25 @@ def init_model(self, kvargs): kv_cache_mem_manager=self.model.mem_manager, linear_att_small_page_buffers=self.linear_att_cache_manager, ) + setattr(self.args, "dynamic_prompt_cache_page_size", 1) else: + radix_page_size = 1 + radix_extra_value_ops = None + if hasattr(self.model.req_manager, "get_prompt_cache_value_ops"): + radix_page_size = self.model.req_manager.get_prompt_cache_page_size() + radix_extra_value_ops = self.model.req_manager.get_prompt_cache_value_ops() + setattr(self.args, "dynamic_prompt_cache_page_size", radix_page_size) self.radix_cache = RadixCache( unique_name=get_unique_server_name(), total_token_num=self.model.mem_manager.size, rank_in_node=self.rank_in_node, mem_manager=self.model.mem_manager, + page_size=radix_page_size, + extra_value_ops=radix_extra_value_ops, ) + if radix_extra_value_ops is not None and hasattr(self.model.mem_manager, "set_swa_pressure_valve"): + # swa 页 allocator 触底时让 radix 对 ref==0 节点回收 swa 页(DeepSeek-V4)。 + self.model.mem_manager.set_swa_pressure_valve(self.radix_cache.reclaim_unreferenced_swa_pages) if "prompt_cache_kv_buffer" in model_cfg: assert self.use_dynamic_prompt_cache diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e1a4e421d1..6aea7cd672 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -90,6 +90,11 @@ def get_tokenizer( ) logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") return DeepSeekV32Tokenizer(hf_tokenizer) + if model_type == "deepseek_v4": + from ..models.deepseek_v4.model import DeepSeekV4Tokenizer + + logger.info("Using DeepSeek-V4 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV4Tokenizer(tokenizer, tokenizer_name) if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor diff --git a/lightllm/third_party/__init__.py b/lightllm/third_party/__init__.py new file mode 100644 index 0000000000..2adb50db25 --- /dev/null +++ b/lightllm/third_party/__init__.py @@ -0,0 +1 @@ +"""Third-party source subsets vendored for LightLLM runtime support.""" diff --git a/lightllm/third_party/sglang_jit/LICENSE b/lightllm/third_party/sglang_jit/LICENSE new file mode 100755 index 0000000000..9c422689c8 --- /dev/null +++ b/lightllm/third_party/sglang_jit/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-2024 SGLang Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/lightllm/third_party/sglang_jit/README.md b/lightllm/third_party/sglang_jit/README.md new file mode 100644 index 0000000000..4f68c9cfd8 --- /dev/null +++ b/lightllm/third_party/sglang_jit/README.md @@ -0,0 +1,13 @@ +# Vendored SGLang JIT Subset + +This directory contains the minimal SGLang JIT source subset needed by the +DeepSeek-V4 LightLLM implementation. + +Source: https://github.com/sgl-project/sglang +Commit: 8cea0473ea5299bc04885f8f6ba71269415a39b5 +License: Apache License 2.0, copied in `LICENSE`. + +Local changes: +- The Python imports were moved from `sglang.jit_kernel.*` to + `lightllm.third_party.sglang_jit.*`. +- The package exports only the DSv4 functions used by LightLLM. diff --git a/lightllm/third_party/sglang_jit/__init__.py b/lightllm/third_party/sglang_jit/__init__.py new file mode 100644 index 0000000000..164d545b4e --- /dev/null +++ b/lightllm/third_party/sglang_jit/__init__.py @@ -0,0 +1 @@ +"""Vendored SGLang JIT kernels used by DeepSeek-V4.""" diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh new file mode 100644 index 0000000000..3a89e8114c --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128.cuh @@ -0,0 +1,522 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + const int32_t index = indices_ptr[global_bid]; + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh new file mode 100644 index 0000000000..b497470606 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online.cuh @@ -0,0 +1,726 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace device::compress { + +/// \brief Plan entry for online compress 128 prefill. +/// Each entry describes a contiguous segment of tokens that lies inside a +/// single 128-chunk. Multiple segments can map to the same batch id when the +/// extend tokens span chunk boundaries. +/// +/// **Layout compatibility:** the field order/types match `PrefillPlan` so that +/// downstream kernels (e.g. `fused_norm_rope` in `CompressExtend` mode) can +/// consume the compress_plan tensor as-if it were a `PrefillPlan` tensor -- +/// they only read `ragged_id` and `position`, both of which carry identical +/// semantics here (the LAST token of the segment in q-ragged and global +/// coordinates respectively). +/// +/// Note that `window_len` here means "number of real tokens in this segment" +/// (1..128), which differs from `PrefillPlan::window_len`. Downstream kernels +/// that share the tensor MUST NOT read it under that name. +struct alignas(16) OnlinePrefillPlan { + /// \brief Ragged-q position of the LAST token in this segment. + /// Equal to `segment_start_ragged + window_len - 1`. + uint32_t ragged_id; + /// \brief Index into the `indices` / `load_indices` arrays. + uint32_t batch_id; + /// \brief Global position of the LAST token in this segment. + /// For compress plans, `position % 128 == 127` (chunk-closing); for write + /// plans, `position % 128 < 127`. + uint32_t position; + /// \brief Number of real tokens in this segment (1..128). + /// The first segment token sits at `position - window_len + 1` (global) and + /// at `ragged_id - window_len + 1` (ragged). + uint32_t window_len; +}; + +static_assert(alignof(OnlinePrefillPlan) == alignof(PrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == sizeof(PrefillPlan)); + +} // namespace device::compress + +namespace host::compress { + +using device::compress::OnlinePrefillPlan; +using OnlinePrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kOnlinePrefillPlanDim = 16; + +static_assert(alignof(OnlinePrefillPlan) == sizeof(OnlinePrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == kOnlinePrefillPlanDim * sizeof(OnlinePrefillPlanTensorDtype)); + +} // namespace host::compress + +namespace { + +using OnlinePlan = device::compress::OnlinePrefillPlan; +using IndiceT = int32_t; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlineDecodeParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlinePrefillParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ load_indices; + /// \brief Plan for segments that close a chunk (write to `kv_compressed_output`). + /// Shape: `[num_compress, 16]` (uint8). + const OnlinePlan* __restrict__ compress_plan; + /// \brief Plan for the trailing partial segment of each batch (write back to + /// `kv_score_buffer`). Shape: `[num_write, 16]` (uint8). + const OnlinePlan* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +// 4 elements per thread, kHeadDim / 4 threads per block +template +__global__ void flash_c128_online_decode(const __grid_constant__ Compress128OnlineDecodeParams params) { + using namespace device; + constexpr uint32_t kVecSize = 4; + constexpr uint32_t kBlockSize = kHeadDim / kVecSize; + using Vec = AlignedVector; + const auto gmem = tile::Memory::cta(kBlockSize); + const auto batch_id = blockIdx.x; + const auto index = params.indices[batch_id]; + const auto seq_len = params.seq_lens[batch_id]; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kHeadDim * 3); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2); + + /// NOTE: kv_score_buffer layout is [max, sum, kv] (slot 0 / 1 / 2). Reads, + /// writes, and the prefill kernel must all agree on this order. + const auto max_score_vec = gmem.load(kv_buf, 0); + const auto sum_score_vec = gmem.load(kv_buf, 1); + const auto old_kv_vec = gmem.load(kv_buf, 2); + + /// NOTE: kv_score_input layout is | kv | score | (head_dim each), matching + /// the offline c128 kernel and the online prefill kernel. + const auto new_kv_vec = gmem.load(kv_src, 0); + const auto new_score_raw_vec = gmem.load(kv_src, 1); + + /// NOTE: the new token sits at global position `seq_len - 1`, so its + /// position inside the 128-chunk is `(seq_len - 1) % 128`. The previous + /// `seq_len % 128` was off by one (`bias[127]` vs `bias[0]`, etc.). + const auto pos_in_chunk = (seq_len - 1) % 128; + const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk); + + Vec out_kv_vec; + Vec out_max_vec; + Vec out_sum_vec; + if (pos_in_chunk != 0) { + // Mid-chunk: combine prior partial state with the new token via online softmax. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto old_max = max_score_vec[i]; + const auto old_kv = old_kv_vec[i]; + const auto new_score = new_score_raw_vec[i] + bias_vec[i]; + const auto new_kv = new_kv_vec[i]; + const auto new_max = fmax(old_max, new_score); + const auto old_sum = sum_score_vec[i] * expf(old_max - new_max); + const auto new_exp = expf(new_score - new_max); + const auto new_sum = old_sum + new_exp; + out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum; + out_max_vec[i] = new_max; + out_sum_vec[i] = new_sum; + } + } else { + // First token of a new 128-chunk: initialize state with this token alone. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + out_kv_vec[i] = new_kv_vec[i]; + out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i]; + out_sum_vec[i] = 1.0f; // exp(score - max) with max == score + } + } + + if (pos_in_chunk == 127) { + // Chunk just closed: emit the compressed kv. No need to update the buffer + // -- the next chunk's first token will overwrite it. + const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim; + gmem.store(kv_out, out_kv_vec); + } else { + // Otherwise persist the running [max, sum, kv] state for the next step. + gmem.store(kv_buf, out_max_vec, 0); + gmem.store(kv_buf, out_sum_vec, 1); + gmem.store(kv_buf, out_kv_vec, 2); + } +} + +constexpr int32_t kTileElements = 2; // split (along head-dim) +/// \brief Each warp will handle this many elements (split along softmax-128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps; +using PrefillStorage = device::AlignedVector; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_prefill_forward( + const PrefillStorage (&kv)[kElementsPerWarp], + const PrefillStorage (&score)[kElementsPerWarp], + float* kv_out, + float* max_out, + float* sum_out, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = score[j][i]; + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = global_product; + if constexpr (kNeedData) { + max_out[local_elem_id] = global_val_max; + sum_out[local_elem_id] = global_exp_sum; + } + } + if constexpr (kNeedData) __syncthreads(); +} + +/// \brief Sentinel score for padded positions in a 128-segment. +/// Must be finite so that `score - max` never produces NaN even when an +/// entire warp has only padded positions. +constexpr float kPadScore = -FLT_MAX; + +/// \brief Online compress 128 prefill. Two passes share this body: +/// - `kWrite=false` (compress pass): handles segments that close a chunk. +/// May load prior partial state from the buffer, but never writes to it, +/// so concurrent blocks can read the same slot without racing. +/// - `kWrite=true` (write pass): handles the trailing partial segment of each +/// batch. Each batch contributes at most one such plan, so concurrent blocks +/// touch disjoint buffer slots. +/// +/// The two passes MUST run as separate kernel launches (in stream order) so +/// that all reads in pass 1 finish before any writes in pass 2 start. +template +__global__ __launch_bounds__(kPrefillBlockSize, 2) // + void flash_c128_online_prefill(const __grid_constant__ Compress128OnlinePrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + /// NOTE: the compiler folds the if-else at compile time. + const auto num_plans = kWrite ? params.num_write : params.num_compress; + const auto plan_ptr = kWrite ? params.write_plan : params.compress_plan; + const uint32_t global_id = blockIdx.x; + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + if (global_pid >= num_plans) return; + const auto [ragged_id, batch_id, position, window_len] = plan_ptr[global_pid]; + if (ragged_id == 0xFFFFFFFFu) [[unlikely]] + return; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + const int32_t split_offset = global_sid * kTileDim; // int32 is enough + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_compressed_output = static_cast(params.kv_compressed_output); + const auto score_bias_base = static_cast(params.score_bias); + + constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score | + const uint32_t chunk_offset = (position % 128u) + 1u - window_len; + const uint32_t window_end = chunk_offset + window_len; // exclusive, in [1, 128] + const int32_t segment_start = ragged_id - (position % 128u); // can be negative, but safe + const int32_t load_index = chunk_offset != 0 ? params.load_indices[batch_id] : -1; + const int32_t store_index = kWrite ? params.indices[batch_id] : -1; + + PDLWaitPrimary(); + + // 2 * 8 = 16 register per elem. in theory we should consume 48 register here + PrefillStorage kv[kElementsPerWarp]; + PrefillStorage score[kElementsPerWarp]; + PrefillStorage bias[kElementsPerWarp]; + const auto warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + if (j >= chunk_offset && j < window_end) { + const auto kv_src_ptr = kv_score_input + (segment_start + j) * kElementSize + split_offset; + const auto score_src_ptr = kv_src_ptr + kHeadDim; + const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset; + kv[i].load(kv_src_ptr, lane_id); + score[i].load(score_src_ptr, lane_id); + bias[i].load(bias_src_ptr, lane_id); + } + } + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + const bool is_valid = (j >= chunk_offset && j < window_end); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore; + /// NOTE: must zero out kv on padded slots -- `c128_prefill_forward` + /// computes `kv * exp_score` where `exp_score = expf(-FLT_MAX - max) ??? 0`, + /// and IEEE-754 makes `NaN * 0 = NaN` / `+-inf * 0 = NaN`. An + /// uninitialized register can hold a NaN/inf bit pattern, so without + /// this reset a single padded warp can poison the whole softmax. + kv[i][ii] = is_valid ? kv[i][ii] : 0.0f; + } + } + + __shared__ alignas(16) float seg_kv[kTileDim]; + __shared__ alignas(16) float seg_max[kTileDim]; + __shared__ alignas(16) float seg_sum[kTileDim]; + + c128_prefill_forward(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id); + + PDLTriggerSecondary(); + + if (warp_id == 0) { + PrefillStorage out_kv_vec, out_max_vec, out_sum_vec; + out_kv_vec.load(seg_kv, lane_id); + out_max_vec.load(seg_max, lane_id); + out_sum_vec.load(seg_sum, lane_id); + if (chunk_offset != 0) { + /// NOTE: load (max, sum, kv) of the in-progress chunk for this index. + /// `load_indices` may differ from `indices` when the prior partial state + /// lives on a different slot than the slot we ultimately write to. + const auto buf_load = kv_score_buffer + load_index * (kHeadDim * 3) + split_offset; + PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec; + buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id); + buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id); + buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + const float m1 = buf_max_vec[ii]; + const float s1 = buf_sum_vec[ii]; + const float k1 = buf_kv_vec[ii]; + const float m2 = out_max_vec[ii]; + const float s2 = out_sum_vec[ii]; + const float k2 = out_kv_vec[ii]; + const float new_max = fmaxf(m1, m2); + const float new_s1 = s1 * expf(m1 - new_max); + const float new_s2 = s2 * expf(m2 - new_max); + const float new_sum = new_s1 + new_s2; + const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum; + out_max_vec[ii] = new_max; + out_sum_vec[ii] = new_sum; + out_kv_vec[ii] = new_kv; + } + } + + if constexpr (kWrite) { + const auto buf_store = kv_score_buffer + store_index * (kHeadDim * 3) + split_offset; + reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec; + reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec; + reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec; + } else { + const auto out_ptr = kv_compressed_output + ragged_id * kHeadDim + split_offset; + reinterpret_cast(out_ptr)[lane_id] = out_kv_vec; + } + } +} + +template +struct FlashCompress128OnlineKernel { + static constexpr auto decode_kernel = flash_c128_online_decode; + template + static constexpr auto prefill_kernel = flash_c128_online_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}).with_dtype().with_device(device).verify(indices); + TensorMatcher({B}).with_dtype().with_device(device).verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128OnlineDecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + LaunchKernel(batch_size, kDecodeBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + using host::compress::kOnlinePrefillPlanDim; + using host::compress::OnlinePrefillPlanTensorDtype; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) ??? 2D + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, kOnlinePrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, kOnlinePrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + /// NOTE: `extra` is `load_indices`. When the previous partial state lives + /// on a slot different from the destination slot (e.g. paged buffers), the + /// caller must supply this; otherwise it defaults to `indices`. + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}).with_dtype().with_device(device_).verify(load_indices); + + const auto device = device_.unwrap(); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128OnlinePrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + + /// NOTE: pass 1 reads the buffer (for the first segment of each batch + /// that started mid-chunk) and writes only to `kv_compressed_output`. + /// Pass 2 then writes the trailing partial state of each batch back to + /// the buffer. Stream serialization between the two launches enforces + /// read-before-write on shared buffer slots. + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = num_w * kNumSplit) { + LaunchKernel(num_w_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace + +namespace host::compress { + +using OnlinePlanResult = tvm::ffi::Tuple; + +struct OnlinePrefillCompressParams { + OnlinePrefillPlan* __restrict__ compress_plan; + OnlinePrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; +}; + +/// \brief Build the compress + write plans for online compress 128 prefill. +/// +/// Each batch's `[prefix_len, prefix_len + extend_len)` range is split at +/// 128-aligned boundaries. Every resulting segment falls into one of: +/// - **compress**: closes a 128-chunk (`chunk_offset + window_len == 128`). +/// These plans only read the buffer (when starting mid-chunk) and write the +/// compressed kv to `kv_compressed_output`. +/// - **write**: trailing partial of the batch (`chunk_offset + window_len < 128`). +/// May read the buffer and always writes the new partial state back to it. +/// Each batch produces at most one such plan. +/// +/// The two plans MUST be dispatched as separate kernel launches in stream +/// order so that pass-1 reads of a buffer slot complete before any pass-2 +/// write of the same slot. +inline OnlinePlanResult plan_online_prefill_host(const OnlinePrefillCompressParams& params, const bool use_cuda_graph) { + const auto& [compress_plan, write_plan, seq_lens, extend_lens, batch_size, num_tokens] = params; + + uint32_t counter = 0; + uint32_t compress_count = 0; + uint32_t write_count = 0; + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = static_cast(seq_lens[i]); + const uint32_t extend_len = static_cast(extend_lens[i]); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t end_pos = prefix_len + extend_len; + /// NOTE: split the extend range into per-128-chunk segments. Each segment + /// stays inside one chunk, so the kernel can decide load/store from + /// `chunk_offset` and `window_len` alone. + uint32_t pos = prefix_len; + while (pos < end_pos) { + const uint32_t chunk_start = (pos / 128u) * 128u; + const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive + const uint32_t seg_len = seg_end - pos; + const uint32_t chunk_off = pos - chunk_start; + /// NOTE: store last-token coordinates so that downstream consumers + /// (e.g. `fused_norm_rope`) can read `ragged_id` and `position` with the + /// same semantics as `PrefillPlan`. The segment start is recoverable as + /// `ragged_id - window_len + 1` and `position - window_len + 1`. + const uint32_t last_pos = seg_end - 1; + const uint32_t last_ragged = counter + (last_pos - prefix_len); + const auto plan = OnlinePrefillPlan{ + .ragged_id = last_ragged, + .batch_id = i, + .position = last_pos, + .window_len = seg_len, + }; + if (chunk_off + seg_len == 128u) { + // full chunk, must be complete, maybe read the buffer, no write + RuntimeCheck(compress_count < num_tokens); + compress_plan[compress_count++] = plan; + } else { + // last chunk, must be incomplete, maybe read the buffer, must write + RuntimeCheck(write_count < num_tokens); + write_plan[write_count++] = plan; + } + pos = seg_end; + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return OnlinePlanResult{compress_count, write_count}; + /// NOTE: pad both plans with sentinel entries so cuda-graph runs always see + /// the same number of blocks. The kernel skips plans whose `ragged_id` is -1. + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = OnlinePrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_count, num_tokens)) { + compress_plan[i] = kInvalidPlan; + } + for (const auto i : irange(write_count, num_tokens)) { + write_plan[i] = kInvalidPlan; + } + return OnlinePlanResult{num_tokens, num_tokens}; +} + +inline OnlinePlanResult plan_online_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + /// NOTE: only host (CPU/cuda-host) planning is implemented for now. The + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kOnlinePrefillPlanDim}) // + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + const auto params = OnlinePrefillCompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + }; + return plan_online_prefill_host(params, use_cuda_graph); +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_online_prefill = host::compress::plan_online_prefill; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh new file mode 100644 index 0000000000..71e600dc39 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_online_v2.cuh @@ -0,0 +1,875 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using PlanD = device::compress::DecodePlan; +using PlanC = device::compress::CompressPlan; + +// --------------------------------------------------------------------------- +// Decode kernel: 1 token / batch. Each block handles one batch. +// 4 elements per thread -> kBlockSize = head_dim / 4. +// --------------------------------------------------------------------------- + +struct Compress128OnlineDecodeParams { + void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3] + const void* __restrict__ kv_score_input; // [batch_size, head_dim * 2] + void* __restrict__ kv_compressed_output; // [batch_size, head_dim] + const void* __restrict__ score_bias; // [128, head_dim] + const PlanD* __restrict__ plan_d; + uint32_t batch_size; +}; + +template +__global__ void flash_c128_online_decode_v2(const __grid_constant__ Compress128OnlineDecodeParams params) { + using namespace device; + constexpr uint32_t kVecSize = 4; + constexpr uint32_t kBlockSize = kHeadDim / kVecSize; + using Vec = AlignedVector; + const auto gmem = tile::Memory::cta(kBlockSize); + const auto batch_id = blockIdx.x; + if (batch_id >= params.batch_size) return; + + // Wait for the plan-finalize kernel to publish `plan.read_page_0 / write_loc` + // before reading the plan. The plan kernel runs on the same stream and does + // NOT issue a PDL trigger, so launching this kernel with PDL means our + // pre-wait global reads can race with the plan kernel's writes. + PDLWaitPrimary(); + + const auto plan = params.plan_d[batch_id]; + const auto pos_in_chunk = (plan.seq_len - 1) % 128; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_load_buf = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3); + const auto kv_store_buf = kv_score_buffer + plan.write_loc * (kHeadDim * 3); + const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2); + + // Buffer layout: [max | sum | kv] (slot 0 / 1 / 2 of the head_dim*3 row). + const auto new_kv_vec = gmem.load(kv_src, 0); + const auto new_score_raw_vec = gmem.load(kv_src, 1); + const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk); + + Vec out_kv_vec; + Vec out_max_vec; + Vec out_sum_vec; + if (pos_in_chunk != 0) { + // Mid-chunk: combine prior partial state with the new token. + const auto max_score_vec = gmem.load(kv_load_buf, 0); + const auto sum_score_vec = gmem.load(kv_load_buf, 1); + const auto old_kv_vec = gmem.load(kv_load_buf, 2); +#pragma unroll + for (uint32_t i = 0; i < kVecSize; ++i) { + const auto old_max = max_score_vec[i]; + const auto old_kv = old_kv_vec[i]; + const auto new_score = new_score_raw_vec[i] + bias_vec[i]; + const auto new_kv = new_kv_vec[i]; + const auto new_max = fmaxf(old_max, new_score); + const auto old_sum = sum_score_vec[i] * expf(old_max - new_max); + const auto new_exp = expf(new_score - new_max); + const auto new_sum = old_sum + new_exp; + out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum; + out_max_vec[i] = new_max; + out_sum_vec[i] = new_sum; + } + } else { + // First token of a new chunk: state == this token alone. +#pragma unroll + for (uint32_t i = 0; i < kVecSize; ++i) { + out_kv_vec[i] = new_kv_vec[i]; + out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i]; + out_sum_vec[i] = 1.0f; + } + } + + if (pos_in_chunk == 127) { + // Chunk just closed: emit compressed kv, no buffer update. + const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim; + gmem.store(kv_out, out_kv_vec); + } else { + gmem.store(kv_store_buf, out_max_vec, 0); + gmem.store(kv_store_buf, out_sum_vec, 1); + gmem.store(kv_store_buf, out_kv_vec, 2); + } +} + +// --------------------------------------------------------------------------- +// Prefill kernel: 1 segment / block. Two passes (compress + write) share the +// kernel template, parameterized by `kWrite`. +// 16 warps per block; each warp handles 8 of the 128 chunk positions. +// --------------------------------------------------------------------------- + +constexpr int32_t kTileElements = 2; // split along head-dim +constexpr int32_t kElementsPerWarp = 8; // split along the 128-chunk +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps; +using PrefillStorage = device::AlignedVector; + +struct Compress128OnlinePrefillParams { + void* __restrict__ kv_score_buffer; // [num_slots, 1, head_dim * 3] + const void* __restrict__ kv_score_input; // [num_q_tokens, head_dim * 2] + void* __restrict__ kv_compressed_output; // [num_compress, head_dim] + const void* __restrict__ score_bias; // [128, head_dim] + const PlanC* __restrict__ plan_c; // close-chunk segments + const PlanC* __restrict__ plan_w; // trailing partial segments + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // +1 to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +/// \brief Sentinel score for padded positions in a 128-segment. +constexpr float kPadScore = -FLT_MAX; + +[[maybe_unused]] +SGL_DEVICE void c128_prefill_segment_softmax( + const PrefillStorage (&kv)[kElementsPerWarp], + const PrefillStorage (&score)[kElementsPerWarp], + float* seg_kv, + float* seg_max, + float* seg_sum, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + // Per-warp running state (max, sum, kv) for kTileElements head-dim slots. + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = score[j][i]; + } + float max_value = score_fp32[0]; +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + max_value = fmaxf(max_value, score_fp32[j]); + } + float sum_exp_value = 0.0f; + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + const auto exp_score = expf(score_fp32[j] - max_value); + sum_product += kv[j][i] * exp_score; + sum_exp_value += exp_score; + } + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // Aligned writes (no bank conflict thanks to `+1` padding). + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + // Cross-warp reduction. Same recipe as c128_online.cuh: each block-thread + // pair reduces a (tile_id, lane_id) slot using a kNumWarps-wide warp shuffle. + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize; + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id; + const uint32_t local_warp_id = j % kNumWarps; + const uint32_t local_elem_id = j / kNumWarps; + const uint32_t local_tile_id = local_elem_id % kTileElements; + const uint32_t local_lane_id = local_elem_id / kTileElements; + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + seg_kv[local_elem_id] = global_product; + seg_max[local_elem_id] = global_val_max; + seg_sum[local_elem_id] = global_exp_sum; + } + __syncthreads(); +} + +/// \brief Online compress 128 prefill v2. +/// +/// `kWrite=false` (compress pass): handles segments that close a 128-chunk. +/// Reads optional prior state from `read_page_0` (-1 = none), emits compressed +/// kv to `kv_compressed_output[plan_id]` (compact). +/// `kWrite=true` (write pass) : handles trailing partial segments. +/// Reads optional prior state from `read_page_0` (-1 = none), writes new +/// running state to `read_page_1`. +template +__global__ __launch_bounds__(kPrefillBlockSize, 2) // + void flash_c128_online_prefill_v2(const __grid_constant__ Compress128OnlinePrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0); + + // Compile-time fold to the right plan list. + const auto num_plans = kWrite ? params.num_write : params.num_compress; + const auto plan_ptr = kWrite ? params.plan_w : params.plan_c; + const uint32_t global_id = blockIdx.x; + const uint32_t global_pid = global_id / kNumSplit; + const uint32_t global_sid = global_id % kNumSplit; + if (global_pid >= num_plans) return; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + const int32_t split_offset = global_sid * kTileDim; + + // The previous kernel (plan-finalize stage 1) does NOT issue a PDL trigger, + // so PDLWaitPrimary effectively waits for stage 1 to complete. Read the plan + // AFTER the wait so the freshly-written `read_page_0` (= state-pool slot) is + // visible. Reading it before the wait is a real race -- with PDL enabled the + // kernel can begin executing before stage 1's stores propagate, and we'd see + // the stage-0 batch_id placeholder in `read_page_0` instead of the slot. + PDLWaitPrimary(); + + const auto plan = plan_ptr[global_pid]; + if (plan.is_invalid()) [[unlikely]] + return; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_compressed_output = static_cast(params.kv_compressed_output); + const auto score_bias_base = static_cast(params.score_bias); + + constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score | + + // The plan stores last-token coordinates; segment start is recoverable as + // ragged_id - window_len + 1. + const uint32_t window_len = plan.buffer_len; + const uint32_t position = plan.seq_len - 1; + const uint32_t pos_in_chunk_end = (position % 128u) + 1u; // exclusive, in [1, 128] + const uint32_t chunk_offset = pos_in_chunk_end - window_len; // in [0, 127] + const int32_t segment_start_ragged = static_cast(plan.ragged_id) - static_cast(position % 128u); + + // --- Stage 1: load kv / score / bias for this warp's 8 chunk positions. + PrefillStorage kv[kElementsPerWarp]; + PrefillStorage score[kElementsPerWarp]; + PrefillStorage bias[kElementsPerWarp]; + const uint32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + if (j >= chunk_offset && j < pos_in_chunk_end) { + const auto kv_src_ptr = kv_score_input + (segment_start_ragged + j) * kElementSize + split_offset; + const auto score_src_ptr = kv_src_ptr + kHeadDim; + const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset; + kv[i].load(kv_src_ptr, lane_id); + score[i].load(score_src_ptr, lane_id); + bias[i].load(bias_src_ptr, lane_id); + } + } + + // --- Stage 2: pad invalid positions. score = -FLT_MAX, kv = 0 (so that + // kv * exp(score-max) ??? 0 / 0 cleanly without producing NaN/inf). +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + const bool is_valid = (j >= chunk_offset && j < pos_in_chunk_end); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore; + kv[i][ii] = is_valid ? kv[i][ii] : 0.0f; + } + } + + // --- Stage 3: warp-tile online softmax over the 128-position chunk. + __shared__ alignas(16) float seg_kv[kTileDim]; + __shared__ alignas(16) float seg_max[kTileDim]; + __shared__ alignas(16) float seg_sum[kTileDim]; + c128_prefill_segment_softmax(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id); + + PDLTriggerSecondary(); + + // --- Stage 4: warp 0 folds with prior partial state (if any) and writes. + if (warp_id == 0) { + PrefillStorage out_kv_vec, out_max_vec, out_sum_vec; + out_kv_vec.load(seg_kv, lane_id); + out_max_vec.load(seg_max, lane_id); + out_sum_vec.load(seg_sum, lane_id); + + if (chunk_offset != 0 && plan.read_page_0 >= 0) { + // Combine with prior partial state for this slot. + const auto buf_load = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset; + PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec; + buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id); + buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id); + buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + const float m1 = buf_max_vec[ii]; + const float s1 = buf_sum_vec[ii]; + const float k1 = buf_kv_vec[ii]; + const float m2 = out_max_vec[ii]; + const float s2 = out_sum_vec[ii]; + const float k2 = out_kv_vec[ii]; + const float new_max = fmaxf(m1, m2); + const float new_s1 = s1 * expf(m1 - new_max); + const float new_s2 = s2 * expf(m2 - new_max); + const float new_sum = new_s1 + new_s2; + const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum; + out_max_vec[ii] = new_max; + out_sum_vec[ii] = new_sum; + out_kv_vec[ii] = new_kv; + } + } + + if constexpr (kWrite) { + // For trailing-partial segments the load and store slots collapse to the + // segment's own chunk slot (the request keeps a single in-progress + // chunk's running state at any time), so we reuse `read_page_0`. + const auto buf_store = kv_score_buffer + plan.read_page_0 * (kHeadDim * 3) + split_offset; + reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec; + reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec; + reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec; + } else { + // Compact output: one row per compress plan, indexed by `global_pid`. + const auto out_ptr = kv_compressed_output + global_pid * kHeadDim + split_offset; + reinterpret_cast(out_ptr)[lane_id] = out_kv_vec; + } + } +} + +// --------------------------------------------------------------------------- +// Host wrapper: matches the c128_v2 / c4_v2 host API style (run_decode / +// run_prefill methods on a kernel-class template). We only expose `kHeadDim` +// + `kUsePDL`; the dtype is fixed to fp32 for the online state pool. +// --------------------------------------------------------------------------- + +template +struct FlashCompress128OnlineKernel { + static constexpr auto decode_kernel = flash_c128_online_decode_v2; + template + static constexpr auto prefill_kernel = flash_c128_online_prefill_v2; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_d_) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output (sparse by batch_id) + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_d = compress::verify_plan_d(plan_d_, B, device_); + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + const auto params = Compress128OnlineDecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_d = plan_d, + .batch_size = batch_size, + }; + LaunchKernel(batch_size, kDecodeBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_c_, + const tvm::ffi::TensorView plan_w_) { + using namespace host; + + auto N = SymbolicSize{"num_q_tokens"}; + auto C = SymbolicSize{"num_c_plans"}; + auto W = SymbolicSize{"num_w_plans"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input (ragged) + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({C, kHeadDim}) // kv compressed output (compact, by plan_c index) + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + // Both compress and write segments use PlanC layout. plan_c uses + // read_page_1=-1 (unused); plan_w uses read_page_1=store_slot. + const auto plan_c = compress::verify_plan_c(plan_c_, C, device_); + const auto plan_w = compress::verify_plan_c(plan_w_, W, device_); + const auto device = device_.unwrap(); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(C.unwrap()); + const auto num_w = static_cast(W.unwrap()); + RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w"); + const auto params = Compress128OnlinePrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_c = plan_c, + .plan_w = plan_w, + .num_compress = num_c, + .num_write = num_w, + }; + + // The two passes MUST be serialized in stream order: pass 1 reads slots + // that pass 2 may write to; running them in parallel would race. + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_kernel, params); + } + if (const auto num_w_blocks = num_w * kNumSplit) { + LaunchKernel(num_w_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_kernel, params); + } + } +}; + +} // namespace + +// =========================================================================== +// Plan builders. Mirrors the offline v2 pattern (`c_plan.cuh`): +// - Decode: a single GPU kernel reads seq_lens / req_to_token / +// req_pool_indices on device and emits the final PlanD tensor in one go. +// - Prefill: stage 0 (host, on CPU pinned memory) splits each batch's +// extend range into per-chunk segments and emits PlanC entries with the +// batch_id stashed in `read_page_0` as a placeholder. Stage 1 is a tiny +// GPU kernel that finalizes `read_page_0` to `req_to_token[rid][chunk_start]`, +// so the slot tensors never leave GPU memory. The online state pool keeps +// a single in-progress chunk per request, so each segment's load and +// store slot collapse to one value (the slot for the segment's own chunk), +// and `read_page_1` is unused. +// =========================================================================== + +namespace host::compress { + +using device::compress::CompressPlan; +using device::compress::DecodePlan; + +// --------------------------------------------------------------------------- +// Decode plan builder. +// --------------------------------------------------------------------------- + +struct OnlineDecodePlanParams { + DecodePlan* __restrict__ plan_d; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ req_pool_indices; + const int32_t* __restrict__ req_to_token; + const int64_t* __restrict__ full_to_swa; // (full_cache_size,) int64 + int64_t stride_r2t; + int32_t swa_page_size; + uint32_t batch_size; +}; + +__global__ void plan_c128_online_decode_kernel(const OnlineDecodePlanParams params) { + const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.batch_size) return; + const auto seq_len = static_cast(params.seq_lens[idx]); + const auto rid = params.req_pool_indices[idx]; + const int32_t chunk_start = static_cast((seq_len - 1u) / 128u * 128u); + const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start]; + const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]); + const int32_t slot = swa_loc / params.swa_page_size; + params.plan_d[idx] = DecodePlan{ + .seq_len = seq_len, + .write_loc = slot, + .read_page_0 = slot, + .read_page_1 = -1, + }; +} + +/// \brief Build the decode plan tensor. Caller (Python) pre-allocates +/// `plan_d_dev` as a `(batch_size, 16)` device uint8 tensor; this routine +/// only fills it. See `plan_online_prefill` for the rationale (avoid +/// `ffi::empty` + dlpack roundtrip / PyTorch caching-allocator stream +/// tracking issue that surfaces as IMA in unrelated downstream kernels). +inline void plan_online_decode( + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView req_pool_indices, + const tvm::ffi::TensorView req_to_token, + const tvm::ffi::TensorView full_to_swa, + const tvm::ffi::TensorView plan_d_dev_, + const int32_t swa_page_size) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + auto seq_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(seq_dtype) + .with_device(device_) + .verify(seq_lens); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({B, sizeof(DecodePlan)}) // + .with_dtype() + .with_device(device_) + .verify(plan_d_dev_); + RuntimeCheck(swa_page_size > 0); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto device = device_.unwrap(); + constexpr uint32_t kBlockSize = 256; + const uint32_t num_blocks = host::div_ceil(batch_size, kBlockSize); + const auto stride_r2t = req_to_token.stride(0); + const auto params = OnlineDecodePlanParams{ + .plan_d = static_cast(plan_d_dev_.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .req_pool_indices = static_cast(req_pool_indices.data_ptr()), + .req_to_token = static_cast(req_to_token.data_ptr()), + .full_to_swa = static_cast(full_to_swa.data_ptr()), + .stride_r2t = stride_r2t, + .swa_page_size = swa_page_size, + .batch_size = batch_size, + }; + LaunchKernel(num_blocks, kBlockSize, device)(plan_c128_online_decode_kernel, params); +} + +// --------------------------------------------------------------------------- +// Prefill plan builder: host stage 0 + GPU stage 1. +// --------------------------------------------------------------------------- + +struct OnlinePrefillStage0Params { + CompressPlan* __restrict__ plan_c; + CompressPlan* __restrict__ plan_w; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_q_tokens; +}; + +inline std::tuple _plan_prefill_partial(const OnlinePrefillStage0Params& p) { + uint32_t counter = 0; + uint32_t compress_count = 0; + uint32_t write_count = 0; + for (const auto i : irange(p.batch_size)) { + const uint32_t seq_len = static_cast(p.seq_lens[i]); + const uint32_t extend_len = static_cast(p.extend_lens[i]); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t end_pos = prefix_len + extend_len; + + uint32_t pos = prefix_len; + while (pos < end_pos) { + const uint32_t chunk_start = (pos / 128u) * 128u; + const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive + const uint32_t seg_len = seg_end - pos; + const uint32_t chunk_off = pos - chunk_start; + const uint32_t last_pos = seg_end - 1; + const uint32_t last_ragged = counter + (last_pos - prefix_len); + RuntimeCheck(last_ragged < (1u << 16), "PlanC.ragged_id is uint16; ragged ", last_ragged, " overflows"); + RuntimeCheck(seg_len <= 128u); + // Stash batch_id in `read_page_0` for stage 1 to translate. A + // chunk-aligned segment never loads, so we still need stage 1 to fill + // a slot in -- the kernel keys the load on `chunk_offset != 0`. + const auto plan = CompressPlan{ + .seq_len = last_pos + 1u, + .ragged_id = static_cast(last_ragged), + .buffer_len = static_cast(seg_len), + .read_page_0 = static_cast(i), // batch_id placeholder + .read_page_1 = -1, // unused, kept so MSB layout is stable + }; + if (chunk_off + seg_len == 128u) { + // close-chunk segment + RuntimeCheck(compress_count < p.num_q_tokens); + p.plan_c[compress_count++] = plan; + } else { + // trailing partial segment + RuntimeCheck(write_count < p.num_q_tokens); + p.plan_w[write_count++] = plan; + } + pos = seg_end; + } + counter += extend_len; + } + RuntimeCheck(counter == p.num_q_tokens, "input size ", counter, " != num_q_tokens ", p.num_q_tokens); + return std::tuple{compress_count, write_count}; +} + +struct OnlinePrefillStage1Params { + CompressPlan* __restrict__ plan_c; + CompressPlan* __restrict__ plan_w; + const int64_t* __restrict__ req_pool_indices; // (batch_size,) + const int32_t* __restrict__ req_to_token; // (num_reqs, max_tokens) + const int64_t* __restrict__ full_to_swa; // (full_cache_size,) + int64_t stride_r2t; + int32_t swa_page_size; + uint32_t num_c; + uint32_t num_w; +}; + +__global__ void plan_c128_online_prefill_kernel(const OnlinePrefillStage1Params params) { + const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t total = params.num_c + params.num_w; + if (idx >= total) return; + + const bool is_compress = idx < params.num_c; + CompressPlan* const plan_ptr = is_compress ? ¶ms.plan_c[idx] : ¶ms.plan_w[idx - params.num_c]; + auto plan = *plan_ptr; + const auto batch_id = plan.read_page_0; + const auto rid = params.req_pool_indices[batch_id]; + const int32_t position = static_cast(plan.seq_len - 1u); + const int32_t chunk_start = (position / 128) * 128; + const int32_t full_loc = params.req_to_token[rid * params.stride_r2t + chunk_start]; + const int32_t swa_loc = static_cast(params.full_to_swa[full_loc]); + plan.read_page_0 = swa_loc / params.swa_page_size; + *plan_ptr = plan; +} + +using OnlinePrefillPlan = tvm::ffi::Tuple; + +inline OnlinePrefillPlan plan_online_prefill( + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView req_pool_indices, + const tvm::ffi::TensorView req_to_token, + const tvm::ffi::TensorView full_to_swa, + const tvm::ffi::TensorView plan_c_pin, + const tvm::ffi::TensorView plan_w_pin, + const tvm::ffi::TensorView plan_c_dev_, + const tvm::ffi::TensorView plan_w_dev_, + const int32_t swa_page_size) { + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto cpu = SymbolicDevice{}; + auto device_ = SymbolicDevice{}; + cpu.set_options(); + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(cpu) + .verify(seq_lens) + .verify(extend_lens); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({N, sizeof(CompressPlan)}) // + .with_dtype() + .with_device(cpu) + .verify(plan_c_pin) + .verify(plan_w_pin); + TensorMatcher({N, sizeof(CompressPlan)}) // + .with_dtype() + .with_device(device_) + .verify(plan_c_dev_) + .verify(plan_w_dev_); + + const auto stage0_params = OnlinePrefillStage0Params{ + .plan_c = static_cast(plan_c_pin.data_ptr()), + .plan_w = static_cast(plan_w_pin.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(B.unwrap()), + .num_q_tokens = static_cast(N.unwrap()), + }; + + // Debug instrumentation: SGLANG_DEBUG_C128_ONLINE_GUARD=1 wraps stage 0 + // with redzone + post-write magic-check on the pin buffers, plus a strict + // upper-bound check on `batch_size` and `num_q_tokens`. If stage 0 has a + // CPU OOB this trips a clear panic at the offending byte instead of a + // delayed CUDA IMA from corrupted heap memory. + static const bool kGuard = []() { + const char* v = std::getenv("SGLANG_DEBUG_C128_ONLINE_GUARD"); + return v != nullptr && v[0] == '1'; + }(); + if (kGuard) { + RuntimeCheck(stage0_params.batch_size <= 65536u, "batch_size out of bound: ", stage0_params.batch_size); + RuntimeCheck(stage0_params.num_q_tokens <= 65536u, "num_q_tokens out of bound: ", stage0_params.num_q_tokens); + // Stamp the pin buffers with 0xAB so we can detect any byte still 0xAB + // beyond what stage 0 should have written (= OOB never reached, that's fine) + // or any byte BEYOND num_q_tokens*16 written to (= true OOB into + // adjacent allocation). + auto* pc = static_cast(plan_c_pin.data_ptr()); + auto* pw = static_cast(plan_w_pin.data_ptr()); + const auto bytes = static_cast(N.unwrap()) * sizeof(CompressPlan); + std::memset(pc, 0xAB, bytes); + std::memset(pw, 0xAB, bytes); + } + + const auto [num_c, num_w] = _plan_prefill_partial(stage0_params); + + if (kGuard) { + // Verify stage 0 wrote ONLY to the [0, num_c*16) and [0, num_w*16) prefix. + auto* pc = static_cast(plan_c_pin.data_ptr()); + auto* pw = static_cast(plan_w_pin.data_ptr()); + const auto end_c = static_cast(num_c) * sizeof(CompressPlan); + const auto end_w = static_cast(num_w) * sizeof(CompressPlan); + const auto pin_bytes = static_cast(N.unwrap()) * sizeof(CompressPlan); + for (size_t k = end_c; k < pin_bytes; ++k) { + RuntimeCheck( + pc[k] == 0xAB, + "GUARD: plan_c_pin OOB write at byte ", + k, + " (num_c=", + num_c, + ", num_q_tokens=", + N.unwrap(), + ")"); + } + for (size_t k = end_w; k < pin_bytes; ++k) { + RuntimeCheck( + pw[k] == 0xAB, + "GUARD: plan_w_pin OOB write at byte ", + k, + " (num_w=", + num_w, + ", num_q_tokens=", + N.unwrap(), + ")"); + } + } + + const auto device = device_.unwrap(); + // Out-params pre-allocated by Python. Cast to typed pointers for use. + auto* const plan_c_dev_ptr = static_cast(plan_c_dev_.data_ptr()); + auto* const plan_w_dev_ptr = static_cast(plan_w_dev_.data_ptr()); + + if (const auto total = num_c + num_w) { + const auto stream = LaunchKernel::resolve_device(device); + // SGLANG_DEBUG_C128_ONLINE_SYNC_H2D=1 forces a synchronous H2D copy. + static const bool kSyncH2D = []() { + const char* v = std::getenv("SGLANG_DEBUG_C128_ONLINE_SYNC_H2D"); + return v != nullptr && v[0] == '1'; + }(); + // SGLANG_DEBUG_C128_ONLINE_NO_H2D=1 skips the H2D copy entirely (debug only). + static const bool kNoH2D = []() { + const char* v = std::getenv("SGLANG_DEBUG_C128_ONLINE_NO_H2D"); + return v != nullptr && v[0] == '1'; + }(); + const auto copy_to_device = [stream](void* dst, void* src, int64_t count) { + if (kNoH2D) return; + const auto bytes = count * sizeof(CompressPlan); + if (kSyncH2D) { + RuntimeDeviceCheck(::cudaMemcpy(dst, src, bytes, ::cudaMemcpyHostToDevice)); + } else { + RuntimeDeviceCheck(::cudaMemcpyAsync(dst, src, bytes, ::cudaMemcpyHostToDevice, stream)); + } + }; + if (num_c) copy_to_device(plan_c_dev_ptr, plan_c_pin.data_ptr(), num_c); + if (num_w) copy_to_device(plan_w_dev_ptr, plan_w_pin.data_ptr(), num_w); + + const auto stage1_params = OnlinePrefillStage1Params{ + .plan_c = plan_c_dev_ptr, + .plan_w = plan_w_dev_ptr, + .req_pool_indices = static_cast(req_pool_indices.data_ptr()), + .req_to_token = static_cast(req_to_token.data_ptr()), + .full_to_swa = static_cast(full_to_swa.data_ptr()), + .stride_r2t = req_to_token.stride(0), + .swa_page_size = swa_page_size, + .num_c = num_c, + .num_w = num_w, + }; + constexpr uint32_t kBlockSize = 128; + const auto num_blocks = host::div_ceil(total, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device)(plan_c128_online_prefill_kernel, stage1_params); + } + return OnlinePrefillPlan{num_c, num_w}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_128_online_decode = host::compress::plan_online_decode; +[[maybe_unused]] +constexpr auto& plan_compress_128_online_prefill = host::compress::plan_online_prefill; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh new file mode 100644 index 0000000000..31353e6a15 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c128_v2.cuh @@ -0,0 +1,448 @@ +/** + * \brief Here's some dimension info for the main buffer used in C128 prefill and decode. + * + * kv_buffer: [num_indices, 128, head_dim * 2] + * - last dimension layout: | kv | score | + * kv_input: [batch_size, head_dim * 2] + * kv_output: [batch_size, head_dim] + * score_bias (ape): [128, head_dim] + * plan_c/plan_w: [variable length] + * + * For prefill, batch_size = num_q_tokens + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using PlanD = device::compress::DecodePlan; +using PlanC = device::compress::CompressPlan; +using PlanW = device::compress::WritePlan; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; +constexpr uint32_t kWriteBlockSize = 128; // one warp per write + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) +#define WRITE_KERNEL __global__ __launch_bounds__(kWriteBlockSize, 16) + +struct Compress128DecodeParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanD* __restrict__ plan_d; + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanC* __restrict__ plan_c; + const PlanW* __restrict__ plan_w; + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +struct C128Trait { + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr int64_t kHeadDim = kHeadDim_; + static constexpr int64_t kScoreOffset = kHeadDim; + static constexpr int64_t kElementSize = kHeadDim * 2; + static constexpr int64_t kPageElementSize = 128 * kElementSize; // page size = 128 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0); +}; + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_buf, // [128n, 128n + 127] + const InFloat* kv_src, // ragged pointer at position = 128n + 127 + OutFloat* kv_out, + const InFloat* score_bias, + const int32_t buffer_len) { + using namespace device; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * Trait::kHeadDim); + } + + const auto kv_start = kv_src - 127 * Trait::kElementSize; // point to start + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + __builtin_assume(j < 128); + const auto src = j < buffer_len ? kv_buf : kv_start; + kv[i] = gmem_in.load(src + j * Trait::kElementSize); + score[i] = gmem_in.load(src + j * Trait::kElementSize + Trait::kScoreOffset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + + float score_fp32[kTileElements][kElementsPerWarp]; + + // convert to fp32 and apply bias first +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[i][j] = cast(score[j][i]) + cast(bias[j][i]); + } + } + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + const auto& score = score_fp32[i]; + float max_value = score[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + + PDLTriggerSecondary(); + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +SGL_DEVICE void c128_write_decode(InFloat* kv_buf, const InFloat* kv_src) { + using namespace device; + + using Storage = AlignedVector; + const auto gmem = tile::Memory::warp(); + + Storage data[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + data[i] = gmem.load(kv_src + Trait::kHeadDim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_buf + Trait::kHeadDim * i, data[i]); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + using Trait = C128Trait; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t global_bid = blockIdx.x / Trait::kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_bid >= params.batch_size) return; + + const auto plan = params.plan_d[global_bid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + + const auto kv_src = kv_input + global_bid * Trait::kElementSize; + const auto kv_out = kv_output + global_bid * Trait::kHeadDim; + const auto kv_buf = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + const auto kv_dst = kv_buffer + plan.write_loc * Trait::kElementSize; + + PDLWaitPrimary(); + // the write warp must match the load warp in the following `c128_forward` + if (warp_id == kNumWarps - 1) { + c128_write_decode(kv_dst, kv_src); + } + if (plan.write_loc % 128 == 127) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, 128); + } +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + using Trait = C128Trait; + + const uint32_t global_pid = blockIdx.x / Trait::kNumSplit; // plan id + const uint32_t global_sid = blockIdx.x % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_pid >= params.num_compress) return; + + const auto plan = params.plan_c[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + if (plan.is_invalid()) return; + + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + // Compact output: one row per compress plan, indexed by `global_pid`. + const auto kv_out = kv_output + global_pid * Trait::kHeadDim; + const auto kv_buf = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + PDLWaitPrimary(); + c128_forward(kv_buf, kv_src, kv_out, score_bias, plan.buffer_len); +} + +template +WRITE_KERNEL void write_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + using Trait = C128Trait; + using StorageIn = AlignedVector; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / Trait::kNumSplit; // plan id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + // split the contiguous `kHeadDim * 2` into `kNumSplit` tiles + // each warp handles 1 contiguous tile (in contrast, decode handle the strided head_dim) + const int64_t split_offset = global_sid * (Trait::kTileDim * 2); + if (global_pid >= params.num_write) return; + + const auto plan = params.plan_w[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + if (plan.is_invalid()) return; + + // each warp will handle a contiguous region + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + const auto kv_buf = kv_buffer + plan.write_loc * Trait::kElementSize; + const auto gmem = tile::Memory::warp(); + + PDLWaitPrimary(); + StorageIn data[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + data[i] = gmem.load(kv_src, i); + } + PDLTriggerSecondary(); +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_buf, data[i], i); + } +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + static constexpr auto prefill_c_kernel = flash_c128_prefill; + static constexpr auto prefill_w_kernel = write_c128_prefill; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + using Trait = C128Trait; + + static void run_decode( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_d_) { + using namespace host; + + auto N = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_d = compress::verify_plan_d(plan_d_, N, device_); + const auto batch_size = static_cast(N.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_d = plan_d, + .batch_size = batch_size, + }; + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_c_, + const tvm::ffi::TensorView plan_w_) { + using namespace host; + + auto N = SymbolicSize{"num_q_tokens"}; + auto C = SymbolicSize{"num_c_plans"}; + auto W = SymbolicSize{"num_w_plans"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input (ragged) + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({C, kHeadDim}) // kv compressed output (compact) + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_c = compress::verify_plan_c(plan_c_, C, device_); + const auto plan_w = compress::verify_plan_w(plan_w_, W, device_); + const auto device = device_.unwrap(); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(C.unwrap()); + const auto num_w = static_cast(W.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_c = plan_c, + .plan_w = plan_w, + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w"); + if (const auto num_c_blocks = num_c * kNumSplit) { + constexpr auto kBlockSize_C = kBlockSize; + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + constexpr auto kBlockSize_W = kWriteBlockSize; + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh new file mode 100644 index 0000000000..145ab1fb08 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4.cuh @@ -0,0 +1,549 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan4 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) + +enum class PageMode { + RingBuffer = 8, + Page4Align = 4, +}; + +struct alignas(16) C4IndexBundle { + int32_t load_first_page; + int32_t load_second_page; + int32_t write_first_page; + int32_t last_position; +}; + +struct Compress4DecodeParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \brief Shape: `[batch_size, 1]` */ + const int32_t* __restrict__ extra; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, 4]` */ + const C4IndexBundle* __restrict__ extra; + /** \brief The following part is plan info. */ + + const Plan4* __restrict__ compress_plan; + const Plan4* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +template +SGL_DEVICE void c4_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 4; + const auto gmem = tile::Memory::warp(); + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv overlap | [1] = kv | [2] = score overlap | [3] = score | + Storage kv_score[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t seq_len, + const int32_t window_len, + [[maybe_unused]] const InFloat* kv_score_overlap_buf = nullptr) { + using namespace device; + + const auto element_size = head_dim * 4; + const auto score_offset = head_dim * 2; + const auto overlap_stride = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const bool is_overlap = i < 4; + const InFloat* src; + if (i < window_len) { + /// NOTE: `seq_len` must be a multiple of 4 here + if constexpr (kPaged) { + const auto kv_score_ptr = is_overlap ? kv_score_overlap_buf : kv_score_buf; + const int32_t k = i % 4; + src = kv_score_ptr + k * element_size; + } else { + const int32_t k = (seq_len + i) % 8; + src = kv_score_buf + k * element_size; + } + } else { + /// NOTE: k in [-7, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = i - 7; + src = kv_score_src + k * element_size; + } + src += (is_overlap ? 0 : overlap_stride); + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + if (seq_len == 4) { + [[unlikely]]; + constexpr float kFloatNegInf = -1e9f; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[8]; + +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + gmem_out.store(kv_out, result); +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, extra, batch_size // decode info + ] = params; + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / kNumSplit; // batch id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + page_size - 1` + if constexpr (kMode == PageMode::Page4Align) { + const auto index_prev = extra[global_bid]; + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 3) % 4); + if (seq_len % 4 == 0) { + const auto kv_overlap = kv_buf + (index_prev - index) * (kElementSize * 4); + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, 8, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 7) % 8); + if (seq_len % 4 == 0) { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, /*window_size=*/8); + } + } + + PDLTriggerSecondary(); +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, extra, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / kNumSplit; // plan id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + const auto seq_len = position + 1; + const int32_t index = indices[global_bid]; + + PDLWaitPrimary(); + + if constexpr (kMode == PageMode::Page4Align) { + const auto write_second_page = index; + const auto [load_first_page, load_second_page, write_first_page, last_pos] = extra[global_bid]; + if constexpr (kWrite) { + int32_t index; + if (position < static_cast(last_pos)) { + index = write_first_page; + } else { + index = write_second_page; + } + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 4); + } else { + int32_t index_overlap, index_normal; + if (window_len <= 4) { + index_overlap = load_second_page; + index_normal = load_second_page; // not used + } else { + index_overlap = load_first_page; + index_normal = load_second_page; + } + const auto kv_buf = kv_score_buffer + index_normal * (kElementSize * 4) + split_offset; + const auto kv_overlap = kv_score_buffer + index_overlap * (kElementSize * 4) + split_offset; + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + if constexpr (kWrite) { + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 8); + } else { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len); + } + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress4Kernel { + template + static constexpr auto decode_kernel = flash_c4_decode; + template + static constexpr auto prefill_kernel = flash_c4_prefill; + template + static constexpr auto prefill_c_kernel = prefill_kernel; + template + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + + using Self = FlashCompress4Kernel; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional extra) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extra = static_cast(extra_ptr), + .batch_size = batch_size, + }; + const auto kernel = extra_ptr != nullptr ? decode_kernel // + : decode_kernel; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra, /*is_prefill=*/true); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .extra = static_cast(extra_ptr), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + const auto c_kernel = extra_ptr != nullptr ? prefill_c_kernel // + : prefill_c_kernel; + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + const auto w_kernel = extra_ptr != nullptr ? prefill_w_kernel // + : prefill_w_kernel; + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(w_kernel, params); + } + } + + // some auxiliary functions + private: + static const void* _get_extra_pointer( + host::SymbolicSize& B, // batch_size + host::SymbolicDevice& device, + const tvm::ffi::Optional& extra, + bool is_prefill = false) { + // only have value when using page-aligned mode + if (!extra.has_value()) return nullptr; + const auto& extra_tensor = extra.value(); + /// NOTE: the metadata layout is different for prefill and decode: + /// for prefill, last 4 are: + /// load overlap | load normal | write overlap | last written page + /// for decode, last 1 is the write (also load) overlap + host::TensorMatcher({B, is_prefill ? 4 : 1}) // extra tensor + .with_dtype() + .with_device(device) + .verify(extra_tensor); + const auto data_ptr = extra_tensor.data_ptr(); + host::RuntimeCheck(data_ptr != nullptr, "extra tensor data ptr is null"); + if (is_prefill) { + static_assert(alignof(C4IndexBundle) == 16); + host::RuntimeCheck(std::bit_cast(data_ptr) % 16 == 0, "extra tensor is not properly aligned"); + } + return data_ptr; + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh new file mode 100644 index 0000000000..efa9f05100 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c4_v2.cuh @@ -0,0 +1,405 @@ +/** + * \brief Here's some dimension info for the main buffer used in C4 prefill and decode. + * + * kv_buffer: [num_indices, 8, head_dim * 4] + * - last dimension layout: | kv overlap | kv | score overlap | score | + * kv_input: [batch_size, head_dim * 4] + * kv_output: [batch_size, head_dim] + * score_bias (ape): [8, head_dim] + * plan_c/plan_w: [variable length] + * + * For prefill, batch_size = num_q_tokens + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include + +namespace { + +using PlanD = device::compress::DecodePlan; +using PlanC = device::compress::CompressPlan; +using PlanW = device::compress::WritePlan; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) +#define WRITE_KERNEL __global__ __launch_bounds__(128, 16) + +struct Compress4DecodeParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanD* __restrict__ plan_d; + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + void* __restrict__ kv_buffer; + const void* __restrict__ kv_input; + void* __restrict__ kv_output; + const void* __restrict__ score_bias; + const PlanC* __restrict__ plan_c; + const PlanW* __restrict__ plan_w; + uint32_t num_compress; + uint32_t num_write; +}; + +template +struct C4Trait { + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 128 + static constexpr int64_t kHeadDim = kHeadDim_; + static constexpr int64_t kOverlapOffset = kHeadDim; + static constexpr int64_t kScoreOffset = kHeadDim * 2; + static constexpr int64_t kElementSize = kHeadDim * 4; + static constexpr int64_t kPageElementSize = 4 * kElementSize; // page size = 4 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0); +}; + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_buf_0, // overlap [4n - 4, 4n - 1] + const InFloat* kv_buf_1, // normal [4n + 0, 4n + 3] + const InFloat* kv_src, // ragged pointer at position = 4n + 3 + OutFloat* kv_out, + const InFloat* score_bias, + const bool should_overlap, + const int32_t buffer_len) { + using namespace device; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + /// NOTE: load one tile_dim (< head_dim) at at time + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * Trait::kHeadDim); + } + + if (should_overlap) { + const auto kv_start = kv_src - 7 * Trait::kElementSize; // point to start +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + const auto src = i < buffer_len ? kv_buf_0 : kv_start; + const auto base = src + i * Trait::kElementSize; + kv[i] = gmem_in.load(base); + score[i] = gmem_in.load(base + Trait::kScoreOffset); + } + } else { + [[unlikely]]; + constexpr float kFloatNegInf = -FLT_MAX; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + const auto kv_start = kv_src - 3 * Trait::kElementSize; // point to start +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + const auto src = i + 4 < buffer_len ? kv_buf_1 : kv_start; + const auto base = src + i * Trait::kElementSize + Trait::kOverlapOffset; + kv[i + 4] = gmem_in.load(base); + score[i + 4] = gmem_in.load(base + Trait::kScoreOffset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + + // consume 32 fp registers + float score_fp32[kTileElements][8]; + + // convert to fp32 and apply bias first +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + for (int32_t j = 0; j < 8; ++j) { + score_fp32[i][j] = cast(score[j][i]) + cast(bias[j][i]); + } + } + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + const auto& score = score_fp32[i]; + float max_value = score[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + // overlap the store with the next iteration's load + PDLTriggerSecondary(); + gmem_out.store(kv_out, result); +} + +template +SGL_DEVICE void c4_write_decode(InFloat* kv_buf, const InFloat* kv_src) { + using namespace device; + + using StorageIn = AlignedVector; + const auto gmem = tile::Memory::warp(); + + StorageIn data[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + data[i] = gmem.load(kv_src + Trait::kHeadDim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_buf + Trait::kHeadDim * i, data[i]); + } +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + using Trait = C4Trait; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / Trait::kNumSplit; // batch id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_bid >= params.batch_size) return; + + const auto plan = params.plan_d[global_bid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + + const auto kv_src = kv_input + global_bid * Trait::kElementSize; + const auto kv_out = kv_output + global_bid * Trait::kHeadDim; + const auto kv_buf_0 = kv_buffer + plan.read_page_0 * Trait::kPageElementSize; + const auto kv_buf_1 = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + const auto kv_dst = kv_buffer + plan.write_loc * Trait::kElementSize; + + PDLWaitPrimary(); + c4_write_decode(kv_dst, kv_src); + if (plan.seq_len % 4 == 0) { + const auto need_overlap = plan.seq_len > 4; + c4_forward(kv_buf_0, kv_buf_1, kv_src, kv_out, score_bias, need_overlap, 8); + } +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + using Trait = C4Trait; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / Trait::kNumSplit; // plan id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + const int64_t split_offset = global_sid * Trait::kTileDim; + if (global_pid >= params.num_compress) return; + + const auto plan = params.plan_c[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_output = static_cast(params.kv_output) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + const auto score_bias = static_cast(params.score_bias) + split_offset; + if (plan.is_invalid()) return; + + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + // Compact output: one row per compress plan, indexed by `global_pid`. + const auto kv_out = kv_output + global_pid * Trait::kHeadDim; + const auto kv_buf_0 = kv_buffer + plan.read_page_0 * Trait::kPageElementSize; + const auto kv_buf_1 = kv_buffer + plan.read_page_1 * Trait::kPageElementSize; + const bool need_overlap = plan.seq_len > 4; + PDLWaitPrimary(); + c4_forward(kv_buf_0, kv_buf_1, kv_src, kv_out, score_bias, need_overlap, plan.buffer_len); +} + +template +WRITE_KERNEL void write_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + using Trait = C4Trait; + using StorageIn = AlignedVector; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / Trait::kNumSplit; // plan id + const uint32_t global_sid = global_wid % Trait::kNumSplit; // split id + // split the contiguous `kHeadDim * 4` into `kNumSplit` tiles + // each warp handles 1 contiguous tile (in contrast, decode handle the strided head_dim) + const int64_t split_offset = global_sid * (Trait::kTileDim * 4); + if (global_pid >= params.num_write) return; + + const auto plan = params.plan_w[global_pid]; + const auto kv_input = static_cast(params.kv_input) + split_offset; + const auto kv_buffer = static_cast(params.kv_buffer) + split_offset; + if (plan.is_invalid()) return; + + // each warp will handle a contiguous region + const auto kv_src = kv_input + plan.ragged_id * Trait::kElementSize; + const auto kv_buf = kv_buffer + plan.write_loc * Trait::kElementSize; + const auto gmem = tile::Memory::warp(); + + PDLWaitPrimary(); + StorageIn data[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + data[i] = gmem.load(kv_src, i); + } + PDLTriggerSecondary(); +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_buf, data[i], i); + } +} + +template +struct FlashCompress4Kernel { + static constexpr auto decode_kernel = flash_c4_decode; + static constexpr auto prefill_c_kernel = flash_c4_prefill; + static constexpr auto prefill_w_kernel = write_c4_prefill; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + using Trait = C4Trait; + + static void run_decode( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_d_) { + using namespace host; + + auto N = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 4, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + + const auto plan_d = compress::verify_plan_d(plan_d_, N, device_); + const auto batch_size = static_cast(N.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_d = plan_d, + .batch_size = batch_size, + }; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_buffer, + const tvm::ffi::TensorView kv_input, + const tvm::ffi::TensorView kv_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView plan_c_, + const tvm::ffi::TensorView plan_w_) { + using namespace host; + + auto N = SymbolicSize{"num_q_tokens"}; + auto C = SymbolicSize{"num_c_plans"}; + auto W = SymbolicSize{"num_w_plans"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 4, Trait::kElementSize}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_buffer); + TensorMatcher({N, Trait::kElementSize}) // kv score input (ragged) + .with_dtype() + .with_device(device_) + .verify(kv_input); + TensorMatcher({C, kHeadDim}) // kv compressed output (compact) + .with_dtype() + .with_device(device_) + .verify(kv_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + const auto plan_c = compress::verify_plan_c(plan_c_, C, device_); + const auto plan_w = compress::verify_plan_w(plan_w_, W, device_); + const auto device = device_.unwrap(); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(C.unwrap()); + const auto num_w = static_cast(W.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_buffer = kv_buffer.data_ptr(), + .kv_input = kv_input.data_ptr(), + .kv_output = kv_output.data_ptr(), + .score_bias = ape.data_ptr(), + .plan_c = plan_c, + .plan_w = plan_w, + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= num_w, "invalid prefill plan: num_q < num_w"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh new file mode 100644 index 0000000000..3e4aaaf5f0 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/c_plan.cuh @@ -0,0 +1,839 @@ +#include +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include + +namespace host::compress { + +constexpr auto kDLUInt8 = DLDataType{.code = kDLUInt, .bits = 8, .lanes = 1}; + +using PlanC = CompressPlan; +using PlanW = WritePlan; +using PlanD = DecodePlan; + +using RID_T = int64_t; +using R2T_T = int32_t; +using F2S_T = int64_t; +using IDX_T = int64_t; + +/// NOTE: for the internal use, we pack the ragged and batch id, since both not exceed 65536 +SGL_DEVICE __host__ PlanW pack_w(uint32_t ragged_id, uint32_t batch_id, int32_t seq_len) { + return {static_cast(ragged_id | batch_id << 16), seq_len}; +} + +/// NOTE: for the internal use, we pack the ragged and batch id, since both not exceed 65536 +SGL_DEVICE uint2 unpack_w(PlanW plan) { + return {static_cast(plan.ragged_id), static_cast(plan.ragged_id >> 16)}; +} + +struct Prefill0Params { + PlanC* plan_c; + PlanW* plan_w; + const IDX_T* seq_lens_ptr; // [batch_size] + const IDX_T* extend_lens_ptr; // [batch_size] + uint32_t batch_size; + uint32_t num_q_tokens; + int32_t compress_ratio; + int32_t swa_page_size; + int32_t mtp_pad; +}; + +struct Prefill1Params { + PlanC* plan_c; + PlanW* plan_w; + const RID_T* rid_ptr; // [batch_size] + const R2T_T* r2t_ptr; // [num_reqs, stride_r2t] + const F2S_T* f2s_ptr; // [num_swa_slots] + int64_t stride_r2t; + uint32_t num_c; + uint32_t num_w; + uint32_t num_c_padded; + uint32_t num_w_padded; + uint32_t num_work; + int32_t swa_page_size; + int32_t ring_size; + int32_t compress_ratio; +}; + +struct DecodeParams { + PlanD* plan_d; + const RID_T* rid_ptr; // [batch_size] + const R2T_T* r2t_ptr; // [num_reqs, stride_r2t] + const F2S_T* f2s_ptr; // [num_swa_slots] + const IDX_T* seq_ptr; // [batch_size] + int64_t stride_r2t; + uint32_t batch_size; + int32_t swa_page_size; + int32_t ring_size; + int32_t compress_ratio; +}; + +struct Prefill1ParamsLegacy { + PlanC* plan_c; + PlanW* plan_w; + const RID_T* rid_ptr; // [batch_size] + uint32_t num_c; + uint32_t num_w; + uint32_t num_c_padded; + uint32_t num_w_padded; + uint32_t num_work; + int32_t compress_ratio; +}; + +struct DecodeParamsLegacy { + PlanD* plan_d; + const RID_T* rid_ptr; // [batch_size] + const IDX_T* seq_ptr; // [batch_size] + uint32_t batch_size; + int32_t compress_ratio; +}; + +inline constexpr uint32_t kMaxPrefillBatchSize = 1024; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { +#ifndef USE_ROCM + uint32_t n = __shfl_up_sync(device::kFullMask, val, offset); +#else + uint32_t n = __shfl_up(val, offset, 32); +#endif + if (lane_id >= offset) val += n; + } + return val; +} + +/// Warp-wide max/min for integer types. `device::warp::reduce_max` routes through +/// `dtype_trait::max` which is only specialized for FP types. +SGL_DEVICE uint32_t warp_reduce_max_u32(uint32_t val) { +#pragma unroll + for (uint32_t mask = 16; mask > 0; mask >>= 1) { +#ifndef USE_ROCM + val = max(val, __shfl_xor_sync(device::kFullMask, val, mask, 32)); +#else + val = max(val, __shfl_xor(val, mask, 32)); +#endif + } + return val; +} + +SGL_DEVICE uint32_t warp_reduce_min_u32(uint32_t val) { +#pragma unroll + for (uint32_t mask = 16; mask > 0; mask >>= 1) { +#ifndef USE_ROCM + val = min(val, __shfl_xor_sync(device::kFullMask, val, mask, 32)); +#else + val = min(val, __shfl_xor(val, mask, 32)); +#endif + } + return val; +} + +__global__ __launch_bounds__(1024, 1) // + void plan_compress_prefill_kernel0(const Prefill0Params params) { + using namespace device; + const auto tx = threadIdx.x; + const auto block_size = kMaxPrefillBatchSize; + constexpr auto kNumWarps = kMaxPrefillBatchSize / kWarpThreads; + const auto cr = params.compress_ratio; + const auto sps = params.swa_page_size; + const bool is_overlap = (cr == 4); + const int32_t window_size = cr * (is_overlap ? 2 : 1); + + alignas(128) __shared__ uint32_t counter_c; + alignas(128) __shared__ uint32_t counter_w; + __shared__ int32_t s_seq_len[kMaxPrefillBatchSize]; + __shared__ int32_t s_prefix_len[kMaxPrefillBatchSize]; + __shared__ uint32_t warp_max[kNumWarps]; + __shared__ uint32_t warp_min[kNumWarps]; + __shared__ uint32_t s_max_extend; + __shared__ uint32_t s_min_extend; + + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // === Stage A: load per-batch fields, init shared scratch === + int32_t seq_len = 0, extend_len = 0, prefix_len = 0; + if (tx < params.batch_size) { + seq_len = static_cast(params.seq_lens_ptr[tx]); + extend_len = static_cast(params.extend_lens_ptr[tx]); + prefix_len = seq_len - extend_len; + s_seq_len[tx] = seq_len; + s_prefix_len[tx] = prefix_len; + } + if (tx == 0) { + counter_c = 0; + counter_w = 0; + } + if (tx < kNumWarps) { + warp_max[tx] = 0; + warp_min[tx] = 0xFFFFFFFFu; + } + + // === Stage B: min/max(extend_len) for MTP-uniform detection === + // For min, treat threads outside `batch_size` as +inf so they don't pull the min down. + const uint32_t e_for_max = static_cast(extend_len); + const uint32_t e_for_min = (tx < params.batch_size) ? e_for_max : 0xFFFFFFFFu; + warp_max[warp_id] = warp_reduce_max_u32(e_for_max); + warp_min[warp_id] = warp_reduce_min_u32(e_for_min); + __syncthreads(); + if (warp_id == 0) { + s_max_extend = warp_reduce_max_u32(warp_max[lane_id]); + s_min_extend = warp_reduce_min_u32(warp_min[lane_id]); + } + __syncthreads(); + + const auto num_q = params.num_q_tokens; + // MTP-uniform: every batch shares the same small extend_len `E`, so we can decompose + // a global token id `k` into (batch_id, j) = (k / E, k % E) and skip the per-batch loop. + const bool is_mtp_extend = (s_min_extend == s_max_extend) && (s_max_extend > 0) && (s_max_extend <= 32); + + // === Stage C: emit valid plans, slot allocation via shared-mem atomicAdd === + if (is_mtp_extend) { + // Path 1: token-driven. Each global token id maps to exactly one (batch_id, j). + const uint32_t E = s_max_extend; + for (uint32_t k = tx; k < num_q; k += block_size) { + const uint32_t batch_id = k / E; + const uint32_t j = k % E; + const int32_t pl = s_prefix_len[batch_id]; + const int32_t sl = s_seq_len[batch_id]; + const int32_t position = pl + static_cast(j); + const uint32_t ragged_id = k; + + if ((position + 1) % cr == 0) { + const int32_t buffer_len = window_size - min(static_cast(j) + 1, window_size); + const uint32_t out_idx = atomicAdd(&counter_c, 1u); + params.plan_c[out_idx] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + .read_page_0 = -1, + .read_page_1 = static_cast(batch_id), + }; + } + + const int32_t last_c_pos = (sl / cr) * cr; + const int32_t first_w_pos = min(last_c_pos - (is_overlap ? cr : 0), sl - params.mtp_pad); + bool do_write = position >= first_w_pos; + if (!do_write && is_overlap) do_write = (position % sps) >= (sps - cr); + if (do_write) { + const uint32_t out_idx = atomicAdd(&counter_w, 1u); + params.plan_w[out_idx] = pack_w(ragged_id, batch_id, position + 1); + } + } + } else { + // Path 2: general prefill (long extend_len). Iterate batches in an outer loop; + // the whole block sweeps each batch's tokens in parallel. + uint32_t base_e = 0; + for (uint32_t batch_id = 0; batch_id < params.batch_size; ++batch_id) { + const int32_t pl = s_prefix_len[batch_id]; + const int32_t sl = s_seq_len[batch_id]; + const int32_t el = sl - pl; + const int32_t last_c_pos = (sl / cr) * cr; + const int32_t first_w_pos = min(last_c_pos - (is_overlap ? cr : 0), sl - params.mtp_pad); + for (int32_t j = static_cast(tx); j < el; j += static_cast(block_size)) { + const int32_t position = pl + j; + const uint32_t ragged_id = base_e + static_cast(j); + + if ((position + 1) % cr == 0) { + const int32_t buffer_len = window_size - min(j + 1, window_size); + const uint32_t out_idx = atomicAdd(&counter_c, 1u); + params.plan_c[out_idx] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + .read_page_0 = -1, + .read_page_1 = static_cast(batch_id), + }; + } + + bool do_write = position >= first_w_pos; + if (!do_write && is_overlap) do_write = (position % sps) >= (sps - cr); + if (do_write) { + const uint32_t out_idx = atomicAdd(&counter_w, 1u); + params.plan_w[out_idx] = pack_w(ragged_id, static_cast(batch_id), position + 1); + } + } + base_e += static_cast(el); + } + } + __syncthreads(); + + // === Stage D: pad [counter_c, num_q) / [counter_w, num_q) with invalid === + const auto total_c = counter_c; + const auto total_w = counter_w; + for (uint32_t k = total_c + tx; k < num_q; k += block_size) { + params.plan_c[k] = PlanC::invalid(); + } + for (uint32_t k = total_w + tx; k < num_q; k += block_size) { + params.plan_w[k] = PlanW::invalid(); + } +} + +/// NOTE: stage 1 +__global__ void plan_compress_prefill_kernel_1(const Prefill1Params params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.num_work) return; + auto plan_c = idx < params.num_c ? params.plan_c[idx] : PlanC::invalid(); + auto plan_w = idx < params.num_w ? params.plan_w[idx] : PlanW::invalid(); + + const auto compute_loc = [&](int32_t swa_loc) { + const auto swa_page = swa_loc / params.swa_page_size; + const auto ring_offset = swa_loc % params.ring_size; + return swa_page * params.ring_size + ring_offset; + }; + + if (!plan_c.is_invalid()) { // 1. in bound. 2. not masked + if (plan_c.buffer_len > 0) { + const auto batch_id = plan_c.read_page_1; + const auto rid = params.rid_ptr[batch_id]; + const auto mapping = params.r2t_ptr + rid * params.stride_r2t; + // `seq_len` should be ratio-aligned here + const auto position_1 = static_cast(plan_c.seq_len - 1); + // only used for c4, harmless for c128 + const auto position_0 = max(position_1 - params.compress_ratio, 0); + const auto raw_loc_0 = mapping[position_0]; + const auto raw_loc_1 = mapping[position_1]; + const auto swa_loc_0 = params.f2s_ptr[raw_loc_0]; + const auto swa_loc_1 = params.f2s_ptr[raw_loc_1]; + plan_c.read_page_0 = compute_loc(swa_loc_0) / params.compress_ratio; + plan_c.read_page_1 = compute_loc(swa_loc_1) / params.compress_ratio; + params.plan_c[idx] = plan_c; + } + } else if (idx < params.num_c_padded) { + params.plan_c[idx] = PlanC::invalid(); + } + + if (!plan_w.is_invalid()) { // 1. in bound. 2. not masked + const auto [ragged_id, batch_id] = unpack_w(plan_w); + const auto rid = params.rid_ptr[batch_id]; + const auto mapping = params.r2t_ptr + rid * params.stride_r2t; + // `seq_len` (`write_loc`) may not be aligned here + const auto position = static_cast(plan_w.write_loc - 1); + const auto raw_loc = mapping[position]; + const auto swa_loc = params.f2s_ptr[raw_loc]; + plan_w.ragged_id = ragged_id; + plan_w.write_loc = compute_loc(swa_loc); + params.plan_w[idx] = plan_w; + } else if (idx < params.num_w_padded) { + params.plan_w[idx] = PlanW::invalid(); + } +} + +__global__ void plan_compress_decode_kernel(const DecodeParams params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.batch_size) return; + const auto rid = params.rid_ptr[idx]; + const auto mapping = params.r2t_ptr + rid * params.stride_r2t; + const auto compute_loc = [&](int32_t swa_loc) { + const auto swa_page = swa_loc / params.swa_page_size; + const auto ring_offset = swa_loc % params.ring_size; + return swa_page * params.ring_size + ring_offset; + }; + const auto seq_len = static_cast(params.seq_ptr[idx]); + const auto position_1 = static_cast(seq_len - 1); + const auto position_0 = max(position_1 - params.compress_ratio, 0); + const auto raw_loc_0 = mapping[position_0]; + const auto raw_loc_1 = mapping[position_1]; + const auto swa_loc_0 = params.f2s_ptr[raw_loc_0]; + const auto swa_loc_1 = params.f2s_ptr[raw_loc_1]; + const auto write_loc = compute_loc(swa_loc_1); + const auto read_page_0 = compute_loc(swa_loc_0) / params.compress_ratio; + const auto read_page_1 = write_loc / params.compress_ratio; + params.plan_d[idx] = { + .seq_len = static_cast(seq_len), + .write_loc = write_loc, + .read_page_0 = read_page_0, + .read_page_1 = read_page_1, + }; +} + +__global__ void plan_compress_prefill_legacy_kernel(const Prefill1ParamsLegacy params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.num_work) return; + auto plan_c = idx < params.num_c ? params.plan_c[idx] : PlanC::invalid(); + auto plan_w = idx < params.num_w ? params.plan_w[idx] : PlanW::invalid(); + + /// Per-request ring buffer slot translation: + /// - c4: page = rid * 2 + (position / 4) % 2; slot = page * 4 + position % 4 + /// - c128: page = rid; slot = rid * 128 + position % 128 + const auto legacy_compute_page = [&](int32_t rid, int32_t position) { + if (params.compress_ratio == 4) return rid * 2 + ((position / 4) & 1); + return rid; // c128 + }; + const auto legacy_compute_loc = [&](int32_t rid, int32_t position) { + const auto remainder = position % params.compress_ratio; + return legacy_compute_page(rid, position) * params.compress_ratio + remainder; + }; + + if (!plan_c.is_invalid()) { + const auto batch_id = plan_c.read_page_1; + const auto rid = static_cast(params.rid_ptr[batch_id]); + // `seq_len` is ratio-aligned for compress events + const auto position_1 = static_cast(plan_c.seq_len) - 1; + const auto position_0 = max(position_1 - params.compress_ratio, 0); + plan_c.read_page_0 = legacy_compute_page(rid, position_0); + plan_c.read_page_1 = legacy_compute_page(rid, position_1); + params.plan_c[idx] = plan_c; + } else if (idx < params.num_c_padded) { + params.plan_c[idx] = PlanC::invalid(); + } + + if (!plan_w.is_invalid()) { + const auto [ragged_id, batch_id] = unpack_w(plan_w); + const auto rid = static_cast(params.rid_ptr[batch_id]); + // `write_loc` carries (position + 1) at this stage; may not be ratio-aligned + const auto position = static_cast(plan_w.write_loc) - 1; + plan_w.ragged_id = ragged_id; + plan_w.write_loc = legacy_compute_loc(rid, position); + params.plan_w[idx] = plan_w; + } else if (idx < params.num_w_padded) { + params.plan_w[idx] = PlanW::invalid(); + } +} + +__global__ void plan_compress_decode_legacy_kernel(const DecodeParamsLegacy params) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.batch_size) return; + /// Per-request ring buffer slot translation: + /// - c4: page = rid * 2 + (position / 4) % 2; slot = page * 4 + position % 4 + /// - c128: page = rid; slot = rid * 128 + position % 128 + const auto legacy_compute_page = [&](int32_t rid, int32_t position) { + if (params.compress_ratio == 4) return rid * 2 + ((position / 4) & 1); + return rid; // c128 + }; + const auto legacy_compute_loc = [&](int32_t rid, int32_t position) { + const auto remainder = position % params.compress_ratio; + return legacy_compute_page(rid, position) * params.compress_ratio + remainder; + }; + const auto rid = static_cast(params.rid_ptr[idx]); + const auto seq_len = static_cast(params.seq_ptr[idx]); + const auto position_1 = seq_len - 1; + const auto position_0 = max(position_1 - params.compress_ratio, 0); + const auto write_loc = legacy_compute_loc(rid, position_1); + const auto read_page_0 = legacy_compute_page(rid, position_0); + const auto read_page_1 = legacy_compute_page(rid, position_1); + params.plan_d[idx] = { + .seq_len = static_cast(seq_len), + .write_loc = write_loc, + .read_page_0 = read_page_0, + .read_page_1 = read_page_1, + }; +} + +using PrefillPlan = tvm::ffi::Tuple; + +/** + * \brief Build c4/c128 prefill plan tensors. CPU-resident. + * Inputs (all CPU-resident): + * @param req_pool_indices `[batch_size]` int64_t + * @param req_to_token `[num_reqs, max_tokens_per_req]` int64_t + * @param full_to_swa `[num_swa_slots]` int64_t + * @param seq_lens `[batch_size]` int64 + * @param extend_lens `[batch_size]` int64 + * @param compress_plan `[num_q_tokens, 16]` uint8 (output) + * @param write_plan `[num_q_tokens, 8]` uint8 (output) + * @param compress_ratio 4 for c4, 128 for c128 + * @param use_cuda_graph Whether the plans will be used with cuda graph (affects padding) + * @return (compress plan tensor, write plan tensor) + */ +inline PrefillPlan plan_compress_prefill( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView req_to_token, // GPU + const tvm::ffi::TensorView full_to_swa, // GPU + const tvm::ffi::TensorView seq_lens, // CPU/GPU + const tvm::ffi::TensorView extend_lens, // CPU/GPU + const tvm::ffi::TensorView pin_buffer, // CPU + const uint32_t num_q_tokens, + const int32_t compress_ratio, + const int32_t swa_page_size, + const int32_t ring_size, + const bool use_cuda_graph) { + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto cpu_or_gpu = SymbolicDevice{}; + auto device_ = SymbolicDevice{}; + cpu_or_gpu.set_options(); + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({B}) // + .with_dtype() + .with_device(cpu_or_gpu) + .verify(seq_lens) + .verify(extend_lens); + TensorMatcher({-1}) // + .with_dtype() + .with_device() + .verify(pin_buffer); + + const bool is_overlap = (compress_ratio == 4); + const int32_t window_size = compress_ratio * (is_overlap ? 2 : 1); + + const auto seq_ptr = static_cast(seq_lens.data_ptr()); + const auto ext_ptr = static_cast(extend_lens.data_ptr()); + const auto rid_ptr = static_cast(req_pool_indices.data_ptr()); + const auto r2t_ptr = static_cast(req_to_token.data_ptr()); + const auto f2s_ptr = static_cast(full_to_swa.data_ptr()); + + const auto batch_size = static_cast(B.unwrap()); + constexpr auto kMaxTokens = static_cast(std::numeric_limits::max()); + RuntimeCheck(compress_ratio == 4 || compress_ratio == 128); + RuntimeCheck(batch_size <= num_q_tokens && num_q_tokens <= kMaxTokens); + // `swa_page_size` >= `ring_size` >= `compress_ratio` + RuntimeCheck(swa_page_size % ring_size == 0 && ring_size % compress_ratio == 0); + + const auto device = device_.unwrap(); + const auto stream = LaunchKernel::resolve_device(device); + + constexpr int32_t kMaxMTPDraftTokens = 4; + const auto mtp_pad = std::min(ring_size - compress_ratio, kMaxMTPDraftTokens); + + if (cpu_or_gpu.unwrap().device_type == kDLGPU) { + // GPU input path: kernel0 builds the (CPU-loop-equivalent) plan metadata directly + // on device, padding to num_q_tokens with invalid; kernel_1 then finalizes the + // SWA-translated read/write locations. Used for MTP / cuda-graph capture where + // a host sync would be expensive. + RuntimeCheck(batch_size <= kMaxPrefillBatchSize, "GPU plan only support batch size up to ", kMaxPrefillBatchSize); + auto C = ffi::empty({num_q_tokens, sizeof(PlanC)}, kDLUInt8, device); + auto W = ffi::empty({num_q_tokens, sizeof(PlanW)}, kDLUInt8, device); + const auto params0 = Prefill0Params{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .seq_lens_ptr = seq_ptr, + .extend_lens_ptr = ext_ptr, + .batch_size = batch_size, + .num_q_tokens = num_q_tokens, + .compress_ratio = compress_ratio, + .swa_page_size = swa_page_size, + .mtp_pad = mtp_pad, + }; + LaunchKernel(1, kMaxPrefillBatchSize, device)(plan_compress_prefill_kernel0, params0); + // kernel_1 sees the already-padded buffers, so num_c == num_w == num_padded == num_q_tokens. + const auto params1 = Prefill1Params{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .rid_ptr = rid_ptr, + .r2t_ptr = r2t_ptr, + .f2s_ptr = f2s_ptr, + .stride_r2t = req_to_token.stride(0), + .num_c = num_q_tokens, + .num_w = num_q_tokens, + .num_c_padded = num_q_tokens, + .num_w_padded = num_q_tokens, + .num_work = num_q_tokens, + .swa_page_size = swa_page_size, + .ring_size = ring_size, + .compress_ratio = compress_ratio, + }; + const auto block_size_1 = 256; + const auto num_blocks_1 = div_ceil(params1.num_work, block_size_1); + LaunchKernel(num_blocks_1, block_size_1, device)(plan_compress_prefill_kernel_1, params1); + return PrefillPlan{std::move(C), std::move(W)}; + } + + // CPU input path: only here do we need the pinned scratch buffer. + const auto pin_buffer_bytes = static_cast(pin_buffer.numel()) * sizeof(uint8_t); + RuntimeCheck(pin_buffer_bytes >= num_q_tokens * (sizeof(PlanC) + sizeof(PlanW))); + const auto plan_c_ptr = reinterpret_cast(pin_buffer.data_ptr()); + const auto plan_w_ptr = reinterpret_cast(plan_c_ptr + num_q_tokens); + + uint32_t counter = 0; + uint32_t counter_c = 0; + uint32_t counter_w = 0; + + const auto should_compress = [=](int32_t position) { return (position + 1) % compress_ratio == 0; }; + for (const auto i : irange(batch_size)) { + const int32_t seq_len = seq_ptr[i]; + const int32_t extend_len = ext_ptr[i]; + const int32_t prefix_len = seq_len - extend_len; + const int32_t last_c_pos = seq_len / compress_ratio * compress_ratio; + const int32_t first_w_pos = last_c_pos - (is_overlap ? compress_ratio : 0); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const auto should_write = [=](int32_t position) { + if (position >= first_w_pos) return true; + return is_overlap && position % swa_page_size >= (swa_page_size - compress_ratio); + }; + for (const auto j : irange(extend_len)) { + const int32_t position = prefix_len + j; + const int32_t ragged_id = counter + j; + if (should_compress(position)) { + const auto buffer_len = window_size - std::min(j + 1, window_size); + plan_c_ptr[counter_c++] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + // to be filled by kernel + .read_page_0 = -1, + .read_page_1 = static_cast(i), + }; + } + if (should_write(position)) { + plan_w_ptr[counter_w++] = pack_w(ragged_id, i, position + 1); + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_q_tokens); + + const auto copy_to_device = [stream](void* cuda_ptr, auto* host_ptr, size_t count) { + const auto size_bytes = count * sizeof(*host_ptr); + RuntimeDeviceCheck(cudaMemcpyAsync(cuda_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice, stream)); + }; + const auto num_c_padded = use_cuda_graph ? num_q_tokens : counter_c; + const auto num_w_padded = use_cuda_graph ? num_q_tokens : counter_w; + auto C = ffi::empty({num_c_padded, sizeof(PlanC)}, kDLUInt8, device); + auto W = ffi::empty({num_w_padded, sizeof(PlanW)}, kDLUInt8, device); + copy_to_device(C.data_ptr(), plan_c_ptr, counter_c); + copy_to_device(W.data_ptr(), plan_w_ptr, counter_w); + const auto params = Prefill1Params{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .rid_ptr = rid_ptr, + .r2t_ptr = r2t_ptr, + .f2s_ptr = f2s_ptr, + .stride_r2t = req_to_token.size(1), + .num_c = counter_c, + .num_w = counter_w, + .num_c_padded = num_c_padded, + .num_w_padded = num_w_padded, + .num_work = std::max(num_c_padded, num_w_padded), + .swa_page_size = swa_page_size, + .ring_size = ring_size, + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(params.num_work, block_size); + LaunchKernel(num_blocks, block_size, device)(plan_compress_prefill_kernel_1, params); + return PrefillPlan{std::move(C), std::move(W)}; +} + +inline tvm::ffi::Tensor plan_compress_decode( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView req_to_token, // GPU + const tvm::ffi::TensorView full_to_swa, // GPU + const tvm::ffi::TensorView seq_lens, // CPU/GPU + const int32_t compress_ratio, + const int32_t swa_page_size, + const int32_t ring_size) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device_) + .verify(req_to_token); + TensorMatcher({-1}) // + .with_dtype() + .with_device(device_) + .verify(full_to_swa); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto device = device_.unwrap(); + auto D = ffi::empty({batch_size, sizeof(PlanD)}, kDLUInt8, device); + const auto params = DecodeParams{ + .plan_d = static_cast(D.data_ptr()), + .rid_ptr = static_cast(req_pool_indices.data_ptr()), + .r2t_ptr = static_cast(req_to_token.data_ptr()), + .f2s_ptr = static_cast(full_to_swa.data_ptr()), + .seq_ptr = static_cast(seq_lens.data_ptr()), + .stride_r2t = req_to_token.size(1), + .batch_size = batch_size, + .swa_page_size = swa_page_size, + .ring_size = ring_size, + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(batch_size, block_size); + LaunchKernel(num_blocks, block_size, device)(plan_compress_decode_kernel, params); + return D; +} + +/** + * \brief Build c4/c128 prefill plan tensors for the legacy non-paged ring + * buffer. Uses only `req_pool_indices` to derive ring slots: + * - c4 (overlap): each request occupies 2 contiguous pages (8 token slots) + * - c128: each request occupies 1 page (128 token slots) + * + * Inputs: + * @param req_pool_indices `[batch_size]` int64 (GPU) + * @param seq_lens `[batch_size]` int64 (CPU) + * @param extend_lens `[batch_size]` int64 (CPU) + * @param pin_buffer pinned scratch (CPU uint8) + * @return (compress plan tensor, write plan tensor) + */ +inline PrefillPlan plan_compress_prefill_legacy( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView seq_lens, // CPU + const tvm::ffi::TensorView extend_lens, // CPU + const tvm::ffi::TensorView pin_buffer, // CPU + const uint32_t num_q_tokens, + const int32_t compress_ratio, + const bool use_cuda_graph) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({B}) // + .with_dtype() + .with_device() + .verify(seq_lens) + .verify(extend_lens); + TensorMatcher({-1}) // + .with_dtype() + .with_device() + .verify(pin_buffer); + + const auto pin_buffer_bytes = static_cast(pin_buffer.numel()) * sizeof(uint8_t); + RuntimeCheck(pin_buffer_bytes >= num_q_tokens * (sizeof(PlanC) + sizeof(PlanW))); + const auto plan_c_ptr = reinterpret_cast(pin_buffer.data_ptr()); + const auto plan_w_ptr = reinterpret_cast(plan_c_ptr + num_q_tokens); + + const bool is_overlap = (compress_ratio == 4); + const auto seq_ptr = static_cast(seq_lens.data_ptr()); + const auto ext_ptr = static_cast(extend_lens.data_ptr()); + const auto rid_ptr = static_cast(req_pool_indices.data_ptr()); + + const auto window_size = compress_ratio * (is_overlap ? 2 : 1); + const auto batch_size = static_cast(B.unwrap()); + constexpr auto kMaxTokens = static_cast(std::numeric_limits::max()); + RuntimeCheck(compress_ratio == 4 || compress_ratio == 128); + RuntimeCheck(batch_size <= num_q_tokens && num_q_tokens <= kMaxTokens); + + uint32_t counter = 0; + uint32_t counter_c = 0; + uint32_t counter_w = 0; + const auto should_compress = [=](int32_t position) { return (position + 1) % compress_ratio == 0; }; + for (const auto i : irange(batch_size)) { + const int32_t seq_len = seq_ptr[i]; + const int32_t extend_len = ext_ptr[i]; + const int32_t prefix_len = seq_len - extend_len; + const int32_t last_c_pos = seq_len / compress_ratio * compress_ratio; + const int32_t first_w_pos = last_c_pos - (is_overlap ? compress_ratio : 0); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const auto should_write = [=](int32_t position) { return position >= first_w_pos; }; + for (const auto j : irange(extend_len)) { + const int32_t position = prefix_len + j; + const int32_t ragged_id = counter + j; + if (should_compress(position)) { + const auto buffer_len = window_size - std::min(j + 1, window_size); + plan_c_ptr[counter_c++] = { + .seq_len = static_cast(position + 1), + .ragged_id = static_cast(ragged_id), + .buffer_len = static_cast(buffer_len), + // to be filled by kernel + .read_page_0 = -1, + .read_page_1 = static_cast(i), + }; + } + if (should_write(position)) { + plan_w_ptr[counter_w++] = pack_w(ragged_id, i, position + 1); + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_q_tokens); + + const auto device = device_.unwrap(); + const auto stream = LaunchKernel::resolve_device(device); + const auto copy_to_device = [stream](void* cuda_ptr, auto* host_ptr, size_t count) { + const auto size_bytes = count * sizeof(*host_ptr); + RuntimeDeviceCheck(cudaMemcpyAsync(cuda_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice, stream)); + }; + const auto num_c_padded = use_cuda_graph ? num_q_tokens : counter_c; + const auto num_w_padded = use_cuda_graph ? num_q_tokens : counter_w; + auto C = ffi::empty({num_c_padded, sizeof(PlanC)}, kDLUInt8, device); + auto W = ffi::empty({num_w_padded, sizeof(PlanW)}, kDLUInt8, device); + copy_to_device(C.data_ptr(), plan_c_ptr, counter_c); + copy_to_device(W.data_ptr(), plan_w_ptr, counter_w); + const auto params = Prefill1ParamsLegacy{ + .plan_c = static_cast(C.data_ptr()), + .plan_w = static_cast(W.data_ptr()), + .rid_ptr = rid_ptr, + .num_c = counter_c, + .num_w = counter_w, + .num_c_padded = num_c_padded, + .num_w_padded = num_w_padded, + .num_work = std::max(num_c_padded, num_w_padded), + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(params.num_work, block_size); + if (num_blocks > 0) { + LaunchKernel(num_blocks, block_size, device)(plan_compress_prefill_legacy_kernel, params); + } + return PrefillPlan{std::move(C), std::move(W)}; +} + +inline tvm::ffi::Tensor plan_compress_decode_legacy( + const tvm::ffi::TensorView req_pool_indices, // GPU + const tvm::ffi::TensorView seq_lens, // GPU + const int32_t compress_ratio) { + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(req_pool_indices); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + RuntimeCheck(compress_ratio == 4 || compress_ratio == 128); + + const auto batch_size = static_cast(B.unwrap()); + const auto device = device_.unwrap(); + auto D = ffi::empty({batch_size, sizeof(PlanD)}, kDLUInt8, device); + const auto params = DecodeParamsLegacy{ + .plan_d = static_cast(D.data_ptr()), + .rid_ptr = static_cast(req_pool_indices.data_ptr()), + .seq_ptr = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + .compress_ratio = compress_ratio, + }; + const auto block_size = 256; + const auto num_blocks = div_ceil(batch_size, block_size); + LaunchKernel(num_blocks, block_size, device)(plan_compress_decode_legacy_kernel, params); + return D; +} + +} // namespace host::compress + +using namespace host::compress; // expose binding diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh new file mode 100644 index 0000000000..46acaa9c46 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/common.cuh @@ -0,0 +1,208 @@ +#include +#include + +#include + +#include + +namespace host::compress { + +using PlanResult = tvm::ffi::Tuple; + +struct CompressParams { + PrefillPlan* __restrict__ compress_plan; + PrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; + uint32_t compress_ratio; + bool is_overlap; +}; + +inline constexpr uint32_t kBlockSize = 1024; + +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) inline + +PLAN_KERNEL void plan_prefill_cuda(const __grid_constant__ CompressParams params) { + const auto &[ + compress_plan, write_plan, seq_lens, extend_lens, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + __shared__ uint32_t compress_counter; + __shared__ uint32_t write_counter; + + uint32_t batch_id = 0; + uint32_t counter = 0; + uint32_t extend_len = extend_lens[0]; + + const auto tid = threadIdx.x; + if (tid == 0) { + compress_counter = 0; + write_counter = 0; + } + __syncthreads(); + + for (uint32_t i = tid; i < num_tokens; i += blockDim.x) { + const uint32_t ragged_id = i; + uint32_t j = ragged_id - counter; + while (j >= extend_len) { + j -= extend_len; + batch_id += 1; + if (batch_id >= batch_size) [[unlikely]] + break; + counter += extend_len; + extend_len = extend_lens[batch_id]; + } + if (batch_id >= batch_size) [[unlikely]] + break; + const uint32_t seq_len = seq_lens[batch_id]; + const uint32_t extend_len = extend_lens[batch_id]; + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t ratio = compress_ratio * (1 + is_overlap); + const uint32_t window_len = j + 1 < ratio ? ratio - (j + 1) : 0; + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = ragged_id, + .batch_id = batch_id, + .position = position, + .window_len = window_len, + }; + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + if ((position + 1) % compress_ratio == 0) { + const auto write_pos = atomicAdd(&compress_counter, 1); + compress_plan[write_pos] = plan; + } + if (position >= start_write_pos) { + const auto write_pos = atomicAdd(&write_counter, 1); + write_plan[write_pos] = plan; + } + } + __syncthreads(); + constexpr auto kInvalid = static_cast(-1); + const auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + const auto compress_count = compress_counter; + const auto write_count = write_counter; + for (uint32_t i = compress_count + tid; i < num_tokens; i += blockDim.x) { + compress_plan[i] = kInvalidPlan; + } + for (uint32_t i = write_count + tid; i < num_tokens; i += blockDim.x) { + write_plan[i] = kInvalidPlan; + } +} + +inline PlanResult plan_prefill_host(const CompressParams& params, const bool use_cuda_graph) { + const auto &[ + compress_ptr, write_ptr, seq_lens_ptr, extend_lens_ptr, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + uint32_t counter = 0; + uint32_t compress_counter = 0; + uint32_t write_counter = 0; + const auto ratio = compress_ratio * (1 + is_overlap); + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = seq_lens_ptr[i]; + const uint32_t extend_len = extend_lens_ptr[i]; + const uint32_t prefix_len = seq_len - extend_len; + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + /// NOTE: `start_write_pos` must be a multiple of `compress_ratio` + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + /// NOTE: to avoid unsigned integer underflow, don't use `pos - compress_ratio` + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + /// NOTE: `position` is within [prefix_len, seq_len) + for (const auto j : irange(extend_len)) { + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = counter + j, + .batch_id = i, + .position = position, + .window_len = ratio - std::min(j + 1, ratio), + }; + RuntimeCheck(plan.is_valid(compress_ratio, is_overlap), "Internal error!"); + if ((position + 1) % compress_ratio == 0) { + compress_ptr[compress_counter++] = plan; + } + if (position >= start_write_pos) { + write_ptr[write_counter++] = plan; + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return PlanResult{compress_counter, write_counter}; + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_counter, num_tokens)) { + compress_ptr[i] = kInvalidPlan; + } + for (const auto i : irange(write_counter, num_tokens)) { + write_ptr[i] = kInvalidPlan; + } + return PlanResult{num_tokens, num_tokens}; +} + +inline PlanResult plan_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const uint32_t compress_ratio, + const bool is_overlap, // for overlap transform, we have to keep 1 more extra window + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + const bool is_cuda = [&] { + if (extend_lens.device().device_type == kDLCUDA) { + device.set_options(); + return true; + } else { + device.set_options(); + return false; + } + }(); + TensorMatcher({N}) // extend_lens and seq_lens + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kPrefillPlanDim}) // compress_plan and write_plan + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + + const auto params = CompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + .compress_ratio = compress_ratio, + .is_overlap = is_overlap, + }; + + if (!is_cuda) return plan_prefill_host(params, use_cuda_graph); + /// NOTE: cuda kernel plan is naturally compatible with cuda graph + LaunchKernel(1, kBlockSize, device.unwrap())(plan_prefill_cuda, params); + return PlanResult{params.num_tokens, params.num_tokens}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_prefill = host::compress::plan_prefill; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh new file mode 100644 index 0000000000..d3953578b9 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope.cuh @@ -0,0 +1,254 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using Plan = device::compress::PrefillPlan; + +/// \brief common block size for memory-bound kernel +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeParams { + void* __restrict__ input; + const void* __restrict__ weight; + float eps; + uint32_t num_works; + const void* __restrict__ handle; + const float* __restrict__ freqs_cis; + uint32_t compress_ratio; +}; + +enum class ForwardMode { + CompressExtend = 0, + CompressDecode = 1, + DefaultForward = 2, +}; + +template +__global__ void fused_norm_rope(const __grid_constant__ FusedNormRopeParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr int64_t kRopeVecSize = kRopeDim / (kWarpThreads * 2); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % (kWarpThreads * 2) == 0); + static_assert(kRopeDim % (kVecSize * kLocalSize) == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeVecSize == 1, "only support rope dim = 64"); + + const auto& [ + _input, _weight, eps, num_works, // norm + handle, freqs_cis, compress_ratio // rope + ] = params; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + + if (work_id >= num_works) return; + + DType* input; + int32_t position; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(handle)[work_id]; + input = static_cast(_input) + plan.ragged_id * kHeadDim; + position = plan.position + 1 - compress_ratio; + if (plan.ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + } else if constexpr (kMode == CompressDecode) { + input = static_cast(_input) + work_id * kHeadDim; + const auto seq_len = static_cast(handle)[work_id]; + if (seq_len % compress_ratio != 0) return; + position = seq_len - compress_ratio; + } else if constexpr (kMode == DefaultForward) { + input = static_cast(_input) + work_id * kHeadDim; + position = static_cast(handle)[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + + using Storage = AlignedVector; + __shared__ Storage s_rope_input[kNumWarps][kRopeSize]; + + // prefetch freq + const auto mem_freq = tile::Memory::warp(); + const auto freq = mem_freq.load(freqs_cis + position * kRopeDim); + + PDLWaitPrimary(); + + // part 1: norm + { + const auto gmem = tile::Memory::warp(); + Storage input_vec[kLocalSize]; + Storage weight_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input, i); + } + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + weight_vec[i] = gmem.load(_weight, i); + } + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + sum_of_squares += fp32_input * fp32_input; + } + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + const auto fp32_weight = cast(weight_vec[i][j]); + input_vec[i][j] = cast(fp32_input * norm_factor * fp32_weight); + } + } + + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope_input[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(input, input_vec[i], i); + } + } + + __syncwarp(); + } + + // part 2: rope + { + // mem elem = DType x 2 + using DTypex2_t = packed_t; + const auto mem_elem = tile::Memory::warp(); + const auto elem = mem_elem.load(s_rope_input[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t output = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(input + (kHeadDim - kRopeDim), cast(output)); + } + + PDLTriggerSecondary(); +} + +template +struct FusedNormRopeKernel { + template + static constexpr auto fused_kernel = fused_norm_rope; + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView handle, + const tvm::ffi::TensorView freqs_cis, + int32_t _mode, + float eps, + uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(_mode); + + auto B = SymbolicSize{"num_q_tokens"}; + auto N = SymbolicSize{"num_compress_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + switch (mode) { + case CompressExtend: + TensorMatcher({N, compress::kPrefillPlanDim}) // plan + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case CompressDecode: + TensorMatcher({N}) // seq_len + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case DefaultForward: + TensorMatcher({N}) // position + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio == 0); + break; + default: + Panic("unsupported forward mode: ", static_cast(mode)); + } + + // launch kernel + const auto num_compress_tokens = static_cast(N.unwrap()); + if (num_compress_tokens == 0) return; + const auto params = FusedNormRopeParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .eps = eps, + .num_works = num_compress_tokens, + .handle = handle.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .compress_ratio = compress_ratio, + }; + const auto num_blocks = div_ceil(num_compress_tokens, kNumWarps); + using KernelType = std::decay_t)>; + static constexpr KernelType kernel_table[3] = { + [static_cast(CompressExtend)] = fused_kernel, + [static_cast(CompressDecode)] = fused_kernel, + [static_cast(DefaultForward)] = fused_kernel, + }; + const auto kernel = kernel_table[static_cast(mode)]; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh new file mode 100644 index 0000000000..a9cac17544 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/fused_norm_rope_v2.cuh @@ -0,0 +1,643 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace { + +using PlanC = device::compress::CompressPlan; +using PlanD = device::compress::DecodePlan; +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +SGL_DEVICE uint8_t quant_fp4_e2m1(float x) { + const float ax = fminf(fabsf(x), 6.0f); + uint8_t idx = 0; + idx += ax > 0.25f; + idx += ax > 0.75f; + idx += ax > 1.25f; + idx += ax > 1.75f; + idx += ax > 2.5f; + idx += ax > 3.5f; + idx += ax > 5.0f; + if (x < 0.0f && idx != 0) idx |= 0x8; + return idx; +} + +constexpr uint32_t kBlockSize = 256; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeStoreParams { + void* __restrict__ input; + const void* __restrict__ handle; // plan decode / compress + const void* __restrict__ weight; + const float* __restrict__ freqs_cis; + const int32_t* __restrict__ out_loc; + uint8_t* __restrict__ kvcache; + float eps; + uint32_t compress_ratio; + uint32_t num_tokens; +}; + +enum class ForwardMode : bool { + CompressExtend = 0, + CompressDecode = 1, +}; + +#define INDEXER_KERNEL __global__ __launch_bounds__(kBlockSize, 8) +#define FLASHMLA_KERNEL __global__ __launch_bounds__(kBlockSize, 8) + +// ---------------------------------------------------------------------------- +// Indexer variant: kHeadDim = 128, 1 token per *warp* (8 tokens per block). +// Each warp's 32 lanes cover the full 128-elem head_dim (kVecSize = 4 each). +// Cache layout: 132 bytes/token (128 fp8 nope + 4 fp32 scale). +// ---------------------------------------------------------------------------- +template +INDEXER_KERNEL void fused_norm_rope_indexer(const __grid_constant__ FusedNormRopeStoreParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + constexpr int64_t kPageBytes = 132ll << kPageBits; + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + using Storage = AlignedVector; + using Float4 = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + // Lanes whose 4-elem pack lies in the rope tail (= last `kRopeSize` packs). + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + if (work_id >= params.num_tokens) return; + + const auto input = static_cast(params.input) + work_id * kHeadDim; + int32_t position; + int32_t out_loc; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.is_invalid()) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[plan.ragged_id]; + } else if constexpr (kMode == CompressDecode) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.seq_len % params.compress_ratio != 0) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float4 data, freq; + + // part 1: norm + { + Storage input_vec, weight_vec; + input_vec.load(input, lane_id); + weight_vec.load(params.weight, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + sum_of_squares += fp32_input * fp32_input; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + const auto fp32_weight = cast(weight_vec[i]); + data[i] = fp32_input * norm_factor * fp32_weight; + } + } + + // part 2: rope (rope-lane only, 4 elems per lane = 2 (real, imag) pairs) + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto freq_x_real = freq[0]; + const auto freq_x_imag = freq[1]; + const auto freq_y_real = freq[2]; + const auto freq_y_imag = freq[3]; + data[0] = x_real * freq_x_real - x_imag * freq_x_imag; + data[1] = x_real * freq_x_imag + x_imag * freq_x_real; + data[2] = y_real * freq_y_real - y_imag * freq_y_imag; + data[3] = y_real * freq_y_imag + y_imag * freq_y_real; + } + + // part 3: hadamard transform + { + // Stage 1: butterfly (data[0], data[1]) and (data[2], data[3]). + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + // Stage 2: butterfly (data[0], data[2]) and (data[1], data[3]). + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } + // Stages 3..7: cross-lane butterflies. Lower-lane (mask bit clear) keeps + // the sum, upper-lane (mask bit set) keeps the difference. shfl_xor is + // unsynchronized across early-returned lanes, but invalid-plan returns + // happen above for *all* lanes of a warp (work_id is warp-uniform), so + // the warp is intact here. +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { +#ifndef USE_ROCM + const float other = __shfl_xor_sync(kFullMask, data[i], mask, kWarpThreads); +#else + const float other = __shfl_xor(data[i], mask, kWarpThreads); +#endif + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; + } + + // part 4: per-warp UE8M0 quant + store. The whole warp emits one fp8 group + // (= 128 elements) plus a single fp32 scale, matching the indexer cache + // layout (`fused_store_indexer_cache`). + { + using OutStorage = AlignedVector; + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + const auto abs_max = warp::reduce_max(local_max); + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 128; + const auto scale_ptr = page_ptr + (128 << kPageBits) + offset * 4; + OutStorage result; + result[0] = pack_fp8(data[0] * inv_scale, data[1] * inv_scale); + result[1] = pack_fp8(data[2] * inv_scale, data[3] * inv_scale); + PDLTriggerSecondary(); + result.store(value_ptr, lane_id); + // The single fp32 scale is identical across all lanes -- write from any lane. + if (lane_id == 0) reinterpret_cast(scale_ptr)[0] = scale; + } +} + +template +INDEXER_KERNEL void fused_norm_rope_indexer_fp4(const __grid_constant__ FusedNormRopeStoreParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + constexpr int64_t kPageBytes = 68ll << kPageBits; + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + using Storage = AlignedVector; + using Float4 = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + if (work_id >= params.num_tokens) return; + + const auto input = static_cast(params.input) + work_id * kHeadDim; + int32_t position; + int32_t out_loc; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.is_invalid()) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[plan.ragged_id]; + } else if constexpr (kMode == CompressDecode) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.seq_len % params.compress_ratio != 0) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float4 data, freq; + + { + Storage input_vec, weight_vec; + input_vec.load(input, lane_id); + weight_vec.load(params.weight, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + sum_of_squares += fp32_input * fp32_input; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + const auto fp32_weight = cast(weight_vec[i]); + data[i] = fp32_input * norm_factor * fp32_weight; + } + } + + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto freq_x_real = freq[0]; + const auto freq_x_imag = freq[1]; + const auto freq_y_real = freq[2]; + const auto freq_y_imag = freq[3]; + data[0] = x_real * freq_x_real - x_imag * freq_x_imag; + data[1] = x_real * freq_x_imag + x_imag * freq_x_real; + data[2] = y_real * freq_y_real - y_imag * freq_y_imag; + data[3] = y_real * freq_y_imag + y_imag * freq_y_real; + } + + { + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const float other = __shfl_xor_sync(0xFFFFFFFFu, data[i], mask, kWarpThreads); + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; + } + + { + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + local_max = warp::reduce_max<8>(local_max); + + const auto scale_raw = fmaxf(1e-4f, local_max) / 6.0f; + const auto scale_ue8m0 = static_cast(cast_to_ue8m0(scale_raw)); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + + const uint8_t packed0 = quant_fp4_e2m1(data[0] * inv_scale) | (quant_fp4_e2m1(data[1] * inv_scale) << 4); + const uint8_t packed1 = quant_fp4_e2m1(data[2] * inv_scale) | (quant_fp4_e2m1(data[3] * inv_scale) << 4); + const uint16_t packed = static_cast(packed0) | (static_cast(packed1) << 8); + + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 64; + const auto scale_ptr = page_ptr + (64 << kPageBits) + offset * 4; + + PDLTriggerSecondary(); + reinterpret_cast(value_ptr)[lane_id] = packed; + if ((lane_id & 7) == 0) static_cast(scale_ptr)[lane_id >> 3] = scale_ue8m0; + } +} + +// ---------------------------------------------------------------------------- +// FlashMLA variant: kHeadDim = 512, 1 token per *block* (256 threads). +// Each thread loads kVecSize=2 BF16, so 256 threads cover the full 512 elems. +// Cache layout: 584 bytes/token = 448 fp8 nope + 64 (=32 bf16x2) rope + 8 scale. +// ---------------------------------------------------------------------------- +template +FLASHMLA_KERNEL void fused_norm_rope_flashmla(const __grid_constant__ FusedNormRopeStoreParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kHeadDim = 512; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 2; + // Last warp owns the rope tail. The remaining 7 warps each emit one + // 64-element fp8 group (own UE8M0 scale). + constexpr uint32_t kRopeWarp = kNumWarps - 1; + constexpr int64_t kPageBytes = host::div_ceil(584ll << kPageBits, 576) * 576; + static_assert(kHeadDim == kBlockSize * kVecSize); + static_assert(kRopeDim == kWarpThreads * kVecSize); + static_assert(kHeadDim - kRopeDim == kRopeWarp * kWarpThreads * kVecSize); + using Storage = AlignedVector; + using Float2 = AlignedVector; + + const auto tx = threadIdx.x; + const auto warp_id = tx / kWarpThreads; + const auto lane_id = tx % kWarpThreads; + const auto work_id = blockIdx.x; + + if (work_id >= params.num_tokens) return; + + const auto input = static_cast(params.input) + work_id * kHeadDim; + int32_t position; + int32_t out_loc; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.is_invalid()) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[plan.ragged_id]; + } else if constexpr (kMode == CompressDecode) { + const auto plan = static_cast(params.handle)[work_id]; + if (plan.seq_len % params.compress_ratio != 0) return; + position = plan.seq_len - params.compress_ratio; + out_loc = params.out_loc[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float2 data, freq; + + // part 1: norm. Each thread owns one 2-elem pack (`tx`-th pack of input). + // Sum of squares is reduced across the whole block via per-warp partials. + { + __shared__ float partial_sums[kNumWarps]; + + Storage input_vec, weight_vec; + input_vec.load(input, tx); + weight_vec.load(params.weight, tx); + if (warp_id == kRopeWarp) freq.load(freqs_cis, lane_id); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + sum_of_squares += fp32_input * fp32_input; + } + + const auto warp_sum = warp::reduce_sum(sum_of_squares); + if (lane_id == 0) partial_sums[warp_id] = warp_sum; + __syncthreads(); + // Replicate the per-warp partial sums to a full warp and reduce. Every + // lane-group of `kNumWarps` lanes ends up with the global sum. + sum_of_squares = warp::reduce_sum(partial_sums[lane_id % kNumWarps]); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto fp32_input = cast(input_vec[i]); + const auto fp32_weight = cast(weight_vec[i]); + data[i] = fp32_input * norm_factor * fp32_weight; + } + } + + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 576; + + PDLTriggerSecondary(); + + // part 2: rope on the rope warp (BF16 store), or per-warp FP8 quant + store. + if (warp_id == kRopeWarp) { + // Each rope-warp lane owns exactly one (real, imag) pair within the rope + // tail. Apply rotation, downcast to BF16, write to the slot's rope region. + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto freq_real = freq[0]; + const auto freq_imag = freq[1]; + data[0] = x_real * freq_real - x_imag * freq_imag; + data[1] = x_real * freq_imag + x_imag * freq_real; + const auto result = cast(fp32x2_t{data[0], data[1]}); + const auto rope_ptr = value_ptr + 448; + reinterpret_cast(rope_ptr)[lane_id] = result; + } else { + // Non-rope warp: per-warp UE8M0 group (64 elems -> 64 fp8 + 1 scale byte). + // BF16 round-trip to match the precision of the non-fused path + // (which goes through quant_to_nope_fp8_rope_bf16_pack_triton with bf16 input). + const auto x = cast(cast(data[0])); + const auto y = cast(cast(data[1])); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const auto scale_ptr = page_ptr + (576 << kPageBits) + offset * 8; + reinterpret_cast(value_ptr)[tx] = result; + // All lanes in this warp produce the same scale byte; let lane 0 publish. + if (lane_id == 0) static_cast(scale_ptr)[warp_id] = scale_ue8m0; + } +} + +template +struct FusedNormRopeKernel { + static constexpr int32_t kLogPageSize = std::countr_zero(kPageSize); + static constexpr bool kIsIndexer = (kHeadDim == 128); + static constexpr int64_t kIndexerBytes = 132 * kPageSize; + static constexpr int64_t kFlashMLABytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr int64_t kPageBytes = kIsIndexer ? kIndexerBytes : kFlashMLABytes; + + /// TODO: Let's fix the config for now. + static_assert(kRopeDim == 64 && (kHeadDim == 128 || kHeadDim == 512)); + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + + template + static constexpr auto select_kernel() { + if constexpr (kIsIndexer) { + return fused_norm_rope_indexer; + } else { + return fused_norm_rope_flashmla; + } + } + + template + static constexpr auto select_fp4_kernel() { + static_assert(kIsIndexer, "FP4 fused store is only defined for the indexer"); + return fused_norm_rope_indexer_fp4; + } + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView plan, + const tvm::ffi::TensorView weight, + const float eps, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView out_loc, + const tvm::ffi::TensorView kvcache, + const bool is_decode, + const uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(is_decode); + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({N, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + TensorMatcher({-1}) // out_loc + .with_dtype() + .with_device(device_) + .verify(out_loc); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(kvcache); + + switch (mode) { + case CompressExtend: + compress::verify_plan_c(plan, N, device_); + RuntimeCheck(out_loc.size(0) >= N.unwrap()); + break; + case CompressDecode: + compress::verify_plan_d(plan, N, device_); + RuntimeCheck(out_loc.size(0) == N.unwrap()); + break; + } + + const auto num_tokens = static_cast(N.unwrap()); + if (num_tokens == 0) return; + const auto params = FusedNormRopeStoreParams{ + .input = input.data_ptr(), + .handle = plan.data_ptr(), + .weight = weight.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .out_loc = static_cast(out_loc.data_ptr()), + .kvcache = static_cast(kvcache.data_ptr()), + .eps = eps, + .compress_ratio = compress_ratio, + .num_tokens = num_tokens, + }; + // Indexer packs `kNumWarps` tokens per block (warp-major); FlashMLA uses + // a whole block per token (cta-major sum-reduce over head_dim=512). + const uint32_t num_blocks = kIsIndexer ? div_ceil(num_tokens, kNumWarps) : num_tokens; + const auto device = device_.unwrap(); + const auto kernel = mode == CompressExtend ? select_kernel() : select_kernel(); + LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + } + + static void forward_fp4( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView plan, + const tvm::ffi::TensorView weight, + const float eps, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView out_loc, + const tvm::ffi::TensorView kvcache, + const bool is_decode, + const uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + static_assert(kIsIndexer, "FP4 fused store is only defined for the indexer"); + constexpr int64_t kFp4PageBytes = 68 * kPageSize; + const auto mode = static_cast(is_decode); + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({N, kHeadDim}).with_dtype().with_device(device_).verify(input); + TensorMatcher({kHeadDim}).with_dtype().with_device(device_).verify(weight); + TensorMatcher({-1, kRopeDim}).with_dtype().with_device(device_).verify(freqs_cis); + TensorMatcher({-1}).with_dtype().with_device(device_).verify(out_loc); + TensorMatcher({-1, -1}).with_strides({kFp4PageBytes, 1}).with_dtype().with_device(device_).verify(kvcache); + + switch (mode) { + case CompressExtend: + compress::verify_plan_c(plan, N, device_); + RuntimeCheck(out_loc.size(0) >= N.unwrap()); + break; + case CompressDecode: + compress::verify_plan_d(plan, N, device_); + RuntimeCheck(out_loc.size(0) == N.unwrap()); + break; + } + + const auto num_tokens = static_cast(N.unwrap()); + if (num_tokens == 0) return; + const auto params = FusedNormRopeStoreParams{ + .input = input.data_ptr(), + .handle = plan.data_ptr(), + .weight = weight.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .out_loc = static_cast(out_loc.data_ptr()), + .kvcache = static_cast(kvcache.data_ptr()), + .eps = eps, + .compress_ratio = compress_ratio, + .num_tokens = num_tokens, + }; + const uint32_t num_blocks = div_ceil(num_tokens, kNumWarps); + const auto device = device_.unwrap(); + const auto kernel = + mode == CompressExtend ? select_fp4_kernel() : select_fp4_kernel(); + LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh new file mode 100644 index 0000000000..90dec3c117 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hash_topk.cuh @@ -0,0 +1,214 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +[[maybe_unused]] +SGL_DEVICE float act_sqrt_softplus(float x) { + const float softplus = fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x))); + return sqrtf(softplus); +} + +struct MoEHashTopKParams { + const float* __restrict__ router_logits; + const int64_t* __restrict__ input_id; + const int32_t* __restrict__ tid2eid; + int32_t* __restrict__ topk_ids; + float* __restrict__ topk_weights; + uint32_t num_tokens; + uint32_t topk; + uint32_t num_routed_experts; + uint32_t num_shared_experts; + float routed_scaling_factor; +}; + +template +__global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ params) { + using namespace device; + const auto& [ + router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers + num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] = + params; + + const uint32_t topk_fused = topk + num_shared_experts; + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / kWarpThreads; + const uint32_t lane_id = tid % kWarpThreads; + if (warp_id >= num_tokens) return; + // we can safely prefetch the token id + const auto token_id = input_id[warp_id]; + + PDLWaitPrimary(); + + float routed_weight = 0.0f; + int32_t expert_id = 0; + if (lane_id < topk) { + expert_id = tid2eid[token_id * topk + lane_id]; + routed_weight = Fn(router_logits[warp_id * num_routed_experts + expert_id]); + } + + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk_fused) { + const bool is_shared = lane_id >= topk; + const auto output_offset = warp_id * topk_fused + lane_id; + topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id; + topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum; + } + + PDLTriggerSecondary(); +} + +struct TopKParams { + int32_t* __restrict__ topk_ids; + // Exactly one is active: ntn_ptr == nullptr means use ntn_value. + const int32_t* __restrict__ ntn_ptr; + int32_t ntn_value; + int64_t stride; + uint32_t topk; + uint32_t num_tokens; +}; + +__global__ void mask_topk_ids_padded_region(const TopKParams __grid_constant__ params) { + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / device::kWarpThreads; + const uint32_t lane_id = tid % device::kWarpThreads; + if (warp_id >= params.num_tokens || lane_id >= params.topk) return; + device::PDLWaitPrimary(); + const uint32_t num = (params.ntn_ptr != nullptr) // + ? static_cast(params.ntn_ptr[0]) + : static_cast(params.ntn_value); + if (warp_id >= num) params.topk_ids[warp_id * params.stride + lane_id] = -1; + device::PDLTriggerSecondary(); +} + +template +struct HashTopKKernel { + static constexpr auto kernel = moe_hash_topk_fused; + + static void + run(const tvm::ffi::TensorView router_logits, + const tvm::ffi::TensorView input_id, + const tvm::ffi::TensorView tid2eid, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView topk_ids, + float routed_scaling_factor) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto E = SymbolicSize{"num_routed_experts"}; + auto K = SymbolicSize{"topk_fused"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, E}) // + .with_dtype() + .with_device(device) + .verify(router_logits); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input_id); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device) + .verify(tid2eid); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_ids); + + const auto num_tokens = static_cast(N.unwrap()); + const auto topk_fused = static_cast(K.unwrap()); + const auto topk = static_cast(tid2eid.size(1)); + const auto shared_experts = topk_fused - topk; + RuntimeCheck(topk <= topk_fused, "HashTopKKernel requires topk <= topk_fused"); + RuntimeCheck(topk_fused <= device::kWarpThreads, "HashTopKKernel requires topk_fused <= warp size"); + + const auto params = MoEHashTopKParams{ + .router_logits = static_cast(router_logits.data_ptr()), + .input_id = static_cast(input_id.data_ptr()), + .tid2eid = static_cast(tid2eid.data_ptr()), + .topk_ids = static_cast(topk_ids.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .num_tokens = num_tokens, + .topk = topk, + .num_routed_experts = static_cast(E.unwrap()), + .num_shared_experts = shared_experts, + .routed_scaling_factor = routed_scaling_factor, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +// TODO this may not be related to *hash* topk, thus may move +struct MaskKernel { + static constexpr auto kernel = mask_topk_ids_padded_region; + + static void run(tvm::ffi::TensorView topk_ids, tvm::ffi::TensorView num_token_non_padded) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto K = SymbolicSize{"topk"}; + auto D = SymbolicSize{"stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N, K}) // + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(topk_ids); + RuntimeCheck(num_token_non_padded.numel() == 1, "num_token_non_padded should be a scalar"); + RuntimeCheck(K.unwrap() <= device::kWarpThreads, "MaskKernel requires topk <= warp size"); + const int32_t* ntn_ptr = nullptr; + int32_t ntn_value = 0; + const auto ntn_dev = num_token_non_padded.device().device_type; + if (ntn_dev == kDLCUDA) { + RuntimeCheck(is_type(num_token_non_padded.dtype()), "num_token_non_padded on CUDA must be int32"); + ntn_ptr = static_cast(num_token_non_padded.data_ptr()); + } else if (ntn_dev == kDLCPU) { + if (is_type(num_token_non_padded.dtype())) { + ntn_value = *static_cast(num_token_non_padded.data_ptr()); + } else if (is_type(num_token_non_padded.dtype())) { + ntn_value = static_cast(*static_cast(num_token_non_padded.data_ptr())); + } else { + RuntimeCheck(false, "num_token_non_padded on CPU must be int32 or int64"); + } + } else { + RuntimeCheck(false, "num_token_non_padded must be on CPU or CUDA"); + } + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = TopKParams{ + .topk_ids = static_cast(topk_ids.data_ptr()), + .ntn_ptr = ntn_ptr, + .ntn_value = ntn_value, + .stride = static_cast(D.unwrap()), + .topk = static_cast(K.unwrap()), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(true)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh new file mode 100644 index 0000000000..aefec24372 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/hisparse_transfer.cuh @@ -0,0 +1,82 @@ +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace { + +/// NOTE: for offload to cpu kernel, we use persistent kernel +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kBlockQuota = 4; + +#define OFFLOAD_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct OffloadParams { + void** gpu_caches; + void** cpu_caches; + const int64_t* gpu_indices; + const int64_t* cpu_indices; + uint32_t num_items; + uint32_t num_layers; +}; + +OFFLOAD_KERNEL void offload_to_cpu(const __grid_constant__ OffloadParams params) { + using namespace device::hisparse; + const auto [gpu_caches, cpu_caches, gpu_indices, cpu_indices, num_items, num_layers] = params; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + constexpr auto kNumWarps = (kBlockSize / 32) * kBlockQuota; + for (auto i = global_tid / 32; i < num_items; i += kNumWarps) { + const int32_t gpu_index = gpu_indices[i]; + const int32_t cpu_index = cpu_indices[i]; + for (auto j = 0u; j < num_layers; ++j) { + const auto gpu_cache = gpu_caches[j]; + const auto cpu_cache = cpu_caches[j]; + transfer_item( + /*dst_cache=*/cpu_cache, + /*src_cache=*/gpu_cache, + /*dst_index=*/cpu_index, + /*src_index=*/gpu_index); + } + } +} + +[[maybe_unused]] +void hisparse_transfer( + tvm::ffi::TensorView gpu_ptrs, + tvm::ffi::TensorView cpu_ptrs, + tvm::ffi::TensorView gpu_indices, + tvm::ffi::TensorView cpu_indices) { + using namespace host; + auto N = SymbolicSize{"num_items"}; + auto L = SymbolicSize{"num_layers"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({L}) // 1D cache pointers + .with_dtype() + .with_device(device_) + .verify(gpu_ptrs) + .verify(cpu_ptrs); + TensorMatcher({N}) // 1D indices + .with_dtype() + .with_device(device_) + .verify(gpu_indices) + .verify(cpu_indices); + const auto params = OffloadParams{ + .gpu_caches = static_cast(gpu_ptrs.data_ptr()), + .cpu_caches = static_cast(cpu_ptrs.data_ptr()), + .gpu_indices = static_cast(gpu_indices.data_ptr()), + .cpu_indices = static_cast(cpu_indices.data_ptr()), + .num_items = static_cast(N.unwrap()), + .num_layers = static_cast(L.unwrap()), + }; + LaunchKernel(kBlockQuota, kBlockSize, device_.unwrap())(offload_to_cpu, params); +} + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh new file mode 100644 index 0000000000..8fc8d0821d --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/main_norm_rope.cuh @@ -0,0 +1,845 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +SGL_DEVICE uint8_t quant_fp4_e2m1(float x) { + const float ax = fminf(fabsf(x), 6.0f); + uint8_t idx = 0; + idx += ax > 0.25f; + idx += ax > 0.75f; + idx += ax > 1.25f; + idx += ax > 1.75f; + idx += ax > 2.5f; + idx += ax > 3.5f; + idx += ax > 5.0f; + if (x < 0.0f && idx != 0) idx |= 0x8; + return idx; +} + +// 4 warps per block: warp-per-(token, head) work-item dispatch (Q kernel). +constexpr uint32_t kFusedQBlockSize = 128; +constexpr uint32_t kFusedQNumWarps = kFusedQBlockSize / device::kWarpThreads; + +// 8 warps per block: block-per-token work-item dispatch (K kernel). +constexpr uint32_t kFusedKBlockSize = 256; +constexpr uint32_t kFusedKNumWarps = kFusedKBlockSize / device::kWarpThreads; + +#define Q_KERNEL __global__ __launch_bounds__(kFusedQBlockSize, 16) +#define K_KERNEL __global__ __launch_bounds__(kFusedKBlockSize, 8) + +// ============================================================================ +// Q kernel: warp-per-(token, head) rmsnorm-self + RoPE + write to q_out. +// ============================================================================ + +struct FusedQNormRopeParams { + const void* __restrict__ q_input; // (B, num_q_heads, kHeadDim) DType + void* __restrict__ q_output; // (B, num_q_heads, kHeadDim) DType + const float* __restrict__ freqs_cis; // (max_pos, kRopeDim) fp32 (re/im interleaved) + const void* __restrict__ positions; // (B,) PosT + int64_t q_input_stride_batch; + int64_t q_output_stride_batch; + uint32_t batch_size; + uint32_t num_q_heads; + float eps; +}; + +template +Q_KERNEL void fused_q_norm_rope(const __grid_constant__ FusedQNormRopeParams params) { + using namespace device; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % kVecSize == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeDim == kWarpThreads * 2, "1 (real, imag) pair per lane"); + + using Storage = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kFusedQNumWarps + warp_id; + + const uint32_t total_works = params.batch_size * params.num_q_heads; + if (work_id >= total_works) return; + + const uint32_t batch_id = work_id / params.num_q_heads; + const uint32_t head_id = work_id % params.num_q_heads; + const auto input_ptr = + static_cast(params.q_input) + batch_id * params.q_input_stride_batch + head_id * kHeadDim; + const auto output_ptr = + static_cast(params.q_output) + batch_id * params.q_output_stride_batch + head_id * kHeadDim; + const auto position = static_cast(static_cast(params.positions)[batch_id]); + + __shared__ Storage s_rope[kFusedQNumWarps][kRopeSize]; + + // Prefetch this lane's freq pair before the PDL gate so the wait happens + // outside the dependency chain on `position`. + const auto mem_freq = tile::Memory{lane_id, kWarpThreads}; + + PDLWaitPrimary(); + + // part 1: rmsnorm-self (no weight). + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + Storage input_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input_ptr, i); + } + + const auto freq = mem_freq.load(params.freqs_cis + position * kRopeDim); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto x = cast(input_vec[i][j]); + sum_of_squares += x * x; + } + } + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto x = cast(input_vec[i][j]); + input_vec[i][j] = cast(x * norm_factor); + } + } + + // Stash the rope tail (last kRopeSize lanes' last tile) into shared memory; + // write nope tiles to gmem directly. + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(output_ptr, input_vec[i], i); + } + } + __syncwarp(); + + PDLTriggerSecondary(); + + // part 2: RoPE on all 32 lanes -- one (real, imag) bf16x2 pair per lane. + using DType2 = packed_t; + const auto mem_elem = tile::Memory{lane_id, kWarpThreads}; + const auto elem = mem_elem.load(s_rope[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t rotated = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(output_ptr + (kHeadDim - kRopeDim), cast(rotated)); +} + +template +struct FusedQNormRopeKernel { + template + static constexpr auto kernel = fused_q_norm_rope; + + static void forward( + const tvm::ffi::TensorView q_input, + const tvm::ffi::TensorView q_output, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + float eps) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_q_heads"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_input); + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_output); + TensorMatcher({-1, kRopeDim}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_heads = static_cast(H.unwrap()); + if (batch_size == 0) return; + + const auto params = FusedQNormRopeParams{ + .q_input = q_input.data_ptr(), + .q_output = q_output.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_input_stride_batch = q_input.stride(0), + .q_output_stride_batch = q_output.stride(0), + .batch_size = batch_size, + .num_q_heads = num_q_heads, + .eps = eps, + }; + const auto total_works = batch_size * num_q_heads; + const auto num_blocks = div_ceil(total_works, kFusedQNumWarps); + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(num_blocks, kFusedQBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(k, params); + } +}; + +// ============================================================================ +// K kernel: block-per-token rmsnorm (with kv_weight) + RoPE + FlashMLA store. +// ============================================================================ + +struct FusedKNormRopeFlashMLAParams { + const void* __restrict__ kv; // (B, kHeadDim) DType + const void* __restrict__ kv_weight; // (kHeadDim,) DType + const float* __restrict__ freqs_cis; // (max_pos, kRopeDim) fp32 + const void* __restrict__ positions; // (B,) PosT + const int32_t* __restrict__ out_loc; // (B,) int32 -> cache slot id + uint8_t* __restrict__ kvcache; // (npages, kPageBytes) uint8 + // Row stride for `kv` in elements. Required because the upstream caller often + // passes `qkv_a[..., q_lora_rank:]`, a non-contiguous slice whose stride[0] + // equals `q_lora_rank + kHeadDim` rather than `kHeadDim`. + int64_t kv_stride_batch; + uint32_t batch_size; + float eps; +}; + +template +K_KERNEL void fused_k_norm_rope_flashmla(const __grid_constant__ FusedKNormRopeFlashMLAParams params) { + using namespace device; + + constexpr int64_t kVecSize = 2; + constexpr uint32_t kRopeWarp = kFusedKNumWarps - 1; + constexpr int64_t kPageBytes = host::div_ceil(584ll << kPageBits, 576) * 576; + static_assert(kHeadDim == kFusedKBlockSize * kVecSize); + static_assert(kRopeDim == kWarpThreads * kVecSize); + static_assert(kHeadDim - kRopeDim == kRopeWarp * kWarpThreads * kVecSize); + using Storage = AlignedVector; + using Float2 = AlignedVector; + + const auto tx = threadIdx.x; + const auto warp_id = tx / kWarpThreads; + const auto lane_id = tx % kWarpThreads; + const auto work_id = blockIdx.x; + if (work_id >= params.batch_size) return; + + const auto input_ptr = static_cast(params.kv) + work_id * params.kv_stride_batch; + const auto position = static_cast(static_cast(params.positions)[work_id]); + const auto out_loc = params.out_loc[work_id]; + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float2 data, freq; + + // part 1: norm. Each thread owns one 2-elem pack (the `tx`-th). + // Sum-of-squares is reduced block-wide via per-warp partials. + { + __shared__ float partial_sums[kFusedKNumWarps]; + + Storage input_vec, weight_vec; + input_vec.load(input_ptr, tx); + weight_vec.load(params.kv_weight, tx); + if (warp_id == kRopeWarp) freq.load(freqs_cis, lane_id); + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto x = cast(input_vec[i]); + sum_of_squares += x * x; + } + const auto warp_sum = warp::reduce_sum(sum_of_squares); + if (lane_id == 0) partial_sums[warp_id] = warp_sum; + __syncthreads(); + // Replicate the per-warp partial sums onto all lanes of one warp and + // reduce. Every group of `kBlockItemNumWarps` lanes ends up with the + // global sum. + sum_of_squares = warp::reduce_sum(partial_sums[lane_id % kFusedKNumWarps]); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const auto x = cast(input_vec[i]); + const auto w = cast(weight_vec[i]); + data[i] = x * norm_factor * w; + } + } + + const int32_t page = out_loc >> kPageBits; + const int32_t offset = out_loc & ((1 << kPageBits) - 1); + const auto page_ptr = params.kvcache + page * kPageBytes; + const auto value_ptr = page_ptr + offset * 576; + + PDLTriggerSecondary(); + + // part 2: rope on warp 7 (BF16 store), per-warp UE8M0 quant + store on warps 0..6. + if (warp_id == kRopeWarp) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto freq_real = freq[0]; + const auto freq_imag = freq[1]; + data[0] = x_real * freq_real - x_imag * freq_imag; + data[1] = x_real * freq_imag + x_imag * freq_real; + const auto result = cast(fp32x2_t{data[0], data[1]}); + const auto rope_ptr = value_ptr + 448; + reinterpret_cast(rope_ptr)[lane_id] = result; + } else { + const auto x = data[0]; + const auto y = data[1]; + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const auto scale_ptr = page_ptr + (576 << kPageBits) + offset * 8; + reinterpret_cast(value_ptr)[tx] = result; + if (lane_id == 0) static_cast(scale_ptr)[warp_id] = scale_ue8m0; + } +} + +template +struct FusedKNormRopeFlashMLAKernel { + static constexpr int32_t kLogPageSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogPageSize == kPageSize); + static_assert(kHeadDim == 512 && kRopeDim == 64, "FlashMLA layout requires (512, 64)"); + + template + static constexpr auto kernel = fused_k_norm_rope_flashmla; + + static void forward( + const tvm::ffi::TensorView kv, + const tvm::ffi::TensorView kv_weight, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + const tvm::ffi::TensorView out_loc, + const tvm::ffi::TensorView kvcache, + float eps) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // + .with_strides({-1, 1}) + .with_dtype() + .with_device(device_) + .verify(kv); + TensorMatcher({kHeadDim}) // + .with_dtype() + .with_device(device_) + .verify(kv_weight); + TensorMatcher({-1, kRopeDim}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(out_loc); + TensorMatcher({-1, -1}) // + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(kvcache); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto params = FusedKNormRopeFlashMLAParams{ + .kv = kv.data_ptr(), + .kv_weight = kv_weight.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .out_loc = static_cast(out_loc.data_ptr()), + .kvcache = static_cast(kvcache.data_ptr()), + .kv_stride_batch = kv.stride(0), + .batch_size = batch_size, + .eps = eps, + }; + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(batch_size, kFusedKBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(k, params); + } +}; + +// ============================================================================ +// Indexer Q kernel: warp-per-(token, head) RoPE + Hadamard + fp8 act-quant. +// ============================================================================ + +struct FusedQIndexerRopeHadamardQuantParams { + const void* __restrict__ q_input; // (B, num_heads, 128) DType + void* __restrict__ q_fp8; // (B, num_heads, 128) fp8_e4m3 + // weights_out[b, h] = weight[b, h] * weight_scale * q_scale[b, h]. + // q_scale is computed internally and not exposed -- the only consumer of + // it is `weights_out`. + const void* __restrict__ weight; // (B, num_heads) DType + float* __restrict__ weights_out; // (B, num_heads) fp32 (== (B, H, 1) flat) + float weight_scale; // scalar c4_indexer.weight_scale + const float* __restrict__ freqs_cis; // (max_pos, 64) fp32 + const void* __restrict__ positions; // (B,) PosT + uint32_t batch_size; + uint32_t num_heads; +}; + +template +Q_KERNEL void fused_q_indexer_rope_hadamard_quant(const __grid_constant__ FusedQIndexerRopeHadamardQuantParams params) { + using namespace device; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; // = 16 + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + + using Storage = AlignedVector; + using Float4 = AlignedVector; + using OutStorage = AlignedVector; // 4 fp8 / lane + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kFusedQNumWarps + warp_id; + // Last `kRopeSize` lanes own the rope tail; their 4-elem packs cover the + // trailing kRopeDim elements. + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + const uint32_t total_works = params.batch_size * params.num_heads; + if (work_id >= total_works) return; + + const uint32_t batch_id = work_id / params.num_heads; + const auto input_ptr = static_cast(params.q_input) + work_id * kHeadDim; + const auto position = static_cast(static_cast(params.positions)[batch_id]); + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + // Lane 0 prefetches the weight scalar for this (token, head) work item. + // Weight is (B, num_heads) DType; we need one scalar per warp -- offload + // the load to lane 0 only. The multiply + store happens once the q_scale + // is known (part 4). + + PDLWaitPrimary(); + Float4 data, freq; + const auto weight_val = cast(static_cast(params.weight)[work_id]); + + // part 1: load (no norm). Each lane owns a 4-elem pack. + { + Storage input_vec; + input_vec.load(input_ptr, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + data[i] = cast(input_vec[i]); + } + } + + // part 2: rope on rope lanes only (4 elems / lane = 2 (real, imag) pairs). + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto fxr = freq[0]; + const auto fxi = freq[1]; + const auto fyr = freq[2]; + const auto fyi = freq[3]; + data[0] = x_real * fxr - x_imag * fxi; + data[1] = x_real * fxi + x_imag * fxr; + data[2] = y_real * fyr - y_imag * fyi; + data[3] = y_real * fyi + y_imag * fyr; + } + + PDLTriggerSecondary(); + + // part 3: 128-point Hadamard (2 local stages + 5 cross-lane shfl_xor stages). + // Same recipe as `fused_norm_rope_indexer`; see comments there for the + // butterfly invariants and the early-return safety argument. + { + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const float other = __shfl_xor_sync(0xFFFFFFFFu, data[i], mask, kWarpThreads); + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; + } + + { + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + const auto abs_max = warp::reduce_max(local_max); + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + OutStorage result; + result[0] = pack_fp8(data[0] * inv_scale, data[1] * inv_scale); + result[1] = pack_fp8(data[2] * inv_scale, data[3] * inv_scale); + + // q_fp8 row pointer: 128 fp8 / row = 32 OutStorage / row, one per lane. + auto out_row = static_cast(params.q_fp8) + work_id * kHeadDim; + result.store(out_row, lane_id); + params.weights_out[work_id] = weight_val * params.weight_scale * scale; + } +} + +template +struct FusedQIndexerRopeHadamardQuantKernel { + template + static constexpr auto kernel = fused_q_indexer_rope_hadamard_quant; + + static void forward( + const tvm::ffi::TensorView q_input, + const tvm::ffi::TensorView q_fp8, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView weights_out, + double weight_scale, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions) { + using namespace host; + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + + auto B = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_heads"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + // Caller path is `wq_b(q_lora).view(-1, H, D)` -> contiguous; the kernel + // assumes a flat `(B*H, kHeadDim)` layout for both q_input and q_fp8. + // Pin the head/innermost strides; assert the batch stride below. + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_input); + TensorMatcher({B, H, kHeadDim}) // + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_fp8); + TensorMatcher({B, H}) // + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({B, H, 1}) // + .with_dtype() + .with_device(device_) + .verify(weights_out); + TensorMatcher({-1, kRopeDim}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + + const auto batch_size = static_cast(B.unwrap()); + const auto num_heads = static_cast(H.unwrap()); + if (batch_size == 0) return; + + // The kernel computes row pointers as `base + work_id * kHeadDim`, so + // both inputs must be contiguous in (batch, head, elem) order. + const int64_t expected_batch_stride = static_cast(num_heads) * kHeadDim; + RuntimeCheck( + q_input.stride(0) == expected_batch_stride, + "q_input must be contiguous (B, H, kHeadDim); got stride[0]=", + q_input.stride(0)); + RuntimeCheck( + q_fp8.stride(0) == expected_batch_stride, + "q_fp8 must be contiguous (B, H, kHeadDim); got stride[0]=", + q_fp8.stride(0)); + + const auto params = FusedQIndexerRopeHadamardQuantParams{ + .q_input = q_input.data_ptr(), + .q_fp8 = q_fp8.data_ptr(), + .weight = weight.data_ptr(), + .weights_out = static_cast(weights_out.data_ptr()), + .weight_scale = static_cast(weight_scale), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .batch_size = batch_size, + .num_heads = num_heads, + }; + const auto total_works = batch_size * num_heads; + const auto num_blocks = div_ceil(total_works, kFusedQNumWarps); + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(num_blocks, kFusedQBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(k, params); + } +}; + +struct FusedQIndexerRopeHadamardFp4QuantParams { + const void* __restrict__ q_input; + void* __restrict__ q_fp4; + int32_t* __restrict__ q_sf; + const void* __restrict__ weight; + float* __restrict__ weights_out; + float weight_scale; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + uint32_t batch_size; + uint32_t num_heads; +}; + +template +Q_KERNEL void +fused_q_indexer_rope_hadamard_fp4_quant(const __grid_constant__ FusedQIndexerRopeHadamardFp4QuantParams params) { + using namespace device; + + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kVecSize = 4; + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim == kWarpThreads * kVecSize); + static_assert(kRopeDim == kWarpThreads * 2); + static_assert(kRopeSize <= kWarpThreads); + + using Storage = AlignedVector; + using Float4 = AlignedVector; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kFusedQNumWarps + warp_id; + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + + const uint32_t total_works = params.batch_size * params.num_heads; + if (work_id >= total_works) return; + + const uint32_t batch_id = work_id / params.num_heads; + const auto input_ptr = static_cast(params.q_input) + work_id * kHeadDim; + const auto position = static_cast(static_cast(params.positions)[batch_id]); + const auto freqs_cis = params.freqs_cis + position * kRopeDim; + + PDLWaitPrimary(); + Float4 data, freq; + const auto weight_val = cast(static_cast(params.weight)[work_id]); + + { + Storage input_vec; + input_vec.load(input_ptr, lane_id); + if (is_rope_lane) freq.load(freqs_cis, lane_id - (kWarpThreads - kRopeSize)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + data[i] = cast(input_vec[i]); + } + } + + if (is_rope_lane) { + const auto x_real = data[0]; + const auto x_imag = data[1]; + const auto y_real = data[2]; + const auto y_imag = data[3]; + const auto fxr = freq[0]; + const auto fxi = freq[1]; + const auto fyr = freq[2]; + const auto fyi = freq[3]; + data[0] = x_real * fxr - x_imag * fxi; + data[1] = x_real * fxi + x_imag * fxr; + data[2] = y_real * fyr - y_imag * fyi; + data[3] = y_real * fyi + y_imag * fyr; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] = cast(cast(data[i])); + } + + PDLTriggerSecondary(); + + { + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a1; + data[1] = a0 - a1; + data[2] = a2 + a3; + data[3] = a2 - a3; + } + { + const float a0 = data[0], a1 = data[1], a2 = data[2], a3 = data[3]; + data[0] = a0 + a2; + data[1] = a1 + a3; + data[2] = a0 - a2; + data[3] = a1 - a3; + } +#pragma unroll + for (uint32_t mask = 1; mask < kWarpThreads; mask <<= 1) { +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + const float other = __shfl_xor_sync(0xFFFFFFFFu, data[i], mask, kWarpThreads); + data[i] = (lane_id & mask) ? (other - data[i]) : (data[i] + other); + } + } + const float kHadamardScale = math::rsqrt(static_cast(kHeadDim)); +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] *= kHadamardScale; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) + data[i] = cast(cast(data[i])); + } + + { + float local_max = math::abs(data[0]); +#pragma unroll + for (int i = 1; i < kVecSize; ++i) { + local_max = math::max(local_max, math::abs(data[i])); + } + local_max = warp::reduce_max<8>(local_max); + const auto scale_raw = fmaxf(1e-4f, local_max) / 6.0f; + const auto scale_ue8m0 = static_cast(cast_to_ue8m0(scale_raw)); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const uint8_t packed0 = quant_fp4_e2m1(data[0] * inv_scale) | (quant_fp4_e2m1(data[1] * inv_scale) << 4); + const uint8_t packed1 = quant_fp4_e2m1(data[2] * inv_scale) | (quant_fp4_e2m1(data[3] * inv_scale) << 4); + const uint16_t packed = static_cast(packed0) | (static_cast(packed1) << 8); + auto out_row = static_cast(params.q_fp4) + work_id * (kHeadDim / 2); + reinterpret_cast(out_row)[lane_id] = packed; + if ((lane_id & 7) == 0) { + reinterpret_cast(params.q_sf + work_id)[lane_id >> 3] = scale_ue8m0; + } + params.weights_out[work_id] = weight_val * params.weight_scale; + } +} + +template +struct FusedQIndexerRopeHadamardFp4QuantKernel { + template + static constexpr auto kernel = fused_q_indexer_rope_hadamard_fp4_quant; + + static void forward( + const tvm::ffi::TensorView q_input, + const tvm::ffi::TensorView q_fp4, + const tvm::ffi::TensorView q_sf, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView weights_out, + double weight_scale, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions) { + using namespace host; + constexpr int64_t kHeadDim = 128; + constexpr int64_t kRopeDim = 64; + constexpr int64_t kFp4Dim = kHeadDim / 2; + + auto B = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_heads"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, H, kHeadDim}) + .with_strides({-1, kHeadDim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_input); + TensorMatcher({B, H, kFp4Dim}) + .with_strides({-1, kFp4Dim, 1}) + .with_dtype() + .with_device(device_) + .verify(q_fp4); + TensorMatcher({B, H}).with_dtype().with_device(device_).verify(q_sf); + TensorMatcher({B, H}).with_dtype().with_device(device_).verify(weight); + TensorMatcher({B, H, 1}).with_dtype().with_device(device_).verify(weights_out); + TensorMatcher({-1, kRopeDim}).with_dtype().with_device(device_).verify(freqs_cis); + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}).with_dtype(pos_dtype).with_device(device_).verify(positions); + + const auto batch_size = static_cast(B.unwrap()); + const auto num_heads = static_cast(H.unwrap()); + if (batch_size == 0) return; + + const int64_t expected_q_stride = static_cast(num_heads) * kHeadDim; + const int64_t expected_fp4_stride = static_cast(num_heads) * kFp4Dim; + RuntimeCheck(q_input.stride(0) == expected_q_stride, "q_input must be contiguous"); + RuntimeCheck(q_fp4.stride(0) == expected_fp4_stride, "q_fp4 must be contiguous"); + RuntimeCheck(q_sf.stride(0) == static_cast(num_heads) && q_sf.stride(1) == 1, "q_sf must be contiguous"); + + const auto params = FusedQIndexerRopeHadamardFp4QuantParams{ + .q_input = q_input.data_ptr(), + .q_fp4 = q_fp4.data_ptr(), + .q_sf = static_cast(q_sf.data_ptr()), + .weight = weight.data_ptr(), + .weights_out = static_cast(weights_out.data_ptr()), + .weight_scale = static_cast(weight_scale), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .batch_size = batch_size, + .num_heads = num_heads, + }; + const auto total_works = batch_size * num_heads; + const auto num_blocks = div_ceil(total_works, kFusedQNumWarps); + const auto k_int32 = kernel; + const auto k_int64 = kernel; + const auto k = pos_dtype.is_type() ? k_int32 : k_int64; + LaunchKernel(num_blocks, kFusedQBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(k, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh new file mode 100644 index 0000000000..7d5f97824b --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh @@ -0,0 +1,219 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct MegaMoEPreDispatchParams { + const bf16_t* __restrict__ x; // [num_tokens, hidden] + const int32_t* __restrict__ topk_idx; // [num_tokens, top_k] + const float* __restrict__ topk_weights; // [num_tokens, top_k] + + fp8_e4m3_t* __restrict__ buf_x; // [padded_max, hidden] + int32_t* __restrict__ buf_x_sf; // contiguous int32 [P, G/4]; see layout comment + int64_t* __restrict__ buf_topk_idx; // [padded_max, top_k] + float* __restrict__ buf_topk_weights; // [padded_max, top_k] + + uint32_t num_tokens; + uint32_t padded_max; + uint32_t hidden; + uint32_t num_groups; // hidden / group_size + uint32_t top_k; +}; + +// kGroupSize must match sglang_per_token_group_quant_fp8_ue8m0(group_size=). +template +__global__ __launch_bounds__(1024, 2) void // + mega_moe_pre_dispatch_kernel(const MegaMoEPreDispatchParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kVecElems = 8; // 8 bf16 = 16B load per thread + static_assert(kGroupSize % kVecElems == 0, "group_size must be a multiple of 8"); + constexpr uint32_t kThreadsPerGroup = kGroupSize / kVecElems; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + + PDLWaitPrimary(); + if (bid < params.num_tokens) { + // ---- Quantize path: one CTA per valid token ---- + + const uint32_t token_id = bid; + const auto token_in = params.x + static_cast(token_id) * params.hidden; + const auto token_out = params.buf_x + static_cast(token_id) * params.hidden; + + InputVec in_vec; + in_vec.load(token_in, tid); + + float local_max = 0.0f; + float vals[kVecElems]; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + const auto [v0, v1] = cast(in_vec[i]); + vals[2 * i + 0] = v0; + vals[2 * i + 1] = v1; + local_max = fmaxf(local_max, fmaxf(fabsf(v0), fabsf(v1))); + } + + // Absmax across the kThreadsPerGroup threads that cover one group. + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + const float raw_scale = absmax / math::FP8_E4M3_MAX; + const uint32_t ue8m0_exp = cast_to_ue8m0(raw_scale); + // 2^-ue8m0_exp as fp32 (equivalent to 1 / __uint_as_float(ue8m0 << 23)). + const float inv_scale = __uint_as_float((127u + 127u - ue8m0_exp) << 23); + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + out_vec[i] = pack_fp8(vals[2 * i + 0] * inv_scale, vals[2 * i + 1] * inv_scale); + } + out_vec.store(token_out, tid); + + // One thread per group writes its UE8M0 byte into the contiguous + // row-major int32-packed layout: byte address = t*num_groups + g + // (see layout comment at the top of the file). + const uint32_t group_id = tid / kThreadsPerGroup; + const uint32_t within_group_id = tid % kThreadsPerGroup; + if (within_group_id == 0 && group_id < params.num_groups) { + const uint32_t byte_off = token_id * params.num_groups + group_id; + reinterpret_cast(params.buf_x_sf)[byte_off] = static_cast(ue8m0_exp); + } + + // Copy this token's topk row (no alignment assumptions; top_k is small). + if (tid < params.top_k) { + const uint32_t off = token_id * params.top_k + tid; + params.buf_topk_idx[off] = params.topk_idx[off]; + params.buf_topk_weights[off] = params.topk_weights[off]; + } + } else { + // ---- Pad path: trailing blocks fill [num_tokens, padded_max) with (-1, 0) ---- + const uint32_t copy_bid = bid - params.num_tokens; + const uint32_t pad_base = params.num_tokens * params.top_k; + const uint32_t slot = pad_base + copy_bid * blockDim.x + tid; + const uint32_t total_slots = params.padded_max * params.top_k; + + if (slot < total_slots) { + params.buf_topk_idx[slot] = -1; + params.buf_topk_weights[slot] = 0.0f; + } + } + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct MegaMoEPreDispatchKernel { + static_assert(kGroupSize == 32 || kGroupSize == 64 || kGroupSize == 128, "unsupported group_size"); + static constexpr auto kernel = mega_moe_pre_dispatch_kernel(kGroupSize), kUsePDL>; + + static void + run(const tvm::ffi::TensorView x, + const tvm::ffi::TensorView topk_idx, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView buf_x, + const tvm::ffi::TensorView buf_x_sf, + const tvm::ffi::TensorView buf_topk_idx, + const tvm::ffi::TensorView buf_topk_weights) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto P = SymbolicSize{"padded_max"}; + auto H = SymbolicSize{"hidden"}; + auto K = SymbolicSize{"top_k"}; + auto G4 = SymbolicSize{"num_groups_div_4"}; + device.set_options(); + + TensorMatcher({M, H}) // input x + .with_dtype() + .with_device(device) + .verify(x); + TensorMatcher({M, K}) // topk_idx + .with_dtype() + .with_device(device) + .verify(topk_idx); + TensorMatcher({M, K}) // topk_weights + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({P, H}) // buf.x + .with_dtype() + .with_device(device) + .verify(buf_x); + // buf.x_sf is the contiguous row-major int32 view from DeepGEMM's mega + // symm buffer (DeepGEMM/csrc/apis/mega.hpp): shape (P, G/4), strides + // (G/4, 1). No explicit strides required -> TensorMatcher enforces + // is_contiguous(). + TensorMatcher({P, G4}) // buf_x_sf + .with_dtype() + .with_device(device) + .verify(buf_x_sf); + TensorMatcher({P, K}) // buf.topk_idx + .with_dtype() + .with_device(device) + .verify(buf_topk_idx); + TensorMatcher({P, K}) // buf.topk_weights + .with_dtype() + .with_device(device) + .verify(buf_topk_weights); + + const auto num_tokens = static_cast(M.unwrap()); + const auto padded_max = static_cast(P.unwrap()); + const auto hidden = static_cast(H.unwrap()); + const auto top_k = static_cast(K.unwrap()); + const auto num_groups_div_4 = static_cast(G4.unwrap()); + + RuntimeCheck(num_tokens <= padded_max, "num_tokens must not exceed padded_max"); + RuntimeCheck(hidden % kGroupSize == 0, "hidden must be a multiple of group_size"); + const auto num_groups = hidden / static_cast(kGroupSize); + RuntimeCheck(num_groups == num_groups_div_4 * 4u, "num_groups must be a multiple of 4"); + RuntimeCheck(hidden % 8u == 0, "hidden must be a multiple of 8 (16B bf16 loads)"); + const auto num_threads = hidden / 8u; + RuntimeCheck(num_threads <= 1024, "hidden too large for single-block-per-row quant"); + RuntimeCheck(num_threads >= top_k, "top_k must fit into one quant CTA"); + + const auto pad_slots = (padded_max - num_tokens) * top_k; + const uint32_t num_pad_blocks = pad_slots == 0 ? 0u : ((pad_slots + num_threads - 1u) / num_threads); + const auto num_total_blocks = num_tokens + num_pad_blocks; + + const auto params = MegaMoEPreDispatchParams{ + .x = static_cast(x.data_ptr()), + .topk_idx = static_cast(topk_idx.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .buf_x = static_cast(buf_x.data_ptr()), + .buf_x_sf = static_cast(buf_x_sf.data_ptr()), + .buf_topk_idx = static_cast(buf_topk_idx.data_ptr()), + .buf_topk_weights = static_cast(buf_topk_weights.data_ptr()), + .num_tokens = num_tokens, + .padded_max = padded_max, + .hidden = hidden, + .num_groups = num_groups, + .top_k = top_k, + }; + + if (num_total_blocks == 0) return; + LaunchKernel(num_total_blocks, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh new file mode 100644 index 0000000000..38be975558 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/paged_mqa_metadata.cuh @@ -0,0 +1,119 @@ +#include +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kBlockSize = 1024; +constexpr uint32_t kSplitKV = 256; // const for both SM90 and SM100 + +struct MetadataParams { + /// NOTE: batch_size > 0 + uint32_t batch_size; + uint32_t num_sm; + const uint32_t* __restrict__ context_lens; + uint32_t* __restrict__ schedule_metadata; + bool use_smem = true; +}; + +__global__ __launch_bounds__(kBlockSize, 1) // + void smxx_paged_mqa_logits_metadata(const MetadataParams params) { + using namespace device; + extern __shared__ uint32_t s_length[]; + static constexpr auto kNumWarps = kBlockSize / kWarpThreads; + static_assert(kNumWarps == kWarpThreads); + + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + __shared__ uint32_t s_warp_sum[kNumWarps]; + + uint32_t local_sum = 0; + for (uint32_t i = tx; i < params.batch_size; i += kBlockSize) { + const auto length = params.context_lens[i]; + local_sum += (length + kSplitKV - 1) / kSplitKV; + if (params.use_smem) s_length[i] = length; + } + + s_warp_sum[warp_id] = warp::reduce_sum(local_sum); + __syncthreads(); + + const auto global_sum = warp::reduce_sum(s_warp_sum[lane_id]); + if (lane_id != 0) return; + + const auto length_ptr = params.use_smem ? s_length : params.context_lens; + + const auto avg = global_sum / params.num_sm; + const auto ret = global_sum % params.num_sm; + uint32_t q = 0; + uint32_t num_work = (length_ptr[0] + kSplitKV - 1) / kSplitKV; + uint32_t sum_work = num_work; + for (auto i = warp_id; i <= params.num_sm; i += kNumWarps) { + const auto target = i * avg + min(i, ret); + while (sum_work <= target) { + if (++q >= params.batch_size) break; + num_work = (length_ptr[q] + kSplitKV - 1) / kSplitKV; + sum_work += num_work; + } + if (q >= params.batch_size) { + params.schedule_metadata[2 * i + 0] = params.batch_size; + params.schedule_metadata[2 * i + 1] = 0; + } else { + // sum > target && (sum - length) <= target + params.schedule_metadata[2 * i + 0] = q; + params.schedule_metadata[2 * i + 1] = target - (sum_work - num_work); + } + } +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +struct IndexerMetadataKernel { + static constexpr auto kMaxBatchSizeInSmem = 16384 * 2; // 128 KB smeme + static void run(tvm::ffi::TensorView seq_lens, tvm::ffi::TensorView metadata) { + using namespace host; + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_sm"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({M, 2}) // + .with_dtype() + .with_device(device) + .verify(metadata); + const auto batch_size = static_cast(N.unwrap()); + const auto num_sm = static_cast(M.unwrap()) - 1; + RuntimeCheck(num_sm <= 1024); + const auto use_smem = batch_size <= kMaxBatchSizeInSmem; + const auto params = MetadataParams{ + .batch_size = batch_size, + .num_sm = num_sm, + .context_lens = static_cast(seq_lens.data_ptr()), + .schedule_metadata = static_cast(metadata.data_ptr()), + .use_smem = use_smem, + }; + constexpr auto kernel = smxx_paged_mqa_logits_metadata; + setup_kernel_smem_once(); + const auto smem = use_smem ? (batch_size + 1) * sizeof(uint32_t) : 0; + LaunchKernel(1, kBlockSize, device.unwrap(), smem)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh new file mode 100644 index 0000000000..2239d3972d --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/rope.cuh @@ -0,0 +1,169 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +using DType = bf16_t; +constexpr int64_t kRopeDim = 64; +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedQKRopeParams { + void* __restrict__ q; + void* __restrict__ k; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + int64_t q_stride_batch; + int64_t k_stride_batch; + int64_t q_stride_head; + int64_t k_stride_head; + uint32_t num_q_heads; + uint32_t num_k_heads; + uint32_t batch_size; +}; + +template +__global__ __launch_bounds__(kBlockSize, 16) // + void deepseek_rope_kernel(const __grid_constant__ FusedQKRopeParams param) { + using namespace device; + using DType2 = packed_t; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto global_warp_id = blockIdx.x * kNumWarps + warp_id; + + const auto& [ + q, k, freqs_cis, positions, // + q_stride_batch, k_stride_batch, q_stride_head, k_stride_head, // + num_q_heads, num_k_heads, batch_size + ] = param; + + const auto num_total_heads = num_q_heads + num_k_heads; + const auto head_id = global_warp_id % num_total_heads; + const auto batch_id = global_warp_id / num_total_heads; + if (batch_id >= batch_size) return; + + const auto position = static_cast(positions)[batch_id]; + const auto is_q = head_id < num_q_heads; + const auto local_head = is_q ? head_id : (head_id - num_q_heads); + const auto stride_batch = is_q ? q_stride_batch : k_stride_batch; + const auto stride_head = is_q ? q_stride_head : k_stride_head; + const auto base_ptr = is_q ? q : k; + const auto input = static_cast(pointer::offset(base_ptr, batch_id * stride_batch, local_head * stride_head)); + + const auto freq_ptr = reinterpret_cast(freqs_cis + position * kRopeDim); + const auto [f_real, f_imag] = freq_ptr[lane_id]; + PDLWaitPrimary(); + + const auto data = input[lane_id]; + const auto [x_real, x_imag] = cast(data); + fp32x2_t output; + if constexpr (kInverse) { + // (a + bi) * (c - di) = (ac + bd) + (bc - ad)i + output = { + x_real * f_real + x_imag * f_imag, + x_imag * f_real - x_real * f_imag, + }; + } else { + // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + output = { + x_real * f_real - x_imag * f_imag, + x_real * f_imag + x_imag * f_real, + }; + } + input[lane_id] = cast(output); + + PDLTriggerSecondary(); +} + +template +struct FusedQKRopeKernel { + // 4 kernel variants: {forward, inverse} x {int32, int64} + static constexpr auto kernel_fwd_i32 = deepseek_rope_kernel; + static constexpr auto kernel_fwd_i64 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i32 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i64 = deepseek_rope_kernel; + + static void forward( + const tvm::ffi::TensorView q, + const tvm::ffi::Optional k, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + bool inverse) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto Q = SymbolicSize{"num_q_heads"}; + auto K = SymbolicSize{"num_k_heads"}; + constexpr auto D = kRopeDim; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, Q, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(q); + if (k.has_value()) { + TensorMatcher({B, K, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(k.value()); + } else { + K.set_value(0); + } + TensorMatcher({-1, D}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + const bool pos_i32 = pos_dtype.is_type(); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto num_q_heads = static_cast(Q.unwrap()); + const auto num_k_heads = static_cast(K.unwrap()); + const auto num_total_heads = num_q_heads + num_k_heads; + const auto total_warps = batch_size * num_total_heads; + const auto num_blocks = div_ceil(total_warps, kNumWarps); + + const auto elem_size = static_cast(sizeof(DType)); + const auto params = FusedQKRopeParams{ + .q = q.data_ptr(), + .k = k ? k.value().data_ptr() : nullptr, + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_stride_batch = q.stride(0) * elem_size, + .k_stride_batch = k ? k.value().stride(0) * elem_size : 0, + .q_stride_head = q.stride(1) * elem_size, + .k_stride_head = k ? k.value().stride(1) * elem_size : 0, + .num_q_heads = num_q_heads, + .num_k_heads = num_k_heads, + .batch_size = batch_size, + }; + + // dispatch: {inverse} x {pos_i32} + using KernelType = decltype(kernel_fwd_i32); + const KernelType kernel = + inverse ? (pos_i32 ? kernel_inv_i32 : kernel_inv_i64) : (pos_i32 ? kernel_fwd_i32 : kernel_fwd_i64); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh new file mode 100644 index 0000000000..be0e759445 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh @@ -0,0 +1,540 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct SiluMulQuantVarlenParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +template +SGL_DEVICE fp32x2_t silu_and_mul(DType2 gate, DType2 up, float limit) { + using namespace device; + // refer to as implementation. TL;DR: must clamp in bf16 + // https://github.com/deepseek-ai/DeepGEMM/blob/7f2a703ed51ac1f7af07f5e1453b2d3267d37d50/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh#L984-L997 + if constexpr (kApplySwigluLimit) { + static_assert(std::is_same_v); + gate = __hmin2(gate, {limit, limit}); + up = __hmax2(up, {-limit, -limit}); + up = __hmin2(up, {limit, limit}); + } + const auto [g0, g1] = cast(gate); + const auto [u0, u1] = cast(up); + const auto silu0 = g0 / (1.0f + __expf(-g0)); + const auto silu1 = g1 / (1.0f + __expf(-g1)); + const float val0 = silu0 * u0; + const float val1 = silu1 * u1; + if constexpr (kPrecise) { // I don't know if we should enable this? + return {val0, val1}; + } else { + return cast(cast(fp32x2_t{val0, val1})); + } +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantVarlenParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_varlen_kernel(const SiluMulQuantVarlenParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + // gran=8 interleaved: every 16-element chunk on the N axis is + // [gate[0..7], up[0..7]]. Each thread handles 8 consecutive output + // elements, so its gate chunk lives at vec index 2*threadIdx.x and its + // up chunk at 2*threadIdx.x+1. + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +struct SiluAndMulClampParams { + const void* __restrict__ input; + void* __restrict__ output; + float swiglu_limit; +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_clamp_kernel(const SiluAndMulClampParams __grid_constant__ params) { + using namespace device; + static_assert(sizeof(DType) == 2, "only fp16/bf16 supported"); + using DType2 = packed_t; + constexpr auto kVecSize = 16 / sizeof(DType); + static_assert(kVecSize % 2 == 0 && kVecSize > 0); + using Vec = AlignedVector; + const auto bid = blockIdx.x; + const auto tile = tile::Memory::cta(); + const float limit = params.swiglu_limit; + + PDLWaitPrimary(); + const auto gate = tile.load(params.input, bid * 2 + 0); + const auto up = tile.load(params.input, bid * 2 + 1); + Vec out; + +#pragma unroll + for (uint32_t i = 0; i < kVecSize / 2; ++i) { + out[i] = cast(silu_and_mul(cast(gate[i]), cast(up[i]), limit)); + } + + tile.store(params.output, out, bid); + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_varlen_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_varlen_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantVarlenParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct SiluAndMulClampKernel { + static constexpr auto kernel = silu_mul_clamp_kernel; + + static void run(const tvm::ffi::TensorView input, const tvm::ffi::TensorView output, const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"gate_up_dim"}; // 2 * out_dim + auto H = SymbolicSize{"out_dim"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate || up) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, H}) // output + .with_dtype() + .with_device(device) + .verify(output); + RuntimeCheck(D.unwrap() == 2 * H.unwrap(), "input last dim must be 2 * output last dim"); + + constexpr uint32_t kVecSize = 16 / sizeof(DType); + const auto out_dim = static_cast(H.unwrap()); + const auto num_tokens = static_cast(M.unwrap()); + RuntimeCheck(out_dim % kVecSize == 0, "out_dim must be divisible by vector size"); + const auto num_threads = out_dim / kVecSize; + RuntimeCheck(num_threads <= 1024, "out_dim too large for single-block-per-row launch"); + + const auto params = SiluAndMulClampParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .swiglu_limit = static_cast(swiglu_limit), + }; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +struct SiluMulQuantContigParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t scale_row_stride_int32; // only used when kTransposed=true +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_contig_kernel(const SiluMulQuantContigParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto token_id = blockIdx.x; + const auto work_id = threadIdx.x / kWorkThreads; + + const auto input = params.input + token_id * params.hidden_dim * 2; + const auto output = params.output + token_id * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + // Physical layout is (G//4_pad, M_pad) int32; each int32 packs 4 + // consecutive UE8M0 exponents for the same token. Byte address: + // (work_id / 4) * M_pad * 4 + token * 4 + (work_id % 4). + const auto base = reinterpret_cast(params.output_scale); + return base + (work_id / 4u) * (params.scale_row_stride_int32 * 4u) + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + token_id * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +template +struct SiluAndMulContigPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_contig_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_contig_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate/up, natural or gran=8 interleaved on last dim) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, N}) // fp8 output + .with_dtype() + .with_device(device) + .verify(output); + + const auto hidden_dim = N.unwrap(); + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + const auto num_groups = static_cast(hidden_dim / kGroupSize); + + uint32_t scale_row_stride_int32 = 0; + if (!transposed) { + G.set_value(num_groups); + TensorMatcher({M, G}) // (M, G) fp32 natural row-major + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + RuntimeCheck(num_groups % 4 == 0, "transposed layout requires num_groups % 4 == 0"); + auto G_ = SymbolicSize{"G // 4"}; + G_.set_value(num_groups / 4); + auto M_pad = SymbolicSize{"M padded"}; + TensorMatcher({M, G_}) // `.transpose(-1,-2)[:M,:]` view of (G//4_pad, M_pad) int32 + .with_strides({int64_t{1}, M_pad}) // col-major transposed + .with_dtype() + .with_device(device) + .verify(output_scale); + scale_row_stride_int32 = static_cast(M_pad.unwrap()); + } + + const auto num_tokens = static_cast(M.unwrap()); + + const auto params = SiluMulQuantContigParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .scale_row_stride_int32 = scale_row_stride_int32, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh new file mode 100644 index 0000000000..49f6f55963 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/store.cuh @@ -0,0 +1,205 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct FusedStoreCacheParam { + const void* __restrict__ input; + void* __restrict__ cache; + const void* __restrict__ indices; + uint32_t num_tokens; +}; + +template +__global__ void fused_store_flashmla_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 584 = 576 + 8 + constexpr int64_t kPageBytes = host::div_ceil(584 << kPageBits, 576) * 576; + + // each warp handles 64 elements, 8 warps, each block handles 1 row + const auto& [input, cache, indices, num_tokens] = param; + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + const uint32_t wid = tid / 32; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[bid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + const auto elems = static_cast(input)[tid + bid * 256]; + if (wid != 7) { + const auto [x, y] = cast(elems); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576); + const auto scale_ptr = pointer::offset(page_ptr, 576 << kPageBits, offset * 8); + static_cast(value_ptr)[tid] = result; + static_cast(scale_ptr)[wid] = scale_ue8m0; + } else { + const auto result = cast(elems); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576, 448); + static_cast(value_ptr)[tid - 7 * 32] = result; + } + + PDLTriggerSecondary(); +} + +template +__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 132 = 128 + 4 + constexpr int64_t kPageBytes = 132 << kPageBits; + + // each warp handles 128 elements, 1 warp, each block handles multiple rows + const auto& [input, cache, indices, num_tokens] = param; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto global_wid = global_tid / 32; + const auto lane_id = threadIdx.x % 32; + + if (global_wid >= num_tokens) return; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[global_wid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + using InStorage = AlignedVector; + using OutStorage = AlignedVector; + const auto elems = static_cast(input)[global_tid]; + const auto [x0, x1] = cast(elems[0]); + const auto [y0, y1] = cast(elems[1]); + const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1))); + const auto abs_max = warp::reduce_max(local_max); + // use normal fp32 scale + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 128); + const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4); + OutStorage result; + result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale); + result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale); + static_cast(value_ptr)[lane_id] = result; + static_cast(scale_ptr)[0] = scale; + + PDLTriggerSecondary(); +} + +template +struct FusedStoreCacheFlashMLAKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr auto kernel = fused_store_flashmla_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 512}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 256; + const auto num_blocks = num_tokens; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct FusedStoreCacheIndexerKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = 132 * kPageSize; + static constexpr auto kernel = fused_store_indexer_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 128}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128; + const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh new file mode 100644 index 0000000000..b1ccd24b20 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v1.cuh @@ -0,0 +1,340 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +#ifndef SGL_TOPK +#define SGL_TOPK 512 +#endif + +constexpr uint32_t kTopK = SGL_TOPK; +constexpr uint32_t kTopKBlockSize = SGL_TOPK; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopKParams { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_transform_kernel(const __grid_constant__ TopKParams params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopKKernel { + static constexpr auto kernel = topk_transform_kernel; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, kTopK}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, kTopK}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopKParams{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh new file mode 100644 index 0000000000..8c4a526575 --- /dev/null +++ b/lightllm/third_party/sglang_jit/csrc/deepseek_v4/topk_v2.cuh @@ -0,0 +1,493 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace { + +#ifndef SGL_TOPK +#define SGL_TOPK 512 +#endif + +inline constexpr uint32_t K = SGL_TOPK; + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +namespace impl = device::top512; +using Large = impl::ClusterTopK; +using Medium = impl::StreamingTopK; +using Small = impl::RegisterTopK; + +using Metadata = Large::Metadata; +constexpr uint32_t kBlockSize = impl::kBlockSize; +constexpr uint32_t kNumClusters = 15; // based on hardware limits +constexpr uint32_t kClusterSize = Large::kClusterSize; +constexpr uint32_t kMax2PassLength = Small::kMax2PassLength; +constexpr uint32_t kMaxSupportedLength = Large::kMaxLength; + +/// Common metadata lives at metadata[0] (first row of the [batch_size+1, 4] tensor). +/// Per-item metadata starts at metadata[1..batch_size]. The plan kernel writes both. +struct alignas(16) GlobalMetadata { + uint32_t cluster_threshold; // decided per-batch in plan kernel + uint32_t num_cluster_items; // N = number of items routed to the cluster path + uint32_t reserved[2]; +}; +static_assert(sizeof(GlobalMetadata) == sizeof(Metadata), "layout: row 0 must occupy one Metadata-sized slot"); + +// optimize occupancy for prefill +#define SMALL_TOPK_KERNEL __global__ __launch_bounds__(kBlockSize, 2) +// cluster at y dim +#define LARGE_CLUSTER __cluster_dims__(1, kClusterSize, 1) +// stage-1 is persistent cluster, and shared memory usage is huge (can not 2) +#define LARGE_TOPK_STAGE_1 __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// stage-2 is non-persistent non-cluster, with less shared memory and higher occupancy +#define LARGE_TOPK_STAGE_2 __global__ __launch_bounds__(kBlockSize, 2) +// fused into 1 stage when batch-size <= kNumPersistentClusters +#define FUSED_COMBINE_KERNEL __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// plan runs once as a single block before the combine kernels +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct TopKParams { + const uint32_t* __restrict__ seq_lens; + const float* __restrict__ scores; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int64_t score_stride; + int64_t page_table_stride; + uint8_t* __restrict__ workspace; // [batch, kWorkspaceBytes] -- internally allocated + /// Pointer to the full metadata tensor: metadata[0] is GlobalMetadata, metadata[1..] + /// are per-item entries (at most kNumClusters * rounds of them). + const Metadata* __restrict__ metadata = nullptr; + int64_t workspace_stride; // bytes per batch + uint32_t batch_size; + uint32_t page_bits; + + SGL_DEVICE const float* get_scores(const uint32_t batch_id) const { + return scores + batch_id * score_stride; + } + SGL_DEVICE impl::TransformParams get_transform(const uint32_t batch_id, int32_t* indices) const { + return { + .page_table = page_table + batch_id * page_table_stride, + .indices_in = indices, + .indices_out = page_indices + batch_id * K, + .page_bits = page_bits, + }; + } + SGL_DEVICE const GlobalMetadata& get_global_metadata() const { + return *reinterpret_cast(metadata); + } + SGL_DEVICE const Metadata& get_item_metadata(uint32_t work_id) const { + return metadata[1 + work_id]; // +1 to skip the GlobalMetadata row + } +}; + +SGL_DEVICE uint2 partition_work(uint32_t length, uint32_t rank) { + constexpr uint32_t kTMAAlign = 4; + const auto total_units = (length + kTMAAlign - 1) / kTMAAlign; + const auto base = total_units / kClusterSize; + const auto extra = total_units % kClusterSize; + const auto local_units = base + (rank < extra ? 1u : 0u); + const auto offset_units = rank * base + min(rank, extra); + const auto offset = offset_units * kTMAAlign; + const auto finish = min(offset + local_units * kTMAAlign, length); + return {offset, finish - offset}; +} + +/// Persistent scheduler. A single block: +/// 1. Decides a cluster_threshold from the real seq_lens distribution (or +/// uses the caller-supplied `static_cluster_threshold` when non-zero). +/// 2. Writes that threshold + N into metadata[0] (the GlobalMetadata row). +/// 3. Compacts items with seq_len > threshold into metadata[1..N+1), laid out +/// to match the persistent consumer's round-robin stride (kNumClusters). +/// Entries for clusters that get no work are zero-filled. +PLAN_KERNEL void topk_plan( + const uint32_t* __restrict__ seq_lens, + Metadata* __restrict__ metadata, + const uint32_t batch_size, + const uint32_t static_cluster_threshold) { + // Candidate thresholds, strictly increasing. Picked to give the auto-heuristic + // reasonable granularity without needing a full sort. Must all be >= kMax2PassLength. + + struct Pair { + uint32_t threshold; + uint32_t max_batch_size; + }; + /// NOTE: only tuned on B200 + constexpr Pair kCandidates[] = { + {32768, 30}, + {40960, 45}, + {49152, 45}, + {65536, 60}, + {98304, 60}, + {131072, 75}, + {196608, 90}, + {262144, 105}, + }; + constexpr uint32_t kNumCandidates = std::size(kCandidates); + constexpr uint32_t kMinBatchSize = kCandidates[0].max_batch_size; + static_assert(kCandidates[0].threshold == kMax2PassLength); + static_assert(kCandidates[kNumCandidates - 1].threshold == kMaxSupportedLength); + + __shared__ uint32_t s_count; // final N after compaction + __shared__ uint32_t s_counts[kNumCandidates]; + __shared__ uint32_t s_threshold; + + const auto tx = threadIdx.x; + if (tx == 0) s_count = 0; + if (tx < kNumCandidates) s_counts[tx] = 0; + __syncthreads(); + + // --- Phase 1: decide threshold ------------------------------------------ + if (static_cluster_threshold > 0) { + if (tx == 0) s_threshold = static_cluster_threshold; + } else if (batch_size <= kMinBatchSize) { + if (tx == 0) s_threshold = kMax2PassLength; // always prefer cluster + } else { + // Count items above each candidate threshold. Monotonically non-increasing in T. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + assert(sl <= kMaxSupportedLength); + uint32_t count = 0; +#pragma unroll + for (uint32_t j = 0; j < kNumCandidates; ++j) { + count += (sl > kCandidates[j].threshold ? 1 : 0); + } + if (count > 0) { + atomicAdd(&s_counts[count - 1], 1); + } + } + __syncthreads(); + if (tx == 0) { + uint32_t accum = 0; + uint32_t chosen = kMaxSupportedLength; +#pragma unroll + for (uint32_t i = 0; i < kNumCandidates; ++i) { + const auto j = kNumCandidates - 1 - i; + accum += s_counts[j]; + /// NOTE: `accum` increasing, while `max_batch_size` decreasing + if (accum > kCandidates[j].max_batch_size) break; + chosen = kCandidates[j].threshold; + } + s_threshold = chosen; + } + } + __syncthreads(); + // sanity check: below 2 pass threshold, must fits in small path + const auto cluster_threshold = max(s_threshold, kMax2PassLength); + + // --- Phase 2: compact items with seq_len > threshold into metadata[1..] - + // Per-item rows live at metadata[1 + pos]; metadata[0] is the GlobalMetadata row. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + if (sl > cluster_threshold) { + const auto pos = atomicAdd(&s_count, 1); + metadata[1 + pos] = {i, sl, false}; + } + } + __syncthreads(); + const auto N = s_count; + + // --- Phase 3: has_next + sentinels + GlobalMetadata --------------------- + for (uint32_t i = tx; i < N; i += kBlockSize) { + if (i + kNumClusters < N) metadata[1 + i].has_next = true; + } + // Zero-fill the first kNumClusters sentinel slots that got no valid entry. + if (tx < kNumClusters && tx >= N) metadata[1 + tx] = {0, 0, false}; + // Write global metadata (row 0). + if (tx == 0) { + auto* g = reinterpret_cast(metadata); + *g = { + .cluster_threshold = cluster_threshold, + .num_cluster_items = N, + .reserved = {0, 0}, + }; + } +} + +SMALL_TOPK_KERNEL void // short context +topk_short_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + // trivial case + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + device::PDLTriggerSecondary(); + Small::transform(transform); + } +} + +LARGE_TOPK_STAGE_1 void // long context, middle to large batch size +topk_combine_preprocess(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + uint32_t work_id = blockIdx.x; + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + uint32_t length; + uint32_t offset; + const auto cluster_rank = blockIdx.y; + + const auto prefetch_metadata = [&] { + const auto metadata = params.get_item_metadata(work_id); + batch_id = metadata.batch_id; + seq_len = metadata.seq_len; + has_next = metadata.has_next; + work_id += kNumClusters; // advance to the next item for this cluster + }; + const auto launch_prologue = [&] { + const auto partition = partition_work(seq_len, cluster_rank); + offset = partition.x; + length = partition.y; + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + }; + + device::PDLWaitPrimary(); + device::PDLTriggerSecondary(); + + prefetch_metadata(); + if (seq_len == 0) return; + Large::stage1_init(smem); + launch_prologue(); + while (true) { + const auto this_length = length; + const auto this_offset = offset; + const auto need_prefetch = has_next; + const auto transform = params.get_transform(batch_id, s_topk_indices); + const auto ws = params.workspace + batch_id * params.workspace_stride; + if (need_prefetch) prefetch_metadata(); + Large::stage1(s_topk_indices, this_length, smem, /*reuse=*/true); + if (need_prefetch) launch_prologue(); + Large::stage1_epilogue(transform, this_offset, ws, smem); + if (!need_prefetch) break; + } +} + +LARGE_TOPK_STAGE_2 void // long context, middle to large batch size +topk_combine_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto cluster_threshold = params.get_global_metadata().cluster_threshold; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= kMax2PassLength) { + if (seq_len <= Small::kMax1PassLength) { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } else { + __syncwarp(); + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } + Small::transform(transform); + } else if (seq_len <= cluster_threshold) { + Medium::run(params.get_scores(batch_id), seq_len, s_topk_indices, smem); + Medium::transform(transform, smem); + } else { + const auto ws = params.workspace + batch_id * params.workspace_stride; + device::PDLWaitPrimary(); + Large::transform(transform, ws, smem); + } +} + +FUSED_COMBINE_KERNEL void // long context, small batch size +topk_fused_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto cluster_rank = blockIdx.y; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + if (cluster_rank != 0) return; // only first rank work + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= Small::kMax1PassLength) { + if (cluster_rank != 0) return; // only first rank work + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + Small::transform(transform); + } else { + const auto [offset, length] = partition_work(seq_len, cluster_rank); + const auto ws = params.workspace + batch_id * params.workspace_stride; + Large::stage1_init(smem); + device::PDLWaitPrimary(); + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + Large::stage1(s_topk_indices, length, smem); + Large::stage1_epilogue(transform, offset, ws, smem); + cooperative_groups::this_cluster().sync(); + if (cluster_rank != 0) return; // only first rank do the stage-2 + Large::transform(transform, ws, smem); + } +} + +struct CombinedTopKKernel { + static constexpr auto kStage1SMEM = sizeof(Large::Smem) + 128; + static constexpr auto kStage2SMEM = std::max(sizeof(Small::Smem), sizeof(Medium::Smem)) + 128; + + static void plan( // + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView metadata, + const uint32_t static_cluster_threshold) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto batch_size = static_cast(B.unwrap()); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1); + if (batch_size <= kNumClusters) return; // metadata unused in fused path + + const auto device = device_.unwrap(); + constexpr auto kernel = topk_plan; + LaunchKernel(1, kBlockSize, device)( // + kernel, + static_cast(seq_lens.data_ptr()), + static_cast(metadata.data_ptr()), + batch_size, + static_cluster_threshold); + } + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::TensorView workspace, + const tvm::ffi::TensorView metadata) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto L = SymbolicSize{"max_seq_len"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto W = SymbolicSize{"workspace_stride"}; + constexpr auto D = Large::kWorkspaceInts; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, L}) // + .with_strides({S, 1}) + .with_dtype() + .with_device(device_) + .verify(scores); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({B, -1}) // + .with_strides({P, 1}) + .with_dtype() + .with_device(device_) + .verify(page_table); + TensorMatcher({B, K}) // + .with_dtype() + .with_device(device_) + .verify(page_indices); + TensorMatcher({B, D}) // + .with_strides({W, 1}) + .with_dtype() + .with_device(device_) + .verify(workspace); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto max_seq_len = static_cast(L.unwrap()); + const auto device = device_.unwrap(); + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + RuntimeCheck(S.unwrap() % 4 == 0, "score_stride must be a multiple of 4 (TMA 16-byte alignment)"); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1, "invalid metadata shape"); + + // NOTE: this should be fixed later + // RuntimeCheck(max_seq_len <= kMaxSupportedLength, max_seq_len, " exceeds the maximum supported length"); + + const auto params = TopKParams{ + .seq_lens = static_cast(seq_lens.data_ptr()), + .scores = static_cast(scores.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .workspace = static_cast(workspace.data_ptr()), + .metadata = static_cast(metadata.data_ptr()), + .workspace_stride = W.unwrap() * static_cast(sizeof(int32_t)), + .batch_size = batch_size, + .page_bits = page_bits, + }; + + if (max_seq_len <= Small::kMax1PassLength) { + // All items fit in the short path -- no stage-1 needed + constexpr auto kernel = topk_short_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel, params); + } else { + // Some items may be large -- launch stage-1 + main + if (batch_size <= kNumClusters) { + // can fuse into 1 stage + constexpr auto kernel = topk_fused_transform; + constexpr auto kSMEM = std::max(kStage1SMEM, kStage2SMEM); + setup_kernel_smem_once(); + LaunchKernel({batch_size, kClusterSize}, kBlockSize, device, kSMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel, params); + } else { + // stage 1 + stage 2 + constexpr auto kernel_stage_1 = topk_combine_preprocess; + setup_kernel_smem_once(); + const auto num_clusters = std::min(batch_size, kNumClusters); + LaunchKernel({num_clusters, kClusterSize}, kBlockSize, device, kStage1SMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel_stage_1, params); + constexpr auto kernel_stage_2 = topk_combine_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel_stage_2, params); + } + } + } +}; + +} // namespace diff --git a/lightllm/third_party/sglang_jit/dsv4/__init__.py b/lightllm/third_party/sglang_jit/dsv4/__init__.py new file mode 100644 index 0000000000..507b225167 --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/__init__.py @@ -0,0 +1,8 @@ +from .elementwise import fused_k_norm_rope_flashmla, fused_q_norm_rope +from .topk import topk_transform_512 + +__all__ = [ + "fused_k_norm_rope_flashmla", + "fused_q_norm_rope", + "topk_transform_512", +] diff --git a/lightllm/third_party/sglang_jit/dsv4/elementwise.py b/lightllm/third_party/sglang_jit/dsv4/elementwise.py new file mode 100644 index 0000000000..07011b0479 --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/elementwise.py @@ -0,0 +1,215 @@ +from typing import Optional, Tuple + +import torch + +from lightllm.third_party.sglang_jit.jit_utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) +from lightllm.third_party.sglang_jit.runtime_utils import is_hip + +from .utils import make_name + +_is_hip = is_hip() + + +@cache_once +def _jit_fused_rope_module(): + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("fused_rope"), + *args, + cuda_files=["deepseek_v4/rope.cuh"], + cuda_wrappers=[("forward", f"FusedQKRopeKernel<{args}>::forward")], + ) + + +@cache_once +def _jit_main_q_norm_rope_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, +): + """Main MLA path Q kernel: rmsnorm-self + RoPE, warp per (token, head).""" + args = make_cpp_args(dtype, head_dim, rope_dim, is_arch_support_pdl()) + return load_jit( + make_name("main_q_norm_rope"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedQNormRopeKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_main_k_norm_rope_flashmla_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, + page_size: int, +): + """Main MLA path K kernel: rmsnorm + RoPE + write to FlashMLA paged cache.""" + args = make_cpp_args(dtype, head_dim, rope_dim, page_size, is_arch_support_pdl()) + return load_jit( + make_name("main_k_norm_rope_flashmla"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedKNormRopeFlashMLAKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_main_q_indexer_rope_hadamard_quant_module(dtype: torch.dtype): + """C4 indexer Q kernel: RoPE + 128-pt Hadamard + fp8 act-quant""" + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("main_q_indexer_rope_hadamard_quant"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedQIndexerRopeHadamardQuantKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_main_q_indexer_rope_hadamard_fp4_quant_module(dtype: torch.dtype): + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("main_q_indexer_rope_hadamard_fp4_quant"), + *args, + cuda_files=["deepseek_v4/main_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedQIndexerRopeHadamardFp4QuantKernel<{args}>::forward"), + ], + ) + + +def fused_rope_inplace( + q: torch.Tensor, + k: Optional[torch.Tensor], + freqs_cis: torch.Tensor, + positions: torch.Tensor, + inverse: bool = False, +) -> None: + """Apply rotary embeddings to both Q and K in a single fused CUDA kernel. + + Args: + q: [batch_size, num_q_heads, rope_dim] bfloat16 + k: [batch_size, num_k_heads, rope_dim] bfloat16 or None + freqs_cis: [max_seq_len, rope_dim // 2] complex64 (full table) + positions: [batch_size] int32 or int64, indices into freqs_cis + inverse: if True, apply inverse rotation (conjugate freqs) + """ + if _is_hip: + from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton + + apply_rotary_emb_triton(q, freqs_cis, positions=positions, inverse=inverse) + if k is not None: + apply_rotary_emb_triton(k, freqs_cis, positions=positions, inverse=inverse) + return + + freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() + module = _jit_fused_rope_module() + module.forward(q, k, freqs_real, positions, inverse) + + +def fused_q_norm_rope( + q_input: torch.Tensor, + q_output: torch.Tensor, + eps: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, +) -> None: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + head_dim = q_input.shape[-1] + rope_dim = freqs_real.shape[-1] + module = _jit_main_q_norm_rope_module(q_input.dtype, head_dim, rope_dim) + module.forward(q_input, q_output, freqs_real, positions, eps) + + +def fused_q_indexer_rope_hadamard_quant( + q_input: torch.Tensor, + weight: torch.Tensor, + weight_scale: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + q_fp8 = torch.empty(q_input.shape, dtype=torch.float8_e4m3fn, device=q_input.device) + weights_out = torch.empty((*q_input.shape[:-1], 1), dtype=torch.float32, device=q_input.device) + if _is_hip: + torch.ops.sgl_kernel.dsv4_fused_q_indexer_rope_hadamard_quant( + q_input, + q_fp8, + weight, + weights_out, + float(weight_scale), + freqs_real, + positions, + ) + else: + module = _jit_main_q_indexer_rope_hadamard_quant_module(q_input.dtype) + module.forward( + q_input, + q_fp8, + weight, + weights_out, + float(weight_scale), + freqs_real, + positions, + ) + return q_fp8, weights_out + + +def fused_q_indexer_rope_hadamard_fp4_quant( + q_input: torch.Tensor, + weight: torch.Tensor, + weight_scale: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + if _is_hip: + raise RuntimeError("DeepSeek V4 FP4 indexer requires the CUDA fused Q path.") + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + q_fp4 = torch.empty( + (*q_input.shape[:-1], q_input.shape[-1] // 2), + dtype=torch.int8, + device=q_input.device, + ) + q_sf = torch.empty(q_input.shape[:-1], dtype=torch.int32, device=q_input.device) + weights_out = torch.empty((*q_input.shape[:-1], 1), dtype=torch.float32, device=q_input.device) + module = _jit_main_q_indexer_rope_hadamard_fp4_quant_module(q_input.dtype) + module.forward( + q_input, + q_fp4, + q_sf, + weight, + weights_out, + float(weight_scale), + freqs_real, + positions, + ) + return (q_fp4, q_sf), weights_out + + +def fused_k_norm_rope_flashmla( + kv: torch.Tensor, + kv_weight: torch.Tensor, + eps: float, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + out_loc: torch.Tensor, + kvcache: torch.Tensor, + page_size: int, +) -> None: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + head_dim = kv.shape[-1] + rope_dim = freqs_real.shape[-1] + module = _jit_main_k_norm_rope_flashmla_module(kv.dtype, head_dim, rope_dim, page_size) + module.forward(kv, kv_weight, freqs_real, positions, out_loc, kvcache, eps) diff --git a/lightllm/third_party/sglang_jit/dsv4/topk.py b/lightllm/third_party/sglang_jit/dsv4/topk.py new file mode 100644 index 0000000000..1bfce7cef3 --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/topk.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Optional + +import torch + +from lightllm.third_party.sglang_jit.jit_utils import ( + cache_once, + is_arch_support_pdl, + is_hip_runtime, + load_jit, + make_cpp_args, +) + +from .utils import make_name + + +@cache_once +def _jit_topk_v1_module(topk: int): + args = make_cpp_args(is_arch_support_pdl()) + assert topk in (512, 1024), "Only support topk=512 or 1024" + return load_jit( + make_name(f"topk_v1_{topk}"), + *args, + cuda_files=["deepseek_v4/topk_v1.cuh"], + cuda_wrappers=[("topk_transform", f"TopKKernel<{args}>::transform")], + extra_cuda_cflags=[f"-DSGL_TOPK={topk}"], + ) + + +@cache_once +def _jit_topk_v2_module(topk: int): + return load_jit( + make_name(f"topk_v2_{topk}"), + cuda_files=["deepseek_v4/topk_v2.cuh"], + cuda_wrappers=[ + ("topk_transform", "CombinedTopKKernel::transform"), + ("topk_plan", "CombinedTopKKernel::plan"), + ], + extra_cuda_cflags=[f"-DSGL_TOPK={topk}"], + ) + + +def topk_transform_512( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + if is_hip_runtime(): + torch.ops.sgl_kernel.deepseek_v4_topk_transform_512( + scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices + ) + else: + module = _jit_topk_v1_module(out_page_indices.shape[1]) + module.topk_transform(scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices) + + +_WORKSPACE_INTS_PER_BATCH = 2 + 1024 * 2 +_PLAN_METADATA_INTS_PER_BATCH = 4 + + +def plan_topk_v2(seq_lens: torch.Tensor, static_threshold: int = 0) -> torch.Tensor: + module = _jit_topk_v2_module(512) # does not matter + bs = seq_lens.shape[0] + metadata = seq_lens.new_empty(bs + 1, _PLAN_METADATA_INTS_PER_BATCH) + module.topk_plan(seq_lens, metadata, static_threshold) + return metadata + + +def topk_transform_512_v2( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + metadata: torch.Tensor, +) -> None: + module = _jit_topk_v2_module(out_page_indices.shape[1]) + bs = scores.shape[0] + workspace = seq_lens.new_empty(bs, _WORKSPACE_INTS_PER_BATCH) + module.topk_transform( + scores, + seq_lens, + page_tables, + out_page_indices, + page_size, + workspace, + metadata, + ) diff --git a/lightllm/third_party/sglang_jit/dsv4/utils.py b/lightllm/third_party/sglang_jit/dsv4/utils.py new file mode 100644 index 0000000000..8085074f6c --- /dev/null +++ b/lightllm/third_party/sglang_jit/dsv4/utils.py @@ -0,0 +1,2 @@ +def make_name(name: str) -> str: + return f"dpsk_v4_{name}" diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh new file mode 100644 index 0000000000..c9da765f4a --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/atomic.cuh @@ -0,0 +1,35 @@ +/// \file atomic.cuh +/// \brief Device-side atomic operations. + +#pragma once +#include + +namespace device::atomic { + +/** + * \brief Atomically computes the maximum of `*addr` and `value`, storing the + * result in `*addr`. + * \param addr Pointer to the value in global/shared memory to be updated. + * \param value The value to compare against. + * \return The old value at `*addr` before the update. + * \note On CUDA, this uses `atomicMax`/`atomicMin` on the reinterpreted + * integer representation. On ROCm, a CAS loop is used as a fallback. + */ +SGL_DEVICE float max(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +} // namespace device::atomic diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh new file mode 100644 index 0000000000..b47a4a27b2 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/cta.cuh @@ -0,0 +1,40 @@ +/// \file cta.cuh +/// \brief CTA (Cooperative Thread Array / thread-block) level primitives. + +#pragma once +#include +#include +#include + +namespace device::cta { + +/** + * \brief Compute the maximum of `value` across all threads in the CTA. + * + * Uses a two-level reduction: first within each warp via `warp::reduce_max`, + * then across warps using shared memory. The final result is stored in + * `smem[0]`. + * + * \tparam T Numeric type (must be supported by `warp::reduce_max`). + * \param value Per-thread input value. + * \param smem Shared memory buffer (must have at least `blockDim.x / 32` + * elements). + * \param min_value Identity element for max (default 0.0f). + * \note This function does NOT issue a trailing `__syncthreads()`. + * Callers must synchronize before reading `smem[0]`. + */ +template +SGL_DEVICE void reduce_max(T value, float* smem, float min_value = 0.0f) { + const uint32_t warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = warp::reduce_max(value); + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_value = tx * kWarpThreads < blockDim.x ? smem[tx] : min_value; + const auto max_value = warp::reduce_max(local_value); + smem[0] = max_value; + } + // no extra sync; it is caller's responsibility to sync if needed +} + +} // namespace device::cta diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh new file mode 100644 index 0000000000..02b166d01c --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress.cuh @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +struct alignas(16) PrefillPlan { + uint32_t ragged_id; + uint32_t batch_id; + uint32_t position; + uint32_t window_len; // must be in `[0, compress_ratio * (1 + is_overlap))` + + bool is_valid(const uint32_t ratio, const bool is_overlap) const { + const uint32_t max_window_len = ratio * (1 + is_overlap); + return window_len < max_window_len; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::PrefillPlan; +using PrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kPrefillPlanDim = 16; + +static_assert(alignof(PrefillPlan) == sizeof(PrefillPlan)); +static_assert(sizeof(PrefillPlan) == kPrefillPlanDim * sizeof(PrefillPlanTensorDtype)); + +} // namespace host::compress diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh new file mode 100644 index 0000000000..3e87127c5f --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/compress_v2.cuh @@ -0,0 +1,99 @@ +#pragma once + +#include +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +/// \brief Per-batch decode plan. Layout: 16 bytes. +struct alignas(16) DecodePlan { + uint32_t seq_len; + int32_t write_loc; + int32_t read_page_0; + int32_t read_page_1; +}; + +/// \brief Per-token compress plan (used by c4/c128 prefill). Layout: 16 bytes. +struct alignas(16) CompressPlan { + uint32_t seq_len; + uint16_t ragged_id; + uint16_t buffer_len; + int32_t read_page_0; + /// \brief Stage 0 (CPU): batch_id (used to look up page table). + /// \brief Stage 1 (GPU): final state-pool write location. + int32_t read_page_1; + + static SGL_DEVICE __host__ CompressPlan invalid() { + return CompressPlan{-1u, 0, 0, -1, -1}; + } + + SGL_DEVICE __host__ bool is_invalid() const { + return seq_len == -1u; + } +}; + +/// \brief Per-token write plan (used by c4/c128 prefill). Layout: 8 bytes. +struct alignas(8) WritePlan { + /// \brief Stage 0 (CPU): packed `(batch_id << 16) | ragged_id`. + /// \brief Stage 1 (GPU): just `ragged_id`. + uint32_t ragged_id; + /// \brief Stage 0 (CPU): position + 1 (used to look up state slot). + /// \brief Stage 1 (GPU): final state-pool write location. + int32_t write_loc; + + static SGL_DEVICE __host__ WritePlan invalid() { + return WritePlan{-1u, -1}; + } + + SGL_DEVICE __host__ bool is_invalid() const { + return ragged_id == -1u; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::CompressPlan; +using device::compress::DecodePlan; +using device::compress::WritePlan; + +static_assert(alignof(DecodePlan) == sizeof(DecodePlan)); +static_assert(sizeof(DecodePlan) == 16); +static_assert(alignof(CompressPlan) == sizeof(CompressPlan)); +static_assert(sizeof(CompressPlan) == 16); +static_assert(alignof(WritePlan) == sizeof(WritePlan)); +static_assert(sizeof(WritePlan) == 8); + +inline auto verify_plan_d(tvm::ffi::TensorView t, SymbolicSize& N, SymbolicDevice& device) -> const DecodePlan* { + TensorMatcher({N, sizeof(DecodePlan)}) // + .with_dtype() + .with_device(device) + .verify(t); + return static_cast(t.data_ptr()); +} + +inline auto verify_plan_c(tvm::ffi::TensorView t, SymbolicSize& N, SymbolicDevice& device) -> const CompressPlan* { + TensorMatcher({N, sizeof(CompressPlan)}) // + .with_dtype() + .with_device(device) + .verify(t); + return static_cast(t.data_ptr()); +} + +inline auto verify_plan_w(tvm::ffi::TensorView t, SymbolicSize& N, SymbolicDevice& device) -> const WritePlan* { + TensorMatcher({N, sizeof(WritePlan)}) // + .with_dtype() + .with_device(device) + .verify(t); + return static_cast(t.data_ptr()); +} + +} // namespace host::compress diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh new file mode 100644 index 0000000000..53a62755b4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/fp8_utils.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +#include +#ifndef USE_ROCM +#include +#endif + +// Small helpers shared by the DeepSeek-V4 FP8/UE8M0 quantization kernels +// (silu_and_mul_masked_post_quant, store, mega_moe_pre_dispatch, ...). +// All functions are `SGL_DEVICE` (= `__forceinline__ __device__`) so +// including this header in multiple translation units is ODR-safe. + +namespace deepseek_v4::fp8 { + +// Round `x` to the nearest representable UE8M0 value. Returns the raw +// 8-bit biased exponent; the actual fp32 scale is `2^(exp - 127)` +// (i.e. `__uint_as_float(exp << 23)`). +SGL_DEVICE int32_t cast_to_ue8m0(float x) { + uint32_t u = __float_as_uint(x); + int32_t exp = int32_t((u >> 23) & 0xFF); + uint32_t mant = u & 0x7FFFFF; + return exp + (mant != 0); +} + +// 1 / 2^(exp - 127) as fp32. Equivalent to `1.0f / __uint_as_float(exp << 23)`. +SGL_DEVICE float inv_scale_ue8m0(int32_t exp) { + return __uint_as_float((127 + 127 - exp) << 23); +} + +// Clamp to [-FP8_E4M3_MAX, FP8_E4M3_MAX]. +// Uses platform-specific max from type.cuh (448 for E4M3FN, 224 for E4M3FNUZ). +SGL_DEVICE float fp8_e4m3_clip(float val) { + return fmaxf(fminf(val, kFP8E4M3Max), -kFP8E4M3Max); +} + +#ifndef USE_ROCM +// Pack two fp32 values into a single fp8x2_e4m3 with clamping. +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} +#else +// Software float -> FP8 E4M3 conversion for ROCm/HIP. +// Supports both E4M3FN (MI350X, gfx950) and E4M3FNUZ (MI300X, gfx942). +SGL_DEVICE uint8_t cvt_float_to_fp8_e4m3(float val) { + val = fp8_e4m3_clip(val); + if (val == 0.0f) return 0; + + uint32_t f32 = __float_as_uint(val); + uint8_t sign = static_cast((f32 >> 31) << 7); + int32_t exp32 = static_cast((f32 >> 23) & 0xFF) - 127; + uint32_t mant23 = f32 & 0x7FFFFF; + +#if HIP_FP8_TYPE_FNUZ + // E4M3FNUZ: bias=8, max=240, no negative zero, NaN=0x80 + constexpr int32_t kBias = 8; + constexpr int32_t kMaxExp = 15; + constexpr int32_t kMinSubnormExp = -10; // min subnormal exponent + constexpr int32_t kMinNormExp = -7; // min normal exponent + constexpr uint8_t kSaturate = 0x7Fu; // max normal = 0_1111_111 = 240.0 +#else + // E4M3FN: bias=7, max=448, NaN=0x7F + constexpr int32_t kBias = 7; + constexpr int32_t kMaxExp = 15; + constexpr int32_t kMinSubnormExp = -9; + constexpr int32_t kMinNormExp = -6; + constexpr uint8_t kSaturate = 0x7Eu; // max normal = 0_1111_110 = 448.0 +#endif + + int32_t exp8; + uint8_t mant3; + + if (exp32 < kMinSubnormExp) { + return sign; + } else if (exp32 < kMinNormExp) { + // Subnormal range + int32_t shift = -(kBias - 1) - exp32; // 1..3 + uint32_t subnorm_mant = (0x800000 | mant23) >> (shift + 20); + uint32_t round_bit = ((0x800000 | mant23) >> (shift + 19)) & 1; + subnorm_mant += round_bit; + mant3 = static_cast(subnorm_mant & 0x07); + exp8 = 0; + if (subnorm_mant > 7) { + exp8 = 1; + mant3 = 0; + } + } else { + exp8 = exp32 + kBias; + mant3 = static_cast(mant23 >> 20); + uint32_t round_bit = (mant23 >> 19) & 1; + mant3 += round_bit; + if (mant3 > 7) { + mant3 = 0; + exp8++; + } + if (exp8 >= kMaxExp) return sign | kSaturate; + } + return sign | (static_cast(exp8) << 3) | mant3; +} + +// Pack two fp32 values into a single fp8x2_e4m3 (uint16_t on HIP). +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + uint8_t x8 = cvt_float_to_fp8_e4m3(x); + uint8_t y8 = cvt_float_to_fp8_e4m3(y); + return static_cast(x8) | (static_cast(y8) << 8); +} +#endif + +} // namespace deepseek_v4::fp8 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh new file mode 100644 index 0000000000..0a3acc4773 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/kvcacheio.cuh @@ -0,0 +1,96 @@ +#include +#include + +#include + +#include + +namespace device::hisparse { + +/// NOTE: We call nope+rope as a "value" here. +/// GPU Cache layout: +/// VALUE 0, VALUE 1, ..., VALUE 63, +/// SCALE 0, SCALE 1, ..., SCALE 63, +/// [Padding to align to 576 bytes] +/// CPU Cache follow a trivial linear layout without any padding. +inline constexpr int64_t kGPUPageSize = 64; +inline constexpr int64_t kGPUPageBits = 6; // log2(kGPUPageSize) +inline constexpr int64_t kValueBytes = 576; +inline constexpr int64_t kScaleBytes = 8; +/// NOTE: FlashMLA requires each page to be aligned to 576 bytes +inline constexpr int64_t kCPUItemBytes = kValueBytes + kScaleBytes; +inline constexpr int64_t kGPUPageBytes = host::div_ceil(kCPUItemBytes * kGPUPageSize, 576) * 576; +inline constexpr int64_t kGPUScaleOffset = kValueBytes * kGPUPageSize; + +struct PointerInfo { + int64_t* value_ptr; + int64_t* scale_ptr; +}; + +SGL_DEVICE PointerInfo get_pointer_gpu(void* cache, int32_t index) { + using namespace device; + static_assert(1 << kGPUPageBits == kGPUPageSize); + const int32_t page_num = index >> kGPUPageBits; + const int32_t page_offset = index & (kGPUPageSize - 1); + const auto page_ptr = pointer::offset(cache, page_num * kGPUPageBytes); + const auto value_ptr = pointer::offset(page_ptr, page_offset * kValueBytes); + const auto scale_ptr = pointer::offset(page_ptr, kGPUScaleOffset + page_offset * kScaleBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +SGL_DEVICE PointerInfo get_pointer_cpu(void* cache, int32_t index) { + using namespace device; + const auto value_ptr = pointer::offset(cache, index * kCPUItemBytes); + const auto scale_ptr = pointer::offset(value_ptr, kValueBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +enum class TransferDirection { + DeviceToDevice = 0, + DeviceToHost = 1, + HostToDevice = 2, +}; + +template +SGL_DEVICE void transfer_item(void* dst_cache, void* src_cache, const int32_t dst_index, const int32_t src_index) { + constexpr bool is_dst_device = (direction != TransferDirection::DeviceToHost); + constexpr bool is_src_device = (direction != TransferDirection::HostToDevice); + constexpr auto dst_fn = is_dst_device ? get_pointer_gpu : get_pointer_cpu; + constexpr auto src_fn = is_src_device ? get_pointer_gpu : get_pointer_cpu; + + const auto [dst_value_ptr, dst_scale_ptr] = dst_fn(dst_cache, dst_index); + const auto [src_value_ptr, src_scale_ptr] = src_fn(src_cache, src_index); + + int64_t local_items[2]; + const int64_t* tail_src_ptr; + int64_t* tail_dst_ptr; + + const int32_t lane_id = threadIdx.x % 32; + + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + local_items[i] = src_value_ptr[j]; + } + + if (lane_id < 8) { // handle the tail element safely + const auto last_id = 64 + lane_id; + tail_src_ptr = src_value_ptr + last_id; + tail_dst_ptr = dst_value_ptr + last_id; + } else { // broadcast load/store is safe + tail_src_ptr = src_scale_ptr; + tail_dst_ptr = dst_scale_ptr; + } + + const auto tail_item = *tail_src_ptr; + + // store first 512 bytes of value + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + dst_value_ptr[j] = local_items[i]; + } + + // store the tail element + *tail_dst_ptr = tail_item; +} + +} // namespace device::hisparse diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh new file mode 100644 index 0000000000..e58214c951 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/cluster.cuh @@ -0,0 +1,257 @@ +#pragma once +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct ClusterTopK { + static constexpr uint32_t kClusterSize = 8; + static constexpr uint32_t kHistBits = 10; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 4; + static constexpr uint32_t kMaxLength = kClusterSize * kNumStages * kSizePerStage; + static constexpr uint32_t kStoreLane = kBlockSize - 1; + static constexpr uint32_t kAboveBits = 11; + + // --------------------------------------------------------------------------- + // Shared memory layouts + // --------------------------------------------------------------------------- + + struct Smem { + uint64_t barrier[kNumStages]; + uint32_t local_above_equal[kClusterSize]; + uint32_t prefix_above_equal; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + uint32_t histogram[kHistBins]; + alignas(128) float score_buffer[kNumStages][kSizePerStage]; + Tie tie_buffer[kMaxTies]; + }; + + struct alignas(16) Metadata { + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + }; + + struct WorkSpace { + uint2 metadata; // {num_above, num_ties} + Tie ties[kMaxTies]; + }; + + static constexpr uint32_t kWorkspaceInts = sizeof(WorkSpace) / sizeof(uint32_t); + + // --------------------------------------------------------------------------- + // Stage 1: histogram + cluster reduce + find threshold + scatter + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_init(void* _smem) { + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto smem = static_cast(_smem); + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < kNumStages) ptx::mbarrier_init(&smem->barrier[tx], 1); + __syncthreads(); + } + + SGL_DEVICE static void stage1_prologue(const float* scores, uint32_t length, void* _smem) { + if (threadIdx.x == 0) { + const auto smem = static_cast(_smem); + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + const auto length_aligned = (length + 3u) & ~3u; // align to 4 for TMA +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + if (stage >= num_stages) break; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length_aligned - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[stage]; + ptx::tma_load(smem->score_buffer[stage], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + } + } + + SGL_DEVICE static void stage1(int32_t* indices, uint32_t length, void* _smem, bool reuse = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram, counters, and barriers +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; + const auto size = min(kSizePerStage, length - offset); + if (lane_id == 0) ptx::mbarrier_wait(&smem->barrier[stage], 0); + __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto idx = tx + i * kBlockSize; + if (idx >= size) break; + const auto score = smem->score_buffer[stage][idx]; + const auto bin = extract_coarse_bin(score); + atomicAdd(&smem->histogram[bin], 1); + } + } + + static_assert(kHistBins <= kBlockSize); + + // 2-shot all-reduce + { + auto cluster = cooperative_groups::this_cluster(); + cluster.sync(); + const auto cluster_rank = blockIdx.y; + const auto kLocalSize = kHistBins / kClusterSize; + const auto offset = kLocalSize * cluster_rank; + + const auto src_tx = tx / kClusterSize; + const auto src_rank = tx % kClusterSize; + + if (tx < kHistBins) { + const auto addr = &smem->histogram[offset + src_tx]; + const auto src_addr = cluster.map_shared_rank(addr, src_rank); + *src_addr = warp::reduce_sum(*src_addr); + } + cluster.sync(); + } + + // now each block holds the whole histogram, find the threshold bin + { + const auto value = tx < kHistBins ? smem->histogram[tx] : 0; + const auto warp_inc = warp_inclusive_sum(lane_id, value); + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + const auto tmp = smem->warp_sum[lane_id]; + // total_length = sum of all bins in the globally-reduced histogram + // (problem.length is block-local; after cluster reduction we need the global total) + const auto total_length = warp::reduce_sum(tmp); + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_inc; + const auto above = total_length - prefix_sum; + if (tx < kHistBins && above < K && above + value >= K) { + smem->counter_gt = smem->counter_eq = 0; + smem->match = { + .bin = tx, + .above_count = above, + .equal_count = value, + }; + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // write above and equal results to global memory +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto buf_idx = tx + i * kBlockSize; + const auto global_idx = offset + buf_idx; + if (global_idx >= length) break; + const auto score = smem->score_buffer[stage][buf_idx]; + const auto bin = extract_coarse_bin(score); + if (bin > thr_bin) { + indices[atomicAdd(&smem->counter_gt, 1)] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } + } + if (reuse) { + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < num_stages) ptx::mbarrier_arrive(&smem->barrier[tx]); + } + __syncthreads(); + } + + // --------------------------------------------------------------------------- + // Stage 1 epilogue: cross-block prefix sum + page translate + tie store + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_epilogue(const TransformParams params, const uint32_t offset, void* _ws, void* _smem) { + auto cluster = cooperative_groups::this_cluster(); + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto local_above = smem->counter_gt; + const auto local_equal = smem->counter_eq; + const auto cluster_rank = blockIdx.y; + + constexpr uint32_t kAboveMask = (1 << kAboveBits) - 1; + static_assert(kAboveMask >= K); + + // Pack local counts -- NO alignment rounding (contiguous layout) + static_assert(kMaxTies <= kBlockSize); + const auto idx_above = tx < local_above ? params.indices_in[tx] : 0; + const auto tie_value = tx < local_equal ? smem->tie_buffer[tx] : Tie{0, 0.0f}; + + // push to remote shared memory, can reduce latency of reading remote + if (tx < kClusterSize) { + const auto value = (local_equal << kAboveBits) | local_above; + const auto dst_addr = cluster.map_shared_rank(smem->local_above_equal, tx); + dst_addr[cluster_rank] = value; + } + // after this last sync, only read local shared memory + // so that it is safe when peer rank has already exited the kernel + cluster.sync(); + if (tx < kClusterSize) { + const auto value = tx < cluster_rank ? smem->local_above_equal[tx] : 0; + const auto kActiveMask = (1u << kClusterSize) - 1; + smem->prefix_above_equal = warp::reduce_sum(value, kActiveMask); + } + __syncthreads(); + + const auto prefix_packed = smem->prefix_above_equal; + const auto prefix_above = prefix_packed & kAboveMask; + const auto prefix_equal = prefix_packed >> kAboveBits; + + // Page-translate above elements + if (tx < local_above) { + params.write(tx + prefix_above, idx_above + offset); + } + // Contiguous tie store via regular global writes (no TMA, no gaps) + const auto ws = static_cast(_ws); + if (tx < local_equal && tx + prefix_equal < kMaxTies) { + ws->ties[tx + prefix_equal] = {tie_value.idx + offset, tie_value.score}; + } + // Block 0 writes global metadata {num_above, num_ties} + if (cluster_rank == kClusterSize - 1 && tx == 0) { + const auto sum_above = prefix_above + local_above; + const auto sum_equal = prefix_equal + local_equal; + ws->metadata = make_uint2(sum_above, sum_equal); + } + } + + SGL_DEVICE static void transform(const TransformParams params, const void* _ws, void* _smem) { + const auto ws = static_cast(_ws); + const auto meta = &ws->metadata; + const auto [num_above, num_equal] = *meta; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(ws->ties, clamped_ties, num_above, K, params, _smem); + } +}; + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh new file mode 100644 index 0000000000..d553032d79 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/common.cuh @@ -0,0 +1,176 @@ +#pragma once +#include +#include +#include +#include + +#include + +namespace device::top512 { + +inline constexpr uint32_t kMaxTopK = 1024; +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kNumWarps = kBlockSize / kWarpThreads; +inline constexpr uint32_t kMaxTies = 1024; // == kBlockSize: 1 element per thread in stage2 +static constexpr uint32_t kRadixBins = 256; +static_assert(kMaxTopK <= kBlockSize && kMaxTies <= kBlockSize); + +// always use float4 to load from global memory +using Vec4 = AlignedVector; + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +struct TransformParams { + const int32_t* __restrict__ page_table; + const int32_t* __restrict__ indices_in; + int32_t* __restrict__ indices_out; + uint32_t page_bits; + + SGL_DEVICE void transform(const uint32_t idx) const { + indices_out[idx] = page_to_indices(page_table, indices_in[idx], page_bits); + } + SGL_DEVICE void write(const uint32_t dst, const uint32_t src) const { + indices_out[dst] = page_to_indices(page_table, src, page_bits); + } +}; + +struct alignas(16) MatchBin { + uint32_t bin; + uint32_t above_count; + uint32_t equal_count; +}; + +struct alignas(8) Tie { + uint32_t idx; + float score; +}; + +struct TieHandleSmem { + alignas(128) uint32_t counter; // output position counter + alignas(128) MatchBin match; + uint32_t histogram[kRadixBins]; // 256-bin radix histogram + uint32_t warp_sum[kNumWarps]; // for 2-pass prefix sum +}; + +template +SGL_DEVICE uint32_t extract_coarse_bin(float x) { + static_assert(0 < kBits && kBits < 15); + const auto hx = cast(x); + const uint16_t bits = *reinterpret_cast(&hx); + const uint16_t key = (bits & 0x8000) ? ~bits : bits | 0x8000; + return key >> (16 - kBits); +} + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +/// Order-preserving float32 -> uint32 for radix select +SGL_DEVICE uint32_t extract_exact_bin(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE void trivial_transform(const TransformParams& params, uint32_t length, uint32_t K) { + if (const auto tx = threadIdx.x; tx < length) { + params.write(tx, tx); + } else if (tx < K) { + params.indices_out[tx] = -1; + } +} + +SGL_DEVICE void tie_handle_transform( + const Tie* __restrict__ ties, // + const uint32_t num_ties, + const uint32_t num_above, + const uint32_t K, + const TransformParams params, + void* _smem) { + auto* smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Each thread loads one element (or becomes inactive) + const bool has_elem = tx < num_ties; + const auto tie = has_elem ? ties[tx] : Tie{0, 0.0f}; + const uint32_t key = extract_exact_bin(tie.score); + const uint32_t idx = tie.idx; + bool active = has_elem; + uint32_t topk_remain = K - num_above; + uint32_t write_pos = K; + + smem->counter = 0; + __syncthreads(); + + // Number of warps covering the 256-bin histogram (256/32 = 8) + constexpr uint32_t kRadixWarps = kRadixBins / kWarpThreads; + +#pragma unroll + for (int round = 0; round < 4; round++) { + const uint32_t shift = 24 - round * 8; + const uint32_t bin = (key >> shift) & 0xFFu; + + // 1. Build histogram + if (tx < kRadixBins) smem->histogram[tx] = 0; + __syncthreads(); + if (active) atomicAdd(&smem->histogram[bin], 1); + __syncthreads(); + + // 2. v2-style 2-pass prefix sum on 256 bins + // Only first 256 threads (8 warps) carry histogram bins. + // Other threads get hist_val=0 and harmless prefix results. + uint32_t hist_val = 0; + uint32_t warp_inc = 0; + if (tx < kRadixBins) { + hist_val = smem->histogram[tx]; + warp_inc = warp_inclusive_sum(lane_id, hist_val); + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + } + __syncthreads(); + if (tx < kRadixBins) { + // Inter-warp prefix (only first kHistWarps warp totals matter) + const auto tmp = (lane_id < kRadixWarps) ? smem->warp_sum[lane_id] : 0; + const auto total = warp::reduce_sum(tmp); + const auto inter = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + const auto prefix = inter + warp_inc; // inclusive prefix through this bin + const auto above = total - prefix; // elements in bins ABOVE this one + // 3. Find threshold bin + if (above < topk_remain && above + hist_val >= topk_remain) { + smem->match = {tx, above, topk_remain - above}; + } + } + __syncthreads(); + + const auto [thr, n_above, _] = smem->match; + + // 4. Scatter + if (active) { + if (bin > thr) { + write_pos = num_above + atomicAdd(&smem->counter, 1); + active = false; + } else if (bin < thr) { + active = false; + } else if (round == 3) { + write_pos = K - atomicAdd(&smem->match.equal_count, -1u); + } + // my_bin == thr && round < 3: stay active for next round + } + + topk_remain -= n_above; + if (topk_remain == 0) break; + } + + if (write_pos < K) params.write(write_pos, idx); +} + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh new file mode 100644 index 0000000000..73eef555f4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/ptx.cuh @@ -0,0 +1,54 @@ +#pragma once +#include + +#include + +#include + +namespace device::top512 { + +namespace ptx { + +SGL_DEVICE void mbarrier_wait(uint64_t* addr, uint32_t phase) { + while (!cuda::ptx::mbarrier_try_wait_parity(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, addr, phase)) + ; +} + +SGL_DEVICE void mbarrier_init(uint64_t* addr, uint32_t arrives) { + cuda::ptx::mbarrier_init(addr, arrives); +} + +SGL_DEVICE void mbarrier_arrive_expect_tx(uint64_t* addr, uint32_t tx) { + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr, tx); +} + +SGL_DEVICE void mbarrier_arrive(uint64_t* addr) { + cuda::ptx::mbarrier_arrive(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr); +} + +SGL_DEVICE void tma_load(void* dst, const void* src, uint32_t num_bytes, uint64_t* mbar) { + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, cuda::ptx::space_global, dst, src, num_bytes, mbar); +} + +SGL_DEVICE uint32_t elect_sync() { + uint32_t pred = 0; + asm volatile( + "{\n\t" + ".reg .pred %%px;\n\t" + "elect.sync _|%%px, %1;\n\t" + "@%%px mov.s32 %0, 1;\n\t" + "}" + : "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} + +SGL_DEVICE bool elect_sync_cta(uint32_t tx) { + const auto warp_id = tx / 32; + const auto uniform_warp_id = __shfl_sync(0xFFFFFFFF, warp_id, 0); + return (uniform_warp_id == 0 && elect_sync()); +} + +} // namespace ptx + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh new file mode 100644 index 0000000000..77d7361ee8 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/register.cuh @@ -0,0 +1,302 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct RegisterTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kVecsPerThread = 4; + static constexpr uint32_t kMaxTolerance = 0; + static constexpr uint32_t kMax1PassLength = kVecsPerThread * 4 * kBlockSize; + static constexpr uint32_t kMaxExtraLength = kMax1PassLength; + static constexpr uint32_t kMax2PassLength = kMax1PassLength + kMaxExtraLength; + + struct Smem { + using HistVec = AlignedVector; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + uint64_t mbarrier; // for cp.async + MatchBin match; + uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + alignas(16) float score_buffer[kMaxExtraLength]; + }; + + template + SGL_DEVICE static void + run(const float* scores, // + int32_t* indices, + const uint32_t length, + void* _smem, + const bool use_pdl = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram + { + typename Smem::HistVec hist_vec; + hist_vec.fill(0); + smem->histogram_vec[tx] = hist_vec; + if (tx == 0) { + smem->counter_gt = smem->counter_eq = 0; + if constexpr (kIs2Pass) { + ptx::mbarrier_init(&smem->mbarrier, 1); + } + } + __syncthreads(); + } + + if (use_pdl) device::PDLWaitPrimary(); + + // Load scores into registers + Vec4 local[kVecsPerThread]; +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { + const uint32_t base = (tx + v * kBlockSize) * 4; + if (base >= length) break; + local[v].load(scores, tx + v * kBlockSize); + } + + // Fetch the next chunk of scores + if constexpr (kIs2Pass) { + if (ptx::elect_sync_cta(tx)) { + const auto length_aligned = (length + 3u - kMax1PassLength) & ~3u; + const auto size_bytes = length_aligned * sizeof(float); + ptx::tma_load(smem->score_buffer, scores + kMax1PassLength, size_bytes, &smem->mbarrier); + ptx::mbarrier_arrive_expect_tx(&smem->mbarrier, size_bytes); + } + __syncwarp(); // avoid warp divergence on + } + + // Accumulate histogram via shared-memory atomics +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + if constexpr (!kIs2Pass) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if (idx >= length) goto LABEL_ACC_FINISH; + } + atomicAdd(&smem->histogram[extract_coarse_bin(local[v][e])], 1); + } + } + if constexpr (kIs2Pass) { + // 16K ~ 32K. `i` is a float4 index + if (lane_id == 0) ptx::mbarrier_wait(&smem->mbarrier, 0); + __syncwarp(); + for (uint32_t i = tx; i + kMax1PassLength < length; i += kBlockSize) { + const auto val = smem->score_buffer[i]; + atomicAdd(&smem->histogram[extract_coarse_bin(val)], 1); + } + } + [[maybe_unused]] LABEL_ACC_FINISH: + __syncthreads(); + + // Phase 2: Exclusive prefix scan -> find threshold bin + { + constexpr uint32_t kItems = kHistBins / kBlockSize; + uint32_t orig[kItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t tmp_local_sum = 0; + +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + orig[i] = hist_vec[i]; + tmp_local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, tmp_local_sum); + const auto warp_exc = warp_inc - tmp_local_sum; + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + // Exactly one bin satisfies: above < K && above + count >= K + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // Phase 3: Scatter + // Elements strictly above threshold go directly to output. + // Tied elements: simple path admits first-come; tiebreak path collects into tie_buffer. + const bool need_tiebreak = (num_equal + num_above > K + kMaxTolerance); + const auto topk_indices = indices; + const auto tie_buffer = smem->tie_buffer; + +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if constexpr (!kIs2Pass) { + if (idx >= length) goto LABEL_SCATTER_DONE; + } + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + // prefetch the next scores + if constexpr (kIs2Pass) { + local[v].load(smem->score_buffer, tx + v * kBlockSize); + } + } + + // 16K ~ 32K, already in registers: similar loop as above but read from smem->score_buffer + if constexpr (kIs2Pass) { +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e + kMax1PassLength; + if (idx >= length) goto LABEL_SCATTER_DONE; + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + } + } + + [[maybe_unused]] LABEL_SCATTER_DONE: + if (!need_tiebreak) return; + + // Phase 4: Tie-breaking within the threshold bin. + // Assume num_ties <= kBlockSize (at most 1 block of ties). + // Each thread takes one tied element, computes its rank (number of + // elements with strictly higher score, breaking exact float ties by + // original index), and writes to output if rank < topk_remain. + __syncthreads(); + static_assert(kMaxTies <= kBlockSize); + + const uint32_t num_ties = min(num_equal, kMaxTies); + const uint32_t topk_remain = K - num_above; + + const auto is_greater = [](const Tie& a, const Tie& b) { + return (a.score > b.score) || (a.score == b.score && a.idx < b.idx); + }; + + if (num_ties <= kWarpThreads) { + static_assert(kWarpThreads <= kNumWarps); + if (lane_id >= num_ties || warp_id >= num_ties) return; // some threads are idle + /// NOTE: use long long to avoid mask overflow when num_ties == 32 + const uint32_t mask = (1ull << num_ties) - 1u; + const auto tie = tie_buffer[lane_id]; + const auto target_tie = tie_buffer[warp_id]; + const bool pred = is_greater(tie, target_tie); + const auto rank = static_cast(__popc(__ballot_sync(mask, pred))); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } else if (num_ties <= kWarpThreads * 2) { + // 64 x 64 topk implementation: each thread takes 2 elements + const auto lane_id_1 = lane_id + kWarpThreads; + const auto warp_id_1 = warp_id + kWarpThreads; + const auto invalid = Tie{.idx = 0xFFFFFFFF, .score = -FLT_MAX}; + const auto tie_0 = tie_buffer[lane_id]; + const auto tie_1 = lane_id_1 < num_ties ? tie_buffer[lane_id_1] : invalid; + if (true) { + const auto target = tie_buffer[warp_id]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + if (warp_id_1 < num_ties) { + const auto target = tie_buffer[warp_id_1]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + } else { + /// NOTE: Based on my observation, this path is very rarely reached + [[unlikely]]; + // Block-level: each thread reads from tie_buffer in shared memory + for (auto i = warp_id; i < num_ties; i += kNumWarps) { + const auto target_tie = tie_buffer[i]; + uint32_t local_rank = 0; + for (auto j = lane_id; j < num_ties; j += kWarpThreads) { + const auto tie = tie_buffer[j]; + if (is_greater(tie, target_tie)) local_rank++; + } + // sum the rank across the warp + const auto rank = warp::reduce_sum(local_rank); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } + } + } + + SGL_DEVICE static void transform(const TransformParams params) { + __syncthreads(); + if (const auto tx = threadIdx.x; tx < K) params.transform(tx); + } +}; + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh new file mode 100644 index 0000000000..4462b89a19 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/deepseek_v4/topk/streaming.cuh @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct StreamingTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 2; // double buffer + + static constexpr uint32_t kHistItems = kHistBins / kBlockSize; // 4 + static_assert(kHistItems * kBlockSize == kHistBins); + using HistVec = AlignedVector; + + struct Smem { + uint64_t barrier[2][kNumStages]; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + union { + float score_buffer[kNumStages][kSizePerStage]; + TieHandleSmem stage2; // reuse smem for tie handling in phase D + }; + }; + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /// NOTE: length must be 4-aligned since we load 4 floats/thread. Caller should round up. + template + SGL_DEVICE static void issue_tma(const float* scores, uint32_t stage, uint32_t length, Smem* smem) { + const auto buf_idx = stage % kNumStages; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[kIsScatter][buf_idx]; + ptx::tma_load(smem->score_buffer[buf_idx], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + + // --------------------------------------------------------------------------- + // Unified streaming pass. Used for both phase A (kIsScatter=false) and + // phase C (kIsScatter=true). Each buffer is reused across iterations via the + // reuse-arrive trick (same pattern as ClusterTopKImpl::stage1). + // --------------------------------------------------------------------------- + + template + SGL_DEVICE static void stream_pass( + const float* scores, + const uint32_t length, + const uint32_t thr_bin, // ignored when !kIsScatter + int32_t* s_topk_indices, // ignored when !kIsScatter + Smem* smem) { + const auto tx = threadIdx.x; + const auto num_iters = (length + kSizePerStage - 1) / kSizePerStage; + const auto lane_id = tx % kWarpThreads; + + // Initial double-buffer TMA prologue. + const auto length_aligned = (length + 3u) & ~3u; + if (tx == 0) { +#pragma unroll + for (uint32_t i = 0; i < kNumStages; i++) { + if (i >= num_iters) break; + issue_tma(scores, i, length_aligned, smem); + } + } + + for (uint32_t iter = 0; iter < num_iters; iter++) { + const auto buf_idx = iter % kNumStages; + const auto offset = iter * kSizePerStage; + const auto this_size = min(kSizePerStage, length - offset); + + if (lane_id == 1) { + const auto phase_bit = (iter / kNumStages) & 1; + ptx::mbarrier_wait(&smem->barrier[kIsScatter][buf_idx], phase_bit); + } + __syncwarp(); + +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; i++) { + const auto local_idx = tx + i * kBlockSize; + if (local_idx >= this_size) break; + const auto score = smem->score_buffer[buf_idx][local_idx]; + const auto bin = extract_coarse_bin(score); + if constexpr (kIsScatter) { + const auto global_idx = offset + local_idx; + if (bin > thr_bin) { + const auto pos = atomicAdd(&smem->counter_gt, 1); + if (pos < K) s_topk_indices[pos] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } else { + atomicAdd(&smem->histogram[bin], 1); + } + } + + __syncthreads(); + if (tx == 0) { + if (const auto next_iter = iter + kNumStages; next_iter < num_iters) { + issue_tma(scores, next_iter, length_aligned, smem); + } + } + } + } + + // --------------------------------------------------------------------------- + // Phase B: find the threshold bin via a warp-level prefix scan. + // Same structure as SmallTopKImpl's phase 2 (4 bins/thread, warp_sum relay). + // --------------------------------------------------------------------------- + + SGL_DEVICE static void find_threshold(uint32_t length, Smem* smem) { + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + uint32_t orig[kHistItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t local_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + orig[i] = hist_vec[i]; + local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, local_sum); + const auto warp_exc = warp_inc - local_sum; + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kHistItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + SGL_DEVICE static void run(const float* scores, const uint32_t length, int32_t* topk_indices, void* _smem) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + + // Init histogram, barriers, counters. + { + HistVec zero; + zero.fill(0); + smem->histogram_vec[tx] = zero; + if (tx < 2 * kNumStages) { + const auto base_barrier = &smem->barrier[0][0]; + ptx::mbarrier_init(&base_barrier[tx], 1); + } + if (tx == 0) { + smem->counter_gt = 0; + smem->counter_eq = 0; + } + __syncthreads(); + } + + // Phase A: histogram pass (pipelined TMA stream). + stream_pass(scores, length, 0, nullptr, smem); + + // Phase B: locate threshold bin & re-init barriers + find_threshold(length, smem); + + // Phase C: scatter pass. + stream_pass(scores, length, smem->match.bin, topk_indices, smem); + } + + SGL_DEVICE static void transform(const TransformParams params, void* _smem) { + // Phase D: page-translate above entries, then refine ties. + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto num_above = smem->match.above_count; + if (tx < num_above) params.transform(tx); + const auto num_equal = smem->counter_eq; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(smem->tie_buffer, clamped_ties, num_above, K, params, &smem->stage2); + } +}; + +} // namespace device::top512 diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh new file mode 100644 index 0000000000..e0ce2dc086 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/common.cuh @@ -0,0 +1,120 @@ +#pragma once +#include + +namespace device::distributed { + +inline constexpr uint32_t kMaxNumGPU = 8; + +struct alignas(128) Semaphore { + public: + constexpr Semaphore() : m_flag(0), m_counter(0) {} + + template + SGL_DEVICE uint32_t get() const { + uint32_t val; + if constexpr (kFence) { + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(val) : "l"(&m_flag)); + } else { + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(val) : "l"(&m_flag)); + } + return val; + } + + template + SGL_DEVICE uint32_t add(uint32_t val) { + uint32_t old_val; + if constexpr (kFence) { + asm volatile("atom.release.sys.global.add.u32 %0, [%1], %2;" : "=r"(old_val) : "l"(&m_flag), "r"(val)); + } else { + asm volatile("atom.global.add.u32 %0, [%1], %2;" : "=r"(old_val) : "l"(&m_flag), "r"(val)); + } + return old_val; + } + + // Only called by the owning GPU - plain load is sufficient + SGL_DEVICE uint32_t get_counter() const { + return m_counter; + } + + // Only called by the owning GPU - plain store is sufficient + SGL_DEVICE void set_counter(uint32_t val) { + m_counter = val; + } + + private: + uint32_t m_flag; + uint32_t m_counter; +}; + +struct PullController { + public: + using SignalType = Semaphore; + + PullController(void** signals, uint32_t num_gpu) { + for (uint32_t i = 0; i < num_gpu; ++i) { + m_signals[i] = static_cast(signals[i]); + } + } + + /// Synchronize all GPUs. + /// When kFence is true, establishes happens-before across GPUs using + /// release/acquire semantics, ensuring prior writes are visible system-wide. + template + SGL_DEVICE void sync(uint32_t rank, uint32_t num_gpu) const { + // For fenced sync: ensure all threads in this block have completed their writes, + // so the signaling thread's release carries them transitively. + static_assert(!(kFence && kStart), "Start stage does not need to wait fence"); + if constexpr (kFence || !kStart) __syncthreads(); + constexpr auto kStage = kStart ? 1 : 2; + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + if (lane_id == 0 && warp_id < num_gpu) { + auto& signal = m_signals[warp_id][blockIdx.x]; + signal.add(1); + if (warp_id == rank) { + const auto target = num_gpu * kStage; + /// NOTE: correctness here: + /// - base is only read/updated locally by the owning GPU + const auto base = signal.get_counter(); + while (signal.get() - base < target) + ; + if constexpr (!kStart) { + signal.set_counter(base + target); + } + } + } + if constexpr (kStart) __syncthreads(); + } + + private: + Semaphore* __restrict__ m_signals[kMaxNumGPU]; +}; + +struct PushController { + public: + using SignalType = uint32_t; + static constexpr int64_t kNumStages = 2; + + PushController(void* ptr) : m_local_signal(static_cast(ptr)) {} + + SGL_DEVICE SignalType epoch() const { + return m_local_signal[blockIdx.x]; + } + + SGL_DEVICE void exit() const { + __syncthreads(); + if (threadIdx.x == 0) { + this->exit_unsafe(blockIdx.x); + } + } + + SGL_DEVICE void exit_unsafe(uint32_t which) const { + auto& signal = m_local_signal[which]; + signal = (signal + 1) % kNumStages; + } + + private: + SignalType* m_local_signal; +}; + +} // namespace device::distributed diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh new file mode 100644 index 0000000000..239fac71a1 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/distributed/custom_all_reduce.cuh @@ -0,0 +1,354 @@ +#pragma once +#include + +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace host::distributed { + +using device::distributed::PullController, device::distributed::PushController; + +struct AllReduceData { + constexpr AllReduceData() {} + void* __restrict__ input[device::distributed::kMaxNumGPU]; +}; + +using ExternHandle = tvm::ffi::Array; + +inline ExternHandle to_extern_handle(void* ptr) { + ExternHandle array; + cudaIpcMemHandle_t handle; + RuntimeDeviceCheck(cudaIpcGetMemHandle(&handle, ptr)); + for (size_t i = 0; i < sizeof(handle); ++i) { + array.push_back(handle.reserved[i]); + } + return array; +} + +inline void* from_extern_handle(const ExternHandle& array) { + cudaIpcMemHandle_t handle; + RuntimeCheck(array.size() == sizeof(handle), "Invalid IPC handle size: ", array.size()); + for (size_t i = 0; i < sizeof(handle); ++i) { + handle.reserved[i] = array[i]; + } + void* ptr; + RuntimeDeviceCheck(cudaIpcOpenMemHandle(&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); + return ptr; +} + +struct HandleHash { + std::size_t operator()(const cudaIpcMemHandle_t& handle) const { + return std::hash{}({handle.reserved, sizeof(handle.reserved)}); + } +}; + +struct HandleEqual { + bool operator()(const cudaIpcMemHandle_t& a, const cudaIpcMemHandle_t& b) const { + return std::memcmp(a.reserved, b.reserved, sizeof(a.reserved)) == 0; + } +}; + +/** + * \brief The control plane of the custom all-reduce implementation. + * It manages the internal state and synchronization of the participating GPUs. + */ +struct CustomAllReduceBase : public tvm::ffi::Object { + public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.CustomAllReduce", CustomAllReduceBase, tvm::ffi::Object); + + static constexpr bool _type_mutable = true; + using InputPair = tvm::ffi::Tuple; // (offset, ipc handle) + + CustomAllReduceBase( + uint32_t rank, + uint32_t num_gpu, + uint32_t max_num_cta_pull, + uint32_t max_num_cta_push, + int64_t pull_buffer_size, + int64_t push_buffer_size, + int64_t graph_buffer_count) + : m_pull_buffer_bytes(pull_buffer_size), + m_push_buffer_bytes(push_buffer_size), + m_graph_buffer_count(graph_buffer_count), + m_rank(rank), + m_num_gpu(num_gpu), + m_max_num_cta_pull(max_num_cta_pull), + m_max_num_cta_push(max_num_cta_push), + // default config for pull kernel, can be updated by `configure()` + m_num_cta(max_num_cta_pull), + m_cta_size(256) { + RuntimeCheck(pull_buffer_size % 128 == 0, "Pull buffer size should be aligned to 128 bytes"); + RuntimeCheck(push_buffer_size % 128 == 0, "Push buffer size should be aligned to 128 bytes"); + RuntimeCheck(rank < num_gpu, "Invalid rank: ", rank); + const int64_t kU32Max = static_cast(std::numeric_limits::max()); + const int64_t push_buffer_size_all = push_all_ranks_bytes(); + RuntimeCheck(pull_buffer_size <= kU32Max, "Pull buffer size is too large: ", pull_buffer_size); + RuntimeCheck(push_buffer_size_all <= kU32Max, "Push buffer size is too large: ", push_buffer_size_all); + RuntimeDeviceCheck(cudaMalloc(&m_storage, storage_bytes())); + } + + ExternHandle share_storage() { + return to_extern_handle(m_storage); + } + + tvm::ffi::Array share_graph_inputs() { + tvm::ffi::Array result; + const auto new_inputs_count = registered_count() - m_cum_registered_count; + RuntimeCheck(new_inputs_count >= 0, "Invalid new count: ", new_inputs_count); + result.reserve(new_inputs_count); + std::unordered_map ipc_cache; + const auto get_handle = [&](void* ptr) -> ExternHandle { + const auto it = ipc_cache.find(ptr); + if (it != ipc_cache.end()) return it->second; + const auto handle = to_extern_handle(ptr); + ipc_cache.try_emplace(ptr, handle); + return handle; + }; + for (const auto ptr : std::span(m_graph_capture_inputs).subspan(m_cum_registered_count)) { + // note: must share the base address of each allocation, or we get wrong address + void* base_ptr; + const auto cu_result = cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr); + RuntimeCheck(cu_result == CUDA_SUCCESS, "failed to get pointer attr"); + const auto offset = reinterpret_cast(ptr) - reinterpret_cast(base_ptr); + result.push_back(InputPair{offset, get_handle(base_ptr)}); + } + return result; + } + + void post_init(tvm::ffi::Array ipc_storages) { + RuntimeCheck(ipc_storages.size() == m_num_gpu, "Invalid array size: ", ipc_storages.size()); + m_peer_storage.resize(m_num_gpu); + for (const auto i : irange(m_num_gpu)) { + if (i == m_rank) { + m_peer_storage[i] = m_storage; + } else { + m_peer_storage[i] = from_extern_handle(ipc_storages[i]); + } + } + + // set signal buffer to zero + const auto pull_signal = get_pull_signal(m_storage); + RuntimeDeviceCheck(cudaMemset(pull_signal, 0, pull_signal_bytes())); + + // update the pull controller and data pointer + RuntimeCheck(!m_pull_ctrl.has_value(), "Controller is already initialized"); + m_pull_ctrl.emplace(m_peer_storage.data(), m_num_gpu); + AllReduceData data; + for (const auto i : irange(m_num_gpu)) { + data.input[i] = get_pull_buffer(m_peer_storage[i]); + } + const auto default_data_ptr = get_data_ptr(); + RuntimeDeviceCheck(cudaMemcpy(default_data_ptr, &data, sizeof(AllReduceData), cudaMemcpyHostToDevice)); + + // update the push controller and data pointer + RuntimeCheck(!m_push_ctrl.has_value(), "Controller is already initialized"); + const auto push_signal = get_push_signal(m_storage); + RuntimeDeviceCheck(cudaMemset(push_signal, 0, push_signal_bytes())); + m_push_ctrl.emplace(push_signal); + const auto push_buffer = get_push_buffer(m_storage); + RuntimeDeviceCheck(cudaMemset(push_buffer, 0, push_all_ranks_bytes())); + } + + void register_inputs(tvm::ffi::Array> ipc_graph_inputs) { + RuntimeCheck(ipc_graph_inputs.size() == m_num_gpu); + const auto new_registered_count = registered_count() - m_cum_registered_count; + RuntimeCheck(new_registered_count >= 0, "Invalid registered count: ", new_registered_count); + if (new_registered_count == 0) return; // avoid `m_get_data_ptr()` out-of-bounds + std::vector data; + data.resize(new_registered_count); + const auto open_cached = [&](const ExternHandle& h) -> void* { + RuntimeCheck(h.size() == sizeof(cudaIpcMemHandle_t), "Invalid IPC handle size: ", h.size()); + cudaIpcMemHandle_t handle; + for (size_t i = 0; i < sizeof(handle); ++i) + handle.reserved[i] = h[i]; + const auto [it, success] = m_ipc_cache.try_emplace(handle, nullptr); + if (success) { + void* ptr; + RuntimeDeviceCheck(cudaIpcOpenMemHandle(&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); + it->second = ptr; + } + return it->second; + }; + for (const auto i : irange(ipc_graph_inputs.size())) { + const auto& array = ipc_graph_inputs[i]; + RuntimeCheck(int64_t(array.size()) == new_registered_count); + if (i == m_rank) { + for (const auto j : irange(new_registered_count)) { + data[j].input[i] = m_graph_capture_inputs[m_cum_registered_count + j]; + } + } else { + for (const auto j : irange(new_registered_count)) { + /// NOTE: structural binding will cause intern compiler error... + const auto elem = array[j]; + const auto offset = elem.get<0>(); + const auto ipc_handle = elem.get<1>(); + data[j].input[i] = pointer::offset(open_cached(ipc_handle), offset); + } + } + } + + const auto new_registered_bytes = sizeof(AllReduceData) * new_registered_count; + const auto dst_ptr = get_data_ptr(m_cum_registered_count); + m_cum_registered_count += new_registered_count; + RuntimeDeviceCheck(cudaMemcpy(dst_ptr, data.data(), new_registered_bytes, cudaMemcpyHostToDevice)); + } + + void set_cuda_graph_capture(bool enabled) { + m_is_graph_capturing = enabled; + } + + void free_ipc_handles() { + for (const auto& pair : m_ipc_cache) { + host::RuntimeDeviceCheck(cudaIpcCloseMemHandle(pair.second)); + } + m_ipc_cache.clear(); + } + + void free_storage() { + host::RuntimeDeviceCheck(cudaFree(m_storage)); + m_storage = nullptr; + } + + tvm::ffi::Tuple configure_pull(uint32_t num_cta, uint32_t cta_size) { + using host::RuntimeCheck; + const auto min_cta_size = m_num_gpu * device::kWarpThreads; + RuntimeCheck(num_cta > 0 && num_cta <= m_max_num_cta_pull, "Invalid number of CTAs: ", num_cta); + RuntimeCheck(cta_size >= min_cta_size, "Block size must be at least ", min_cta_size); + const auto old_num_cta = m_num_cta; + const auto old_block_size = m_cta_size; + m_num_cta = num_cta; + m_cta_size = cta_size; + return tvm::ffi::Tuple{old_num_cta, old_block_size}; + } + + protected: + AllReduceData* allocate_graph_capture_input(void* data_ptr) { + const auto count = registered_count(); + RuntimeCheck(count < m_graph_buffer_count, "Graph buffer overflow, increase `graph_buffer_count`!"); + m_graph_capture_inputs.push_back(data_ptr); + return get_data_ptr(count); + } + AllReduceData* get_data_ptr(int64_t which = -1) { + const auto count = registered_count(); + RuntimeCheck(which >= -1 && which < count, "Invalid graph buffer index: ", which, ", count: ", count); + const auto start = get_pull_params(m_storage); + return static_cast(start) + (1 + which); + } + int64_t registered_count() const { + return static_cast(m_graph_capture_inputs.size()); + } + int64_t pull_signal_bytes() const { + return _align_bytes(sizeof(PullController::SignalType) * m_max_num_cta_pull); + } + int64_t push_signal_bytes() const { + return _align_bytes(sizeof(PushController::SignalType) * m_max_num_cta_push); + } + int64_t graph_param_bytes() const { + return _align_bytes(sizeof(AllReduceData) * (1 + m_graph_buffer_count)); // 1 for default + } + int64_t push_all_ranks_bytes() const { + return _align_bytes(PushController::kNumStages * m_num_gpu * m_push_buffer_bytes); + } + int64_t storage_bytes() const { + return _get_offset_impl(5); + } + void* get_pull_signal(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(0)); + } + void* get_push_signal(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(1)); + } + void* get_pull_params(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(2)); + } + void* get_pull_buffer(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(3)); + } + void* get_push_buffer(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(4)); + } + int64_t _get_offset_impl(int64_t which) const { + // | SignalArray (pull + push) | GraphBuffers (pull params) | Buffers (pull + push) | + const int64_t offset_map[5] = { + /*[0]=*/pull_signal_bytes(), + /*[1]=*/push_signal_bytes(), + /*[2]=*/graph_param_bytes(), + /*[3]=*/m_pull_buffer_bytes, + /*[4]=*/push_all_ranks_bytes(), + }; + RuntimeCheck(which >= 0 && which <= 5, "Invalid offset index: ", which); + return std::accumulate(offset_map, offset_map + which, int64_t(0)); + } + static int64_t _align_bytes(int64_t size) { + return div_ceil(size, 128) * 128; + } + + const int64_t m_pull_buffer_bytes; + const int64_t m_push_buffer_bytes; + const int64_t m_graph_buffer_count; + const uint32_t m_rank; + const uint32_t m_num_gpu; + const uint32_t m_max_num_cta_pull; + const uint32_t m_max_num_cta_push; + // these 2 config should only affect pull kernel + uint32_t m_num_cta; + uint32_t m_cta_size; + // other states + bool m_is_graph_capturing = false; + int64_t m_cum_registered_count = 0; + std::optional m_pull_ctrl; + std::optional m_push_ctrl; + void* m_storage = nullptr; + std::vector m_graph_capture_inputs; + std::vector m_peer_storage; + std::unordered_map m_ipc_cache; +}; + +struct CustomAllReduceRef : public tvm::ffi::ObjectRef { + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(CustomAllReduceRef, tvm::ffi::ObjectRef, CustomAllReduceBase); +}; + +} // namespace host::distributed + +namespace device::distributed { + +template +SGL_DEVICE auto reduce_impl(AlignedVector (&storage)[M]) -> AlignedVector { + fp32x2_t acc[N] = {}; +#pragma unroll // unroll num gpu + for (uint32_t i = 0; i < M; ++i) { +#pragma unroll // unroll vec + for (uint32_t j = 0; j < N; ++j) { + const auto [x, y] = cast(storage[i][j]); + auto& [x_acc, y_acc] = acc[j]; + x_acc += x; + y_acc += y; + } + } + + AlignedVector result; +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + result[j] = cast(acc[j]); + } + + return result; +} + +} // namespace device::distributed diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h new file mode 100644 index 0000000000..17d9048d4c --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/ffi.h @@ -0,0 +1,104 @@ +#pragma once +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace host::ffi { + +using tvm::ffi::Tensor, tvm::ffi::TensorView, tvm::ffi::ShapeView; + +inline Tensor empty(ShapeView shape, DLDataType dtype, DLDevice device) { + return Tensor::FromEnvAlloc(::TVMFFIEnvTensorAlloc, shape, dtype, device); +} + +inline Tensor empty_like(TensorView tensor) { + return empty(tensor.shape(), tensor.dtype(), tensor.device()); +} + +struct _dummy_deleter { + void operator()(void*) const {} +}; + +// template + +template +struct FromBlobContext { + [[no_unique_address]] Fn deleter; + int64_t dimension; + int64_t* get_shape() { + return reinterpret_cast(this + 1); + } + int64_t* get_stride() { + return this->get_shape() + dimension; + } +}; + +template +inline Tensor from_blob( + void* data, + ShapeView shape, + DLDataType dtype, + DLDevice device, + Fn&& deleter = {}, + std::optional stride = {}, + uint64_t byte_offset = 0) { + using Context = FromBlobContext>; + const auto ndim = shape.size(); + const auto ctx = [&] { + auto ptr = std::malloc(sizeof(Context) + sizeof(int64_t) * ndim * 2); + auto ctx = static_cast(ptr); + std::construct_at(ctx, std::forward(deleter), static_cast(ndim)); + stdr::copy_n(shape.data(), ndim, ctx->get_shape()); + if (stride.has_value()) { + RuntimeCheck(stride->size() == ndim, "Stride ndim mismatch!"); + stdr::copy_n(stride->data(), ndim, ctx->get_stride()); + } else { + int64_t stride_val = 1; + for (const auto i : irange(ndim)) { + const auto j = ndim - 1 - i; + ctx->get_stride()[j] = stride_val; + stride_val *= shape[j]; + } + } + return ctx; + }(); + const auto tensor = DLTensor{ + .data = data, + .device = device, + .ndim = static_cast(ndim), + .dtype = dtype, + .shape = ctx->get_shape(), + .strides = ctx->get_stride(), + .byte_offset = byte_offset, + }; + const auto blob_deleter = [](DLManagedTensor* self) { + auto ctx = static_cast(self->manager_ctx); + ctx->deleter(self->dl_tensor.data); + std::destroy_at(ctx); + std::free(ctx); + }; + auto managed_tensor = DLManagedTensor{tensor, ctx, blob_deleter}; + return Tensor::FromDLPack(&managed_tensor); +} + +template +inline Tensor from_blob_like( + void* data, + TensorView t, + Fn&& deleter = {}, + bool is_contiguous = false, // if override to true, the stride will be ignored + uint64_t byte_offset = 0) { + const auto stride = is_contiguous ? std::nullopt : std::optional{t.strides()}; + return from_blob(data, t.shape(), t.dtype(), t.device(), std::forward(deleter), stride, byte_offset); +} + +} // namespace host::ffi diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh new file mode 100644 index 0000000000..cd024acd46 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/impl/norm.cuh @@ -0,0 +1,168 @@ +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace host::norm { + +/** + * \brief Check if the given configuration is supported. + * \tparam T Element type (only fp16_t/bf16_t is supported) + * \tparam kDim Dimension size (usually hidden size) + */ +template +inline constexpr bool is_config_supported() { + if (!std::is_same_v && !std::is_same_v) return false; + if (kDim <= 256) { + return (kDim == 64 || kDim == 128 || kDim == 256); + } else { + return (kDim % 256 == 0 && kDim <= 8192); + } +} + +/** + * \brief Determine whether to use cta norm based on dimension size. + * TL;DR: use warp norm for dim <= 256, cta norm otherwise. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \note This function assumes that the configuration is supported. + * \see `is_config_supported` + */ +template +inline constexpr bool should_use_cta() { + static_assert(is_config_supported(), "Unsupported norm configuration"); + return kDim > 256; +} + +/** + * \brief Get the number of threads per CTA for cta norm. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \return Number of threads per CTA + */ +template +inline constexpr uint32_t get_cta_threads() { + static_assert(should_use_cta()); + return (kDim / 256) * device::kWarpThreads; +} + +} // namespace host::norm + +namespace device::norm { + +namespace details { + +template +SGL_DEVICE AlignedVector apply_norm_impl( + const AlignedVector input, + const AlignedVector weight, + const float eps, + [[maybe_unused]] float* smem_buffer, + [[maybe_unused]] uint32_t num_warps) { + float sum_of_squares = 0.0f; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + sum_of_squares += fp32_input.x * fp32_input.x; + sum_of_squares += fp32_input.y * fp32_input.y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + float norm_factor; + if constexpr (kUseCTA) { + // need to synchronize across the cta + const auto warp_id = threadIdx.x / kWarpThreads; + smem_buffer[warp_id] = sum_of_squares; + __syncthreads(); + // use the first warp to reduce + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < num_warps ? smem_buffer[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem_buffer[32] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + norm_factor = smem_buffer[32]; + } else { + norm_factor = math::rsqrt(sum_of_squares / kDim + eps); + } + + AlignedVector output; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + const auto fp32_weight = cast(weight[i]); + output[i] = cast({ + fp32_input.x * norm_factor * fp32_weight.x, + fp32_input.y * norm_factor * fp32_weight.y, + }); + } + + return output; +} + +} // namespace details + +/** + * \brief Apply norm using warp-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_warp(const T& input, const T& weight, float eps) { + static_assert(kDim <= 256, "Warp norm only supports dim <= 256"); + return details::apply_norm_impl(input, weight, eps, nullptr, 0); +} + +/** + * \brief Apply norm using CTA-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \param smem Shared memory buffer + * \param num_warps Number of warps in the CTA + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_cta( + const T& input, const T& weight, float eps, float* smem, uint32_t num_warps = blockDim.x / kWarpThreads) { + static_assert(kDim > 256, "CTA norm only supports dim > 256"); + return details::apply_norm_impl(input, weight, eps, smem, num_warps); +} + +/** + * \brief Storage type for norm operation. + * For warp norm, the storage size depends on kDim. + * For cta norm, the storage size is fixed to 16B. + * We will also pack the input 16-bit floats into 32-bit types + * for faster CUDA core operations. + * + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size + */ +template +using StorageType = std::conditional_t< // storage type + (kDim > 256), // whether to use cta norm + AlignedVector, 4>, // cta norm storage, fixed to 16B + AlignedVector, kDim / (2 * kWarpThreads)> // warp norm storage + >; + +/** + * \brief Minimum shared memory size (in bytes) required for cta norm. + */ +inline constexpr uint32_t kSmemBufferSize = 33; + +} // namespace device::norm diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh new file mode 100644 index 0000000000..4f9ac48141 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/math.cuh @@ -0,0 +1,71 @@ +/// \file math.cuh +/// \brief Device-side math helper functions and constants. +/// +/// Provides type-generic wrappers around CUDA math intrinsics by +/// dispatching through `dtype_trait`. All functions are forced-inline +/// device functions. + +#pragma once +#include + +#include + +namespace device::math { + +/// \brief Constant: log2(e) +inline constexpr float log2e = 1.44269504088896340736f; +/// \brief Constant: ln(2) +inline constexpr float loge2 = 0.693147180559945309417f; +/// \brief Maximum representable value for FP8 E4M3 format. +inline constexpr float FP8_E4M3_MAX = 448.0f; +static_assert(log2e * loge2 == 1.0f, "log2e * loge2 must be 1"); + +/// \brief Returns the larger of `a` and `b`. +template +SGL_DEVICE T max(T a, T b) { + return dtype_trait::max(a, b); +} + +/// \brief Returns the smaller of `a` and `b`. +template +SGL_DEVICE T min(T a, T b) { + return dtype_trait::min(a, b); +} + +/// \brief Returns the absolute value of `a`. +template +SGL_DEVICE T abs(T a) { + return dtype_trait::abs(a); +} + +/// \brief Returns the square root of `a`. +template +SGL_DEVICE T sqrt(T a) { + return dtype_trait::sqrt(a); +} + +/// \brief Returns the reciprocal square root of `a` (i.e. 1 / sqrt(a)). +template +SGL_DEVICE T rsqrt(T a) { + return dtype_trait::rsqrt(a); +} + +/// \brief Returns e^a. +template +SGL_DEVICE T exp(T a) { + return dtype_trait::exp(a); +} + +/// \brief Returns sin(a). +template +SGL_DEVICE T sin(T a) { + return dtype_trait::sin(a); +} + +/// \brief Returns cos(a). +template +SGL_DEVICE T cos(T a) { + return dtype_trait::cos(a); +} + +} // namespace device::math diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh new file mode 100644 index 0000000000..4ea722a3fe --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/runtime.cuh @@ -0,0 +1,86 @@ +/// \file runtime.cuh +/// \brief Host-side CUDA runtime query helpers. +/// +/// Thin wrappers around CUDA occupancy and device-property APIs with +/// automatic error checking via `RuntimeDeviceCheck`. + +#pragma once + +#include + +#include +#include +#ifndef USE_ROCM +#include +#else +#include +#ifndef cudaOccupancyMaxActiveBlocksPerMultiprocessor +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#endif +#ifndef cudaDeviceGetAttribute +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#endif +#ifndef cudaDevAttrMultiProcessorCount +#define cudaDevAttrMultiProcessorCount hipDeviceAttributeMultiprocessorCount +#endif +#ifndef cudaDevAttrComputeCapabilityMajor +#define cudaDevAttrComputeCapabilityMajor hipDeviceAttributeComputeCapabilityMajor +#endif +#ifndef cudaRuntimeGetVersion +#define cudaRuntimeGetVersion hipRuntimeGetVersion +#endif +#ifndef cudaOccupancyAvailableDynamicSMemPerBlock +inline hipError_t +cudaOccupancyAvailableDynamicSMemPerBlock(std::size_t* smem, const void* func, int num_blocks, int block_size) { + // HIP does not expose this directly; return max shared mem as conservative estimate + hipDeviceProp_t prop; + int device; + hipGetDevice(&device); + hipGetDeviceProperties(&prop, device); + *smem = prop.sharedMemPerBlock; + return hipSuccess; +} +#endif +#endif + +namespace host::runtime { + +// Return the maximum number of active blocks per SM for the given kernel +template +inline auto get_blocks_per_sm(T&& kernel, int32_t block_dim, std::size_t dynamic_smem = 0) -> uint32_t { + int num_blocks_per_sm = 0; + RuntimeDeviceCheck( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, block_dim, dynamic_smem)); + return static_cast(num_blocks_per_sm); +} + +// Return the number of SMs for the given device +inline auto get_sm_count(int device_id) -> uint32_t { + int sm_count; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); + return static_cast(sm_count); +} + +// Return the Major compute capability for the given device +inline auto get_cc_major(int device_id) -> int { + int cc_major; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device_id)); + return cc_major; +} + +// Return the runtime version +inline auto get_runtime_version() -> int { + int runtime_version; + RuntimeDeviceCheck(cudaRuntimeGetVersion(&runtime_version)); + return runtime_version; +} + +// Return the maximum dynamic shared memory per block for the given kernel +template +inline auto get_available_dynamic_smem_per_block(T&& kernel, int num_blocks, int block_size) -> std::size_t { + std::size_t smem_size; + RuntimeDeviceCheck(cudaOccupancyAvailableDynamicSMemPerBlock(&smem_size, kernel, num_blocks, block_size)); + return smem_size; +} + +} // namespace host::runtime diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp b/lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp new file mode 100644 index 0000000000..d229d3a975 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/scalar_type.hpp @@ -0,0 +1,334 @@ +#pragma once + +#include +#include +#ifndef __CUDACC__ +#include +#endif + +namespace host { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ + private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + + public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = + "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h new file mode 100644 index 0000000000..7c9fd52131 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/source_location.h @@ -0,0 +1,40 @@ +/// \file source_location.h +/// \brief Portable `source_location` wrapper. +/// +/// Uses `std::source_location` when available (C++20), otherwise falls +/// back to a minimal stub that returns empty/zero values. + +#pragma once +#include + +/// NOTE: fallback to a minimal source_location implementation +#if defined(__cpp_lib_source_location) +#include + +using source_location_t = std::source_location; + +#else + +struct source_location_fallback { + public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char* file_name() const noexcept { + return ""; + } + constexpr const char* function_name() const noexcept { + return ""; + } +}; + +using source_location_t = source_location_fallback; + +#endif diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h new file mode 100644 index 0000000000..1ae9233a61 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/tensor.h @@ -0,0 +1,605 @@ +/// \file tensor.h +/// \brief Tensor validation and symbolic matching utilities. +/// +/// Provides the `TensorMatcher` fluent API for validating tensor shapes, +/// strides, dtypes, and devices at kernel entry points, along with +/// `SymbolicSize`, `SymbolicDType`, and `SymbolicDevice` for capturing +/// and cross-checking tensor metadata across multiple tensors. +/// +/// See the "Tensor Checking" section in the JIT kernel dev guide for +/// usage examples. + +#pragma once +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#elif defined(__HIPCC__) +#include +#endif + +namespace host { + +namespace details { + +inline constexpr auto kAnyDeviceID = -1; +inline constexpr auto kAnySize = static_cast(-1); +inline constexpr auto kNullSize = static_cast(-1); +inline constexpr auto kNullDType = static_cast(18u); +inline constexpr auto kNullDevice = static_cast(-1); + +struct SizeRef; +struct DTypeRef; +struct DeviceRef; + +template +struct _dtype_trait {}; + +template +struct _dtype_trait { + inline static constexpr DLDataType value = { + .code = std::is_signed_v ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt, + .bits = static_cast(sizeof(T) * 8), + .lanes = 1}; +}; + +template +struct _dtype_trait { + inline static constexpr DLDataType value = { + .code = DLDataTypeCode::kDLFloat, .bits = static_cast(sizeof(T) * 8), .lanes = 1}; +}; + +#ifdef __CUDACC__ +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat8_e4m3fn, .bits = 8, .lanes = 1}; +}; +#elif defined(__HIPCC__) +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1}; +}; +template <> +struct _dtype_trait { + inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1}; +}; +#endif + +template +struct _device_trait { + inline static constexpr DLDevice value = {.device_type = Code, .device_id = kAnyDeviceID}; +}; + +template +inline constexpr auto kDTypeList = std::array{_dtype_trait::value...}; + +template +inline constexpr auto kDeviceList = std::array{_device_trait::value...}; + +template +struct PrintAbleSpan { + explicit PrintAbleSpan(std::span data) : data(data) {} + std::span data; +}; + +// define DLDataType comparison and printing in root namespace +inline constexpr auto kDeviceStringMap = [] { + constexpr auto map = std::array, 16>{ + std::pair{DLDeviceType::kDLCPU, "cpu"}, + std::pair{DLDeviceType::kDLCUDA, "cuda"}, + std::pair{DLDeviceType::kDLCUDAHost, "cuda_host"}, + std::pair{DLDeviceType::kDLOpenCL, "opencl"}, + std::pair{DLDeviceType::kDLVulkan, "vulkan"}, + std::pair{DLDeviceType::kDLMetal, "metal"}, + std::pair{DLDeviceType::kDLVPI, "vpi"}, + std::pair{DLDeviceType::kDLROCM, "rocm"}, + std::pair{DLDeviceType::kDLROCMHost, "rocm_host"}, + std::pair{DLDeviceType::kDLExtDev, "ext_dev"}, + std::pair{DLDeviceType::kDLCUDAManaged, "cuda_managed"}, + std::pair{DLDeviceType::kDLOneAPI, "oneapi"}, + std::pair{DLDeviceType::kDLWebGPU, "webgpu"}, + std::pair{DLDeviceType::kDLHexagon, "hexagon"}, + std::pair{DLDeviceType::kDLMAIA, "maia"}, + std::pair{DLDeviceType::kDLTrn, "trn"}, + }; + constexpr auto max_type = stdr::max(map | stdv::keys); + auto result = std::array{}; + for (const auto& [code, name] : map) { + result[static_cast(code)] = name; + } + return result; +}(); + +struct PrintableDevice { + DLDevice device; +}; + +inline auto& operator<<(std::ostream& os, DLDevice device) { + const auto& mapping = kDeviceStringMap; + const auto entry = static_cast(device.device_type); + RuntimeCheck(entry < mapping.size()); + const auto name = mapping[entry]; + RuntimeCheck(!name.empty(), "Unknown device: ", int(device.device_type)); + os << name; + if (device.device_id != kAnyDeviceID && device.device_type != DLDeviceType::kDLCPU) { + os << ":" << device.device_id; + } + return os; +} + +inline auto& operator<<(std::ostream& os, PrintableDevice pd) { + return os << pd.device; +} + +template +inline auto& operator<<(std::ostream& os, PrintAbleSpan span) { + os << "["; + for (const auto i : irange(span.data.size())) { + if (i > 0) { + os << ", "; + } + os << span.data[i]; + } + os << "]"; + return os; +} + +} // namespace details + +/// \brief Check whether `dtype` matches the DLDataType for C++ type `T`. +template +inline bool is_type(DLDataType dtype) { + return dtype == details::_dtype_trait::value; +} + +/** + * \brief A symbolic dimension size that can be bound once and + * verified across multiple tensors. + * + * Create with an optional annotation string for error messages: + * \code + * auto N = SymbolicSize{"num_tokens"}; + * \endcode + * + * Call `verify()` during tensor matching to either bind the first + * observed value or check subsequent values match. Call `unwrap()` + * to retrieve the bound value (panics if unset). + */ +struct SymbolicSize { + public: + SymbolicSize(std::string_view annotation = {}) : m_value(details::kNullSize), m_annotation(annotation) {} + SymbolicSize(const SymbolicSize&) = delete; + SymbolicSize& operator=(const SymbolicSize&) = delete; + + auto get_name() const -> std::string_view { + return m_annotation; + } + + auto set_value(int64_t value) -> void { + RuntimeCheck(!this->has_value(), "Size value already set"); + m_value = value; + } + + auto has_value() const -> bool { + return m_value != details::kNullSize; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> int64_t { + RuntimeCheck(info, this->has_value(), "Size value is not set"); + return m_value; + } + + auto verify(int64_t value, const char* prefix, int64_t dim) -> void { + if (this->has_value()) { + if (m_value != value) { + [[unlikely]]; + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " but got ", value); + } + } else { + this->set_value(value); + } + } + + auto value_or_name(const char* prefix, int64_t dim) const -> std::string { + if (const auto value = this->get_value()) { + return std::to_string(*value); + } else { + return m_name_str(prefix, dim); + } + } + + private: + auto m_name_str(const char* prefix, int64_t dim) const -> std::string { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_annotation.empty()) os << "('" << m_annotation << "')"; + return std::move(os).str(); + } + + std::int64_t m_value; + std::string_view m_annotation; +}; + +inline auto operator==(DLDevice lhs, DLDevice rhs) -> bool { + return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id; +} + +/** + * \brief A symbolic data type that can be constrained and verified. + * + * Optionally restrict allowed types via `set_options()`. + * Use `verify()` to bind/check the dtype, and `unwrap()` to retrieve it. + */ +struct SymbolicDType { + public: + SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} + SymbolicDType(const SymbolicDType&) = delete; + SymbolicDType& operator=(const SymbolicDType&) = delete; + + auto set_value(DLDataType value) -> void { + RuntimeCheck(!this->has_value(), "Dtype value already set"); + RuntimeCheck( + m_check(value), "Dtype value [", value, "] not in the allowed options: ", details::PrintAbleSpan{m_options}); + m_value = value; + } + + auto has_value() const -> bool { + return m_value.code != details::kNullDType; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> DLDataType { + RuntimeCheck(info, this->has_value(), "Dtype value is not set"); + return m_value; + } + + auto set_options(std::span options) -> void { + m_options = options; + } + + template + auto set_options() -> void { + m_options = details::kDTypeList; + } + + auto verify(DLDataType dtype) -> void { + if (this->has_value()) { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " but got ", dtype); + } else { + this->set_value(dtype); + } + } + + template + auto is_type() const -> bool { + return ::host::is_type(m_value); + } + + private: + auto m_check(DLDataType value) const -> bool { + return stdr::empty(m_options) || (stdr::find(m_options, value) != stdr::end(m_options)); + } + + std::span m_options; + DLDataType m_value; +}; + +/** + * \brief A symbolic device that can be constrained and verified. + * + * Optionally restrict allowed device types via + * `set_options()`. The device id can be wildcarded. + */ +struct SymbolicDevice { + public: + SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} + SymbolicDevice(const SymbolicDevice&) = delete; + SymbolicDevice& operator=(const SymbolicDevice&) = delete; + + auto set_value(DLDevice value) -> void { + RuntimeCheck(!this->has_value(), "Device value already set"); + RuntimeCheck( + m_check(value), + "Device value [", + details::PrintableDevice{value}, + "] not in the allowed options: ", + details::PrintAbleSpan{m_options}); + m_value = value; + } + + auto has_value() const -> bool { + return m_value.device_type != details::kNullDevice; + } + + auto get_value() const -> std::optional { + return this->has_value() ? std::optional{m_value} : std::nullopt; + } + + auto unwrap(DebugInfo info = {}) const -> DLDevice { + RuntimeCheck(info, this->has_value(), "Device value is not set"); + return m_value; + } + + auto set_options(std::span options) -> void { + m_options = options; + } + + template + auto set_options() -> void { + m_options = details::kDeviceList; + } + + auto verify(DLDevice device) -> void { + if (this->has_value()) { + RuntimeCheck( + m_value == device, + "Device mismatch: expected ", + details::PrintableDevice{m_value}, + " but got ", + details::PrintableDevice{device}); + } else { + this->set_value(device); + } + } + + private: + auto m_check(DLDevice value) const -> bool { + return stdr::empty(m_options) || (stdr::any_of(m_options, [value](const DLDevice& opt) { + // device type must exactly match + if (opt.device_type != value.device_type) return false; + // device id can be wildcarded + return opt.device_id == details::kAnyDeviceID || opt.device_id == value.device_id; + })); + } + + std::span m_options; + DLDevice m_value; +}; + +namespace details { + +template +struct BaseRef { + public: + BaseRef(const BaseRef&) = delete; + BaseRef& operator=(const BaseRef&) = delete; + + auto operator->() const -> T* { + return m_ref; + } + auto operator*() const -> T& { + return *m_ref; + } + auto rebind(T& other) -> void { + m_ref = &other; + } + + explicit BaseRef() : m_ref(&m_cache), m_cache() {} + BaseRef(T& size) : m_ref(&size), m_cache() {} + + private: + T* m_ref; + T m_cache; +}; + +struct SizeRef : BaseRef { + using BaseRef::BaseRef; + SizeRef(int64_t value) { + if (value != kAnySize) { + (**this).set_value(value); + } else { + // otherwise, we can match any size + } + } +}; + +struct DTypeRef : BaseRef { + using BaseRef::BaseRef; + DTypeRef(DLDataType options) { + (**this).set_value(options); + } + DTypeRef(std::initializer_list options) { + (**this).set_options(options); + } + DTypeRef(std::span options) { + (**this).set_options(options); + } +}; + +struct DeviceRef : BaseRef { + using BaseRef::BaseRef; + DeviceRef(DLDevice options) { + (**this).set_value(options); + } + DeviceRef(std::initializer_list options) { + (**this).set_options(options); + } + DeviceRef(std::span options) { + (**this).set_options(options); + } +}; + +} // namespace details + +/** + * \brief Fluent API for validating tensor shape, strides, dtype, and device. + * + * Construct with the expected shape (using `SymbolicSize` or literal + * integers), chain `.with_strides()`, `.with_dtype<...>()`, and + * `.with_device<...>()`, then call `.verify(tensor)`. + * + * Example: + * \code + * auto N = SymbolicSize{"N"}; + * TensorMatcher({N, 128}) + * .with_dtype() + * .with_device() + * .verify(input_tensor); + * \endcode + * + * \note `TensorMatcher` is a move-only temporary. Do not store in a variable. + */ +struct TensorMatcher { + private: + using SizeRef = details::SizeRef; + using DTypeRef = details::DTypeRef; + using DeviceRef = details::DeviceRef; + + public: + TensorMatcher(const TensorMatcher&) = delete; + TensorMatcher& operator=(const TensorMatcher&) = delete; + + explicit TensorMatcher(std::initializer_list shape) : m_shape(shape), m_strides(), m_dtype() {} + + auto with_strides(std::initializer_list strides) && -> TensorMatcher&& { + // no partial update allowed + RuntimeCheck(m_strides.size() == 0, "Strides already specified"); + RuntimeCheck(m_shape.size() == strides.size(), "Strides size must match shape size"); + m_strides = strides; + return std::move(*this); + } + + template + auto with_dtype(DTypeRef&& dtype) && -> TensorMatcher&& { + m_init_dtype(); + m_dtype.rebind(*dtype); + m_dtype->set_options(); + return std::move(*this); + } + + template + auto with_dtype() && -> TensorMatcher&& { + static_assert(sizeof...(Ts) > 0, "At least one dtype option must be specified"); + m_init_dtype(); + m_dtype->set_options(); + return std::move(*this); + } + + template + auto with_device(DeviceRef&& device) && -> TensorMatcher&& { + m_init_device(); + m_device.rebind(*device); + m_device->set_options(); + return std::move(*this); + } + + template + auto with_device() && -> TensorMatcher&& { + static_assert(sizeof...(Codes) > 0, "At least one device option must be specified"); + m_init_device(); + m_device->set_options(); + return std::move(*this); + } + + // once we start verification, we cannot modify anymore + auto verify(tvm::ffi::TensorView view, DebugInfo info = {}) const&& -> const TensorMatcher&& { + try { + m_verify_impl(view); + } catch (PanicError& e) { + auto oss = std::ostringstream{}; + oss << "Tensor match failed for "; + s_print_tensor(oss, view); + oss << " at " << info.file_name() << ":" << info.line() << "\n- Root cause: " << e.root_cause(); + throw PanicError(std::move(oss).str()); + } + return std::move(*this); + } + + private: + static auto s_print_tensor(std::ostringstream& oss, tvm::ffi::TensorView view) -> void { + oss << "Tensor<"; + int64_t dim = 0; + for (const auto& size : view.shape()) { + if (dim++ > 0) oss << ", "; + oss << size; + } + oss << ">[strides=<"; + dim = 0; + for (const auto& stride : view.strides()) { + if (dim++ > 0) { + oss << ", "; + } + oss << stride; + } + oss << ">, dtype=" << view.dtype(); + oss << ", device=" << details::PrintableDevice{view.device()} << "]"; + } + + auto m_verify_impl(tvm::ffi::TensorView view) const -> void { + const auto dim = static_cast(view.dim()); + RuntimeCheck(dim == m_shape.size(), "Tensor dimension mismatch: expected ", m_shape.size(), " but got ", dim); + for (const auto i : irange(dim)) { + m_shape[i]->verify(view.size(i), "shape", i); + } + if (m_has_strides()) { + for (const auto i : irange(dim)) { + if (view.size(i) != 1 || !m_strides[i]->has_value()) { + // skip stride check for size 1 dimension + m_strides[i]->verify(view.stride(i), "stride", i); + } + } + } else { + RuntimeCheck(view.is_contiguous(), "Tensor is not contiguous as expected"); + } + // since we may double verify, we will force to check + m_dtype->verify(view.dtype()); + m_device->verify(view.device()); + } + + auto m_init_dtype() -> void { + RuntimeCheck(!m_has_dtype, "DType already specified"); + m_has_dtype = true; + } + + auto m_init_device() -> void { + RuntimeCheck(!m_has_device, "Device already specified"); + m_has_device = true; + } + + auto m_has_strides() const -> bool { + return !m_strides.empty(); + } + + std::span m_shape; + std::span m_strides; + DTypeRef m_dtype; + DeviceRef m_device; + bool m_has_dtype = false; + bool m_has_device = false; +}; + +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh new file mode 100644 index 0000000000..1adc821706 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/tile.cuh @@ -0,0 +1,62 @@ +/// \file tile.cuh +/// \brief Tiled memory access helpers for coalesced global memory I/O. +/// +/// `tile::Memory` represents a contiguous memory region where multiple +/// threads cooperatively load/store elements. The three factory methods +/// determine the thread group: +/// - `thread()` - single thread (no tiling). +/// - `warp()` - all threads in a warp cooperate. +/// - `cta()` - all threads in the CTA cooperate. + +#pragma once +#include + +#include + +namespace device::tile { + +/** + * \brief Represents a contiguous memory region for cooperative tiled access. + * + * Each instance is parameterized by an element type `T` and bound to a + * specific thread id (`tid`) within a group of `tsize` threads. + * + * \tparam T The storage element type (e.g. `AlignedVector, 4>`). + */ +template +struct Memory { + public: + SGL_DEVICE constexpr Memory(uint32_t tid, uint32_t tsize) : tid(tid), tsize(tsize) {} + /// \brief Create a Memory accessor for a single thread (no cooperation). + SGL_DEVICE static constexpr Memory thread() { + return Memory{0, 1}; + } + /// \brief Create a Memory accessor distributed across warp threads. + SGL_DEVICE static Memory warp(int warp_threads = kWarpThreads) { + return Memory{static_cast(threadIdx.x % warp_threads), static_cast(warp_threads)}; + } + /// \brief Create a Memory accessor distributed across all CTA threads. + SGL_DEVICE static Memory cta(int cta_threads = blockDim.x) { + return Memory{static_cast(threadIdx.x), static_cast(cta_threads)}; + } + /// \brief Load one element from `ptr` at the position assigned to this thread. + /// \param ptr Base pointer (cast to `const T*`). + /// \param offset Optional tile offset (multiplied by `tsize`). + SGL_DEVICE T load(const void* ptr, int64_t offset = 0) const { + return static_cast(ptr)[tid + offset * tsize]; + } + /// \brief Store one element to `ptr` at the position assigned to this thread. + SGL_DEVICE void store(void* ptr, T val, int64_t offset = 0) const { + static_cast(ptr)[tid + offset * tsize] = val; + } + /// \brief Check whether this thread's element index is within bounds. + SGL_DEVICE bool in_bound(int64_t element_count, int64_t offset = 0) const { + return tid + offset * tsize < element_count; + } + + private: + uint32_t tid; + uint32_t tsize; +}; + +} // namespace device::tile diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh new file mode 100644 index 0000000000..a7a5346196 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/type.cuh @@ -0,0 +1,120 @@ +/// \file type.cuh +/// \brief Dtype trait system for CUDA scalar/packed types. +/// +/// `dtype_trait` provides per-type metadata: packed type alias, +/// conversion functions (`from`), and unary/binary math operations. +/// Use `device::cast(from_value)` for type conversion on device. +/// +/// Registered types: +/// | Scalar | Packed (x2) | Notes | +/// |-----------|-------------|-------------------------------| +/// | `fp32_t` | `fp32x2_t` | Full math ops (abs,sqrt,...) | +/// | `fp16_t` | `fp16x2_t` | Conversion only | +/// | `bf16_t` | `bf16x2_t` | Conversion only | +/// | `fp32x2_t`| `fp32x4_t` | Packed float2 <-> half2/bf162 | + +#pragma once +#include + +template +struct dtype_trait {}; + +#define SGL_REGISTER_DTYPE_TRAIT(TYPE, PACK2, ...) \ + template <> \ + struct dtype_trait { \ + using self_t = TYPE; \ + using packed_t = PACK2; \ + template \ + SGL_DEVICE static self_t from(const S& value) { \ + return static_cast(value); \ + } \ + __VA_ARGS__ \ + } + +#define SGL_REGISTER_TYPE_END static_assert(true) + +#define SGL_REGISTER_FROM_FUNCTION(FROM, FN) \ + SGL_DEVICE static self_t from(const FROM& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_UNARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_BINARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x, const self_t& y) { \ + return FN(x, y); \ + } \ + static_assert(true) + +SGL_REGISTER_DTYPE_TRAIT( + fp32_t, fp32x2_t, SGL_REGISTER_TYPE_END; // + SGL_REGISTER_FROM_FUNCTION(fp16_t, __half2float); + SGL_REGISTER_FROM_FUNCTION(bf16_t, __bfloat162float); + SGL_REGISTER_UNARY_FUNCTION(abs, fabsf); + SGL_REGISTER_UNARY_FUNCTION(sqrt, sqrtf); + SGL_REGISTER_UNARY_FUNCTION(rsqrt, rsqrtf); + SGL_REGISTER_UNARY_FUNCTION(exp, expf); + SGL_REGISTER_UNARY_FUNCTION(sin, sinf); + SGL_REGISTER_UNARY_FUNCTION(cos, cosf); + SGL_REGISTER_BINARY_FUNCTION(max, fmaxf); + SGL_REGISTER_BINARY_FUNCTION(min, fminf);); +SGL_REGISTER_DTYPE_TRAIT(fp16_t, fp16x2_t); +SGL_REGISTER_DTYPE_TRAIT(bf16_t, bf16x2_t); + +/// TODO: Add ROCM implementation +SGL_REGISTER_DTYPE_TRAIT( + fp32x2_t, fp32x4_t, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp16x2_t, __half22float2); + SGL_REGISTER_FROM_FUNCTION(bf16x2_t, __bfloat1622float2);); + +SGL_REGISTER_DTYPE_TRAIT( + fp16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22half2_rn);); + +SGL_REGISTER_DTYPE_TRAIT( + bf16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn);); + +#ifndef USE_ROCM +SGL_REGISTER_DTYPE_TRAIT(fp8_e4m3_t, fp8x2_e4m3_t); +#endif + +#undef SGL_REGISTER_DTYPE_TRAIT +#undef SGL_REGISTER_FROM_FUNCTION + +/// \brief Alias: the packed (x2) type for `T`. +template +using packed_t = typename dtype_trait::packed_t; + +namespace device { + +/** + * \brief Cast a value from type `From` to type `To` on device. + * + * Dispatches through `dtype_trait::from()`, which uses the appropriate + * CUDA intrinsic (e.g. `__half2float`, `__float22half2_rn`). + */ +template +SGL_DEVICE To cast(const From& value) { + return dtype_trait::from(value); +} + +} // namespace device + +// --------------------------------------------------------------------------- +// FP8 max clamp value — platform-dependent +// CUDA (e4m3fn): 448.0f +// AMD FNUZ (e4m3fnuz): 224.0f +// AMD E4M3 (e4m3fn): 448.0f +// --------------------------------------------------------------------------- +#ifndef USE_ROCM +constexpr float kFP8E4M3Max = 448.0f; +#else // USE_ROCM +#if HIP_FP8_TYPE_FNUZ +constexpr float kFP8E4M3Max = 224.0f; +#else // HIP_FP8_TYPE_E4M3 +constexpr float kFP8E4M3Max = 448.0f; +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh new file mode 100644 index 0000000000..2dd6f3dc93 --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.cuh @@ -0,0 +1,333 @@ +/// \file utils.cuh +/// \brief Core CUDA/device utilities: type aliases, PDL helpers, +/// typed pointer access, kernel launch wrapper, and error checking. +/// +/// This header is included (directly or transitively) by nearly every +/// JIT kernel. It provides: +/// - Scalar/packed type aliases (`fp16_t`, `bf16_t`, `fp8_e4m3_t`, ...). +/// - `SGL_DEVICE` macro (forced-inline device function qualifier). +/// - `kWarpThreads` constant (32). +/// - PDL (Programmatic Dependent Launch) helpers for Hopper (sm_90+). +/// - Typed `load_as` / `store_as` for void-pointer access. +/// - `pointer::offset` for safe void-pointer arithmetic. +/// - `host::LaunchKernel` - kernel launcher with optional PDL. +/// - `host::RuntimeDeviceCheck` - CUDA error checking. + +#pragma once + +#include + +#include +#include + +#include +#include +#include +#ifndef USE_ROCM +#include +#include +#include +#include +#else +#include +#include +#include +#ifndef __grid_constant__ +#define __grid_constant__ +#endif +using cudaError_t = hipError_t; +using cudaStream_t = hipStream_t; +using cudaLaunchConfig_t = hipLaunchConfig_t; +using cudaLaunchAttribute = hipLaunchAttribute; +inline constexpr auto cudaSuccess = hipSuccess; +#define cudaStreamPerThread hipStreamPerThread +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaLaunchKernel hipLaunchKernel +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#endif + +#ifndef USE_ROCM +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __nv_bfloat16; +using fp8_e4m3_t = __nv_fp8_e4m3; +using fp8_e5m2_t = __nv_fp8_e5m2; + +using fp32x2_t = float2; +using fp16x2_t = __half2; +using bf16x2_t = __nv_bfloat162; +using fp8x2_e4m3_t = __nv_fp8x2_e4m3; +using fp8x2_e5m2_t = __nv_fp8x2_e5m2; + +using fp32x4_t = float4; +#else +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __hip_bfloat16; +using fp8_e4m3_t = uint8_t; +using fp8_e5m2_t = uint8_t; +using fp32x2_t = float2; +using fp16x2_t = half2; +using bf16x2_t = __hip_bfloat162; +using fp8x2_e4m3_t = uint16_t; +using fp8x2_e5m2_t = uint16_t; +using fp32x4_t = float4; +#endif + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + +// DLPack device type for the current platform +#ifndef USE_ROCM +inline constexpr auto kDLGPU = kDLCUDA; +#else +inline constexpr auto kDLGPU = kDLROCM; +#endif + +namespace device { + +/// \brief Macro: forced-inline device function qualifier. +#define SGL_DEVICE __forceinline__ __device__ + +// Architecture detection: SGL_CUDA_ARCH is injected by load_jit() and is +// available in both host and device compilation passes, whereas __CUDA_ARCH__ +// is only defined by nvcc during the device pass. +#if !defined(USE_ROCM) +#if !defined(SGL_CUDA_ARCH) +#error "SGL_CUDA_ARCH is not defined. JIT compilation must inject -DSGL_CUDA_ARCH via load_jit()." +#endif +#if defined(__CUDA_ARCH__) +static_assert( + __CUDA_ARCH__ == SGL_CUDA_ARCH, "SGL_CUDA_ARCH mismatch: injected arch flag does not match device target"); +#endif +#define SGL_ARCH_HOPPER_OR_GREATER (SGL_CUDA_ARCH >= 900) +#define SGL_ARCH_BLACKWELL_OR_GREATER ((SGL_CUDA_ARCH >= 1000) && (CUDA_VERSION >= 12090)) +#else // USE_ROCM +#define SGL_ARCH_HOPPER_OR_GREATER 0 +#define SGL_ARCH_BLACKWELL_OR_GREATER 0 +#endif + +// Maximum vector size in bytes supported by current architecture. +// Pre-Blackwell / AMD: 128-bit (16 bytes) +// Blackwell or greater: 256-bit (32 bytes) +inline constexpr std::size_t kMaxVecBytes = SGL_ARCH_BLACKWELL_OR_GREATER ? 32 : 16; + +/// \brief Number of threads per warp (always 32 on NVIDIA/AMD GPUs). +inline constexpr auto kWarpThreads = 32u; +/// \brief Full warp active mask (all 32 lanes). +#ifndef USE_ROCM +inline constexpr auto kFullMask = 0xffffffffu; +#else +inline constexpr auto kFullMask = 0xffffffffffffffffULL; +#endif + +/** + * \brief PDL (Programmatic Dependent Launch): wait for the primary kernel. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.wait` instruction to + * synchronize with a preceding kernel in the same stream. On older + * architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLWaitPrimary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.wait;" ::: "memory"); + } +#endif +} + +/** + * \brief PDL: trigger dependent (secondary) kernel launch. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.launch_dependents` + * instruction. On older architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLTriggerSecondary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.launch_dependents;" :::); + } +#endif +} + +template +SGL_DEVICE constexpr auto div_ceil(T a, U b) { + return (a + b - 1) / b; +} + +/** + * \brief Load data with the specified type and offset from a void pointer. + * \tparam T The type to load. + * \param ptr The base pointer. + * \param offset The offset in number of elements of type T. + */ +template +SGL_DEVICE T load_as(const void* ptr, int64_t offset = 0) { + return static_cast(ptr)[offset]; +} + +/** + * \brief Store data with the specified type and offset to a void pointer. + * \tparam T The type to store. + * \param ptr The base pointer. + * \param val The value to store. + * \param offset The offset in number of elements of type T. + * \note we use type_identity_t to force the caller to explicitly specify + * the template parameter `T`, which can avoid accidentally using the wrong type. + */ +template +SGL_DEVICE void store_as(void* ptr, std::type_identity_t val, int64_t offset = 0) { + static_cast(ptr)[offset] = val; +} + +/// \brief Safe void-pointer arithmetic (byte-level by default). +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template +SGL_DEVICE auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} + +template +SGL_DEVICE auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +} // namespace device + +namespace host { + +/** + * \brief Check the CUDA error code and panic with location info on failure. + */ +inline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) { + if (error != ::cudaSuccess) { + [[unlikely]]; + ::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error)); + } +} + +/// \brief Check the last CUDA error (calls `cudaGetLastError`). +inline void RuntimeDeviceCheck(DebugInfo location = {}) { + return RuntimeDeviceCheck(::cudaGetLastError(), location); +} + +/** + * \brief Kernel launcher with automatic stream resolution and PDL support. + * + * Usage: + * \code + * host::LaunchKernel(grid, block, device) + * .enable_pdl(true) + * (my_kernel, arg1, arg2); + * \endcode + * + * The constructor resolves the CUDA stream from a `DLDevice` (via + * `TVMFFIEnvGetStream`) or accepts a raw `cudaStream_t`. The call + * operator launches the kernel and checks for errors. + */ +struct LaunchKernel { + public: + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + DLDevice device, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)), + m_location(location) {} + + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)), m_location(location) {} + + LaunchKernel(const LaunchKernel&) = delete; + LaunchKernel& operator=(const LaunchKernel&) = delete; + + static auto resolve_device(DLDevice device) -> cudaStream_t { + return static_cast(::TVMFFIEnvGetStream(device.device_type, device.device_id)); + } + + auto enable_pdl(bool enabled = true) -> LaunchKernel& { +#ifdef USE_ROCM + (void)enabled; + m_config.numAttrs = 0; +#else + if (enabled) { + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = true; + m_config.attrs = m_attrs; + } +#endif + return *this; + } + + auto enable_cluster(dim3 cluster_dim) -> LaunchKernel& { +#ifdef USE_ROCM + (void)cluster_dim; +#else + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {cluster_dim.x, cluster_dim.y, cluster_dim.z}; + m_config.attrs = m_attrs; +#endif + return *this; + } + + template + auto operator()(T&& kernel, Args&&... args) const -> void { +#ifdef USE_ROCM + hipLaunchKernelGGL( + std::forward(kernel), + m_config.gridDim, + m_config.blockDim, + m_config.dynamicSmemBytes, + m_config.stream, + std::forward(args)...); + RuntimeDeviceCheck(m_location); +#else + RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward(args)...), m_location); +#endif + } + + private: + static auto s_make_config( // Make a config for kernel launch + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t smem) -> cudaLaunchConfig_t { + auto config = ::cudaLaunchConfig_t{}; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem; + config.stream = stream; + config.numAttrs = 0; + return config; + } + + cudaLaunchConfig_t m_config; + const DebugInfo m_location; + cudaLaunchAttribute m_attrs[2]; +}; + +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h new file mode 100644 index 0000000000..3226f79ddc --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/utils.h @@ -0,0 +1,186 @@ +/// \file utils.h +/// \brief Host-side C++ utilities used by JIT kernel wrappers. +/// +/// Provides: +/// - `DebugInfo` - wraps `std::source_location` for error reporting. +/// - `RuntimeCheck` - runtime assertion with formatted error messages. +/// - `Panic` - unconditional abort with formatted error messages. +/// - `pointer::offset` - safe void-pointer arithmetic (host side). +/// - `div_ceil` - integer ceiling division. +/// - `dtype_bytes` - byte width of a `DLDataType`. +/// - `irange` - Python-style integer range for range-for loops. + +#pragma once + +// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3 +#ifdef __CUDACC__ +#include +#if CUDA_VERSION <= 12010 + +#pragma push_macro("__cpp_consteval") +#pragma push_macro("_NODISCARD") +#pragma push_macro("__builtin_LINE") + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbuiltin-macro-redefined" +#define __cpp_consteval 201811L +#pragma clang diagnostic pop + +#ifdef _NODISCARD +#undef _NODISCARD +#define _NODISCARD +#endif + +#define consteval constexpr + +#include "source_location.h" + +#undef consteval +#pragma pop_macro("__cpp_consteval") +#pragma pop_macro("_NODISCARD") +#else // __CUDACC__ && CUDA_VERSION > 12010 +#include "source_location.h" +#endif +#else // no __CUDACC__ +#include "source_location.h" +#endif + +#include + +#include +#include +#include +#include +#include +#include + +namespace host { + +template +inline constexpr bool dependent_false_v = false; + +/// \brief Source-location wrapper for debug/error messages. +struct DebugInfo : public source_location_t { + DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} +}; + +/// \brief Exception type thrown by `RuntimeCheck` and `Panic`. +struct PanicError : public std::runtime_error { + public: + explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} + auto root_cause() const -> std::string_view { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); + } + + private: + std::string m_message; +}; + +/// \brief Unconditionally abort with a formatted error message. +template +[[noreturn]] +inline auto panic(DebugInfo location, Args&&... args) -> void { + std::ostringstream os; + os << "Runtime check failed at " << location.file_name() << ":" << location.line(); + if constexpr (sizeof...(args) > 0) { + os << ": "; + (os << ... << std::forward(args)); + } else { + os << " in " << location.function_name(); + } + throw PanicError(std::move(os).str()); +} + +/** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ +template +struct RuntimeCheck { + template + explicit RuntimeCheck(Cond&& condition, Args&&... args, DebugInfo location = {}) { + if (condition) return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + template + explicit RuntimeCheck(DebugInfo location, Cond&& condition, Args&&... args) { + if (condition) return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } +}; + +template +struct Panic { + explicit Panic(Args&&... args, DebugInfo location = {}) { + ::host::panic(location, std::forward(args)...); + } + explicit Panic(DebugInfo location, Args&&... args) { + ::host::panic(location, std::forward(args)...); + } + [[noreturn]] ~Panic() { + std::terminate(); + } +}; + +template +explicit RuntimeCheck(Cond&&, Args&&...) -> RuntimeCheck; + +template +explicit RuntimeCheck(DebugInfo, Cond&&, Args&&...) -> RuntimeCheck; + +template +explicit Panic(Args&&...) -> Panic; + +template +explicit Panic(DebugInfo, Args&&...) -> Panic; + +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template +inline auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} + +template +inline auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +/// \brief Integer ceiling division: ceil(a / b). +template +inline constexpr auto div_ceil(T a, U b) { + return (a + b - 1) / b; +} + +/// \brief Returns the byte width of a DLPack data type. +inline auto dtype_bytes(DLDataType dtype) -> std::size_t { + return static_cast(dtype.bits / 8); +} + +namespace stdr = std::ranges; +namespace stdv = stdr::views; + +/// \brief Python-style integer range: `irange(n)` -> `[0, n)`. +template +inline auto irange(T end) { + return stdv::iota(static_cast(0), end); +} + +/// \brief Python-style integer range: `irange(start, end)` -> `[start, end)`. +template +inline auto irange(T start, T end) { + return stdv::iota(start, end); +} + +} // namespace host diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh new file mode 100644 index 0000000000..67f388679f --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/vec.cuh @@ -0,0 +1,118 @@ +/// \file vec.cuh +/// \brief Aligned vector types for coalesced global memory access. +/// +/// `AlignedVector` wraps `N` elements of type `T` in a naturally +/// aligned struct so that the compiler emits wide (vectorized) load/store +/// instructions (e.g. `LDG.128`). The maximum supported vector width is +/// 256 bits (32 bytes), matching CUDA's widest vector load. + +#pragma once +#include + +#include +#include + +namespace device { + +namespace details { + +/// \brief Maps byte-width to the corresponding unsigned integer type. +template +struct uint_trait {}; + +template <> +struct uint_trait<1> { + using type = uint8_t; +}; + +template <> +struct uint_trait<2> { + using type = uint16_t; +}; + +template <> +struct uint_trait<4> { + using type = uint32_t; +}; + +template <> +struct uint_trait<8> { + using type = uint64_t; +}; + +/// \brief Alias: maps `sizeof(T)` to matching unsigned int type. +template +using sized_int = typename uint_trait::type; + +} // namespace details + +/// \brief Raw aligned storage for `N` elements of type `T`. +template +struct alignas(sizeof(T) * N) AlignedStorage { + T data[N]; +}; + +/** + * \brief Aligned vector for vectorized memory access on GPU. + * + * Stores `N` elements of type `T` with natural alignment so that a single + * `load`/`store` call compiles to a wide memory transaction. + * + * \tparam T Element type (e.g. `fp16_t`, `bf16_t`, `float`). + * \tparam N Number of elements. Must be a power of two and + * `sizeof(T) * N <= 32` (256 bits). + * + * Example: + * \code + * AlignedVector vec; // 16 bytes, 128-bit aligned + * vec.load(input_ptr, tid); // vectorized load + * vec[0] = vec[0] + 1; + * vec.store(output_ptr, tid); // vectorized store + * \endcode + */ +template +struct AlignedVector { + private: + static_assert( + (N > 0 && (N & (N - 1)) == 0) && sizeof(T) * N <= kMaxVecBytes, + "CUDA vector size exceeds arch limit: max 16 bytes on pre-Blackwell/AMD, " + "32 bytes on Blackwell or greater"); + using element_t = typename details::sized_int; + using storage_t = AlignedStorage; + + public: + /// \brief Vectorized load from `ptr` at the given element `offset`. + SGL_DEVICE void load(const void* ptr, int64_t offset = 0) { + m_storage = reinterpret_cast(ptr)[offset]; + } + /// \brief Vectorized store to `ptr` at the given element `offset`. + SGL_DEVICE void store(void* ptr, int64_t offset = 0) const { + reinterpret_cast(ptr)[offset] = m_storage; + } + /// \brief Fill all N elements with the same `value`. + SGL_DEVICE void fill(T value) { + const auto store_value = *reinterpret_cast(&value); +#pragma unroll + for (std::size_t i = 0; i < N; ++i) { + m_storage.data[i] = store_value; + } + } + + SGL_DEVICE auto operator[](std::size_t idx) -> T& { + return reinterpret_cast(&m_storage)[idx]; + } + SGL_DEVICE auto operator[](std::size_t idx) const -> T { + return reinterpret_cast(&m_storage)[idx]; + } + SGL_DEVICE auto data() -> T* { + return reinterpret_cast(&m_storage); + } + SGL_DEVICE auto data() const -> const T* { + return reinterpret_cast(&m_storage); + } + + private: + storage_t m_storage; +}; + +} // namespace device diff --git a/lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh b/lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh new file mode 100644 index 0000000000..9d82efae1e --- /dev/null +++ b/lightllm/third_party/sglang_jit/include/sgl_kernel/warp.cuh @@ -0,0 +1,56 @@ +/// \file warp.cuh +/// \brief Warp-level reduction primitives. + +#pragma once +#include +#include + +namespace device::warp { + +/// \brief Full warp active mask. +#ifndef USE_ROCM +static constexpr uint32_t kFullMask = 0xffffffffu; +using mask_t = uint32_t; +#else +static constexpr uint64_t kFullMask = 0xffffffffffffffffULL; +using mask_t = uint64_t; +#endif + +/** + * \brief Warp-level sum reduction. + * + * On CUDA: uses __shfl_xor_sync with width=32. + * On HIP: uses __shfl_xor with explicit width parameter (supports wave64 sub-groups). + */ +template +SGL_DEVICE T reduce_sum(T value, mask_t active_mask = kFullMask) { + static_assert(kNumThreads >= 1 && kNumThreads <= kWarpThreads); + static_assert(std::has_single_bit(kNumThreads), "must be pow of 2"); +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) +#ifndef USE_ROCM + value = value + __shfl_xor_sync(active_mask, value, mask, 32); +#else + value = value + __shfl_xor(value, mask, kNumThreads); +#endif + return value; +} + +/** + * \brief Warp-level max reduction. + */ +template +SGL_DEVICE T reduce_max(T value, mask_t active_mask = kFullMask) { + static_assert(kNumThreads >= 1 && kNumThreads <= kWarpThreads); + static_assert(std::has_single_bit(kNumThreads), "must be pow of 2"); +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) +#ifndef USE_ROCM + value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); +#else + value = math::max(value, __shfl_xor(value, mask, kNumThreads)); +#endif + return value; +} + +} // namespace device::warp diff --git a/lightllm/third_party/sglang_jit/jit_utils.py b/lightllm/third_party/sglang_jit/jit_utils.py new file mode 100644 index 0000000000..4096c16bb4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/jit_utils.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import functools +import importlib.util +import logging +import os +import pathlib +from contextlib import contextmanager +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + TypeAlias, + TypeVar, + Union, +) + +import torch + +if TYPE_CHECKING: + from tvm_ffi import Module + +F = TypeVar("F", bound=Callable[..., Any]) +_FULL_TEST_ENV_VAR = "SGLANG_JIT_KERNEL_RUN_FULL_TESTS" + +logger = logging.getLogger(__name__) + + +def is_in_ci() -> bool: + return os.getenv("SGLANG_IS_IN_CI", "").lower() in ("1", "true", "yes", "y") + + +def should_run_full_tests() -> bool: + return os.getenv(_FULL_TEST_ENV_VAR, "false").lower() == "true" + + +def get_ci_test_range(full_range: List[Any], ci_range: List[Any]) -> List[Any]: + if should_run_full_tests(): + return full_range + return ci_range if is_in_ci() else full_range + + +def cache_once(fn: F) -> F: + """ + NOTE: `functools.lru_cache` is not compatible with `torch.compile` + So we manually implement a simple cache_once decorator to replace it. + """ + result_map = {} + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + key = (args, tuple(sorted(kwargs.items()))) + if key not in result_map: + result_map[key] = fn(*args, **kwargs) + return result_map[key] + + return wrapper # type: ignore + + +def _make_wrapper(tup: Tuple[str, str]) -> str: + export_name, kernel_name = tup + return f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({export_name}, ({kernel_name}));" + + +@cache_once +def _resolve_kernel_path() -> pathlib.Path: + cur_dir = pathlib.Path(__file__).parent.resolve() + + # first, try this directory structure + def _environment_install(): + candidate = cur_dir.resolve() + if (candidate / "include").exists() and (candidate / "csrc").exists(): + return candidate + return None + + def _package_install(): + # TODO: support find path by package + return None + + path = _environment_install() or _package_install() + if path is None: + raise RuntimeError("Cannot find sglang.jit_kernel path") + return path + + +KERNEL_PATH = _resolve_kernel_path() +DEFAULT_INCLUDE = [str(KERNEL_PATH / "include")] +DEFAULT_CFLAGS = ["-std=c++20", "-O3"] +DEFAULT_LDFLAGS = [] +CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, str, bool, torch.dtype] + + +class CPPArgList(list[str]): + def __str__(self) -> str: + return ", ".join(self) + + +CPP_DTYPE_MAP = { + torch.float: "fp32_t", + torch.float16: "fp16_t", + torch.float8_e4m3fn: "fp8_e4m3_t", + torch.bfloat16: "bf16_t", + torch.int8: "int8_t", + torch.int32: "int32_t", + torch.int64: "int64_t", +} + + +# AMD/ROCm note: +@cache_once +def is_hip_runtime() -> bool: + return bool(torch.version.hip) + + +# MThreads/MUSA note: +@cache_once +def is_musa_runtime() -> bool: + return hasattr(torch.version, "musa") and torch.version.musa is not None + + +def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList: + def _convert(arg: CPP_TEMPLATE_TYPE) -> str: + if isinstance(arg, bool): + return "true" if arg else "false" + if isinstance(arg, (int, str, float)): + return str(arg) + if isinstance(arg, torch.dtype): + return CPP_DTYPE_MAP[arg] + raise TypeError(f"Unsupported argument type for cpp template: {type(arg)}") + + return CPPArgList(_convert(arg) for arg in args) + + +def load_jit( + *args: str, + cpp_files: List[str] | None = None, + cuda_files: List[str] | None = None, + cpp_wrappers: List[Tuple[str, str]] | None = None, + cuda_wrappers: List[Tuple[str, str]] | None = None, + extra_cflags: List[str] | None = None, + extra_cuda_cflags: List[str] | None = None, + extra_ldflags: List[str] | None = None, + extra_include_paths: List[str] | None = None, + extra_dependencies: List[str] | None = None, + build_directory: str | None = None, + header_only: bool = True, +) -> Module: + """ + Loading a JIT module from C++/CUDA source files. + We define a wrapper as a tuple of (export_name, kernel_name), + where `export_name` is the name used to called from Python, + and `kernel_name` is the name of the kernel class in C++/CUDA source. + + :param args: Unique marker of the JIT module. Must be distinct for different kernels. + :type args: str + :param cpp_files: A list of C++ source files. + :type cpp_files: List[str] | None + :param cuda_files: A list of CUDA source files. + :type cuda_files: List[str] | None + :param cpp_wrappers: A list of C++ wrappers, defining the export name and kernel name. + :type cpp_wrappers: List[Tuple[str, str]] | None + :param cuda_wrappers: A list of CUDA wrappers, defining the export name and kernel name. + :type cuda_wrappers: List[Tuple[str, str]] | None + :param extra_cflags: Extra C++ compiler flags. + :type extra_cflags: List[str] | None + :param extra_cuda_cflags: Extra CUDA compiler flags. + :type extra_cuda_cflags: List[str] | None + :param extra_ldflags: Extra linker flags. + :type extra_ldflags: List[str] | None + :param extra_include_paths: Extra include paths. + :type extra_include_paths: List[str] | None + :param extra_dependencies: Extra dependencies for the JIT module, e.g., cutlass. + :type extra_dependencies: List[str] | None + :param build_directory: The build directory for JIT compilation. + :type build_directory: str | None + :param header_only: Whether the module is header-only. + If true, apply the wrappers to export given class/functions. + Otherwise, we must export from C++/CUDA side. + :return: A just-in-time(JIT) compiled module. + :rtype: Module + """ + + from tvm_ffi.cpp import load, load_inline + + cpp_files = cpp_files or [] + cuda_files = cuda_files or [] + extra_cflags = extra_cflags or [] + extra_cuda_cflags = extra_cuda_cflags or [] + extra_ldflags = extra_ldflags or [] + extra_include_paths = extra_include_paths or [] + + cpp_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cpp_files] + cuda_files = [str((KERNEL_PATH / "csrc" / f).resolve()) for f in cuda_files] + + for dep in set(extra_dependencies or []): + if dep not in _REGISTERED_DEPENDENCIES: + raise ValueError(f"Dependency {dep} is not registered.") + extra_include_paths += _REGISTERED_DEPENDENCIES[dep]() + + module_name = "sgl_kernel_jit_" + "_".join(str(arg) for arg in args) + if header_only: + cpp_wrappers = cpp_wrappers or [] + cuda_wrappers = cuda_wrappers or [] + cpp_sources = [f'#include "{path}"' for path in cpp_files] + cpp_sources += [_make_wrapper(tup) for tup in cpp_wrappers] + + # include cuda files + cuda_sources = [f'#include "{path}"' for path in cuda_files] + cuda_sources += [_make_wrapper(tup) for tup in cuda_wrappers] + with _jit_compile_context(): + return load_inline( + module_name, + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + else: + assert cpp_wrappers is None and cuda_wrappers is None + with _jit_compile_context(): + return load( + module_name, + cpp_files=cpp_files, + cuda_files=cuda_files, + extra_cflags=DEFAULT_CFLAGS + extra_cflags, + extra_cuda_cflags=_get_default_target_flags() + extra_cuda_cflags, + extra_ldflags=DEFAULT_LDFLAGS + extra_ldflags, + extra_include_paths=DEFAULT_INCLUDE + extra_include_paths, + build_directory=build_directory, + ) + + +@dataclass +class ArchInfo: + major: int + minor: int + suffix: str + + @property + def target_name(self) -> str: + return f"{self.major}.{self.minor}{self.suffix}" + + @property + def jit_flag(self) -> str: + return f"-DSGL_CUDA_ARCH={self.major * 100 + self.minor * 10}" + + +@cache_once +def _init_jit_cuda_arch_once(): + global _CUDA_ARCH + try: + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + except Exception: + logger.warning("Cannot detect CUDA architecture.") + major, minor = 0, 0 # invalid value to trigger compile error if used + _CUDA_ARCH = ArchInfo(major, minor, "") + + +@contextmanager +def _jit_compile_context(): + if is_hip_runtime(): + yield # TODO: support ROCm `TVM_FFI_ROCM_ARCH_LIST` if needed + return + env_key = "TVM_FFI_CUDA_ARCH_LIST" + old_value = os.environ.get(env_key, None) + os.environ[env_key] = get_jit_cuda_arch().target_name + try: + yield + finally: + if old_value is None: + os.environ.pop(env_key, None) + else: + os.environ[env_key] = old_value + + +# NOTE: this might also be used in __main__.py for compile flags export +def _get_default_target_flags() -> List[str]: + if is_hip_runtime(): + flags = ["-DUSE_ROCM", "-std=c++20", "-O3"] + # Detect FP8 type based on GPU architecture + try: + device = torch.cuda.current_device() + gcn_arch = torch.cuda.get_device_properties(device).gcnArchName + if "gfx942" in gcn_arch: + flags.append("-DHIP_FP8_TYPE_FNUZ=1") + else: + flags.append("-DHIP_FP8_TYPE_E4M3=1") + except Exception: + flags.append("-DHIP_FP8_TYPE_E4M3=1") + return flags + else: + return [ + get_jit_cuda_arch().jit_flag, + "-std=c++20", + "-O3", + "--expt-relaxed-constexpr", + ] + + +@contextmanager +def override_jit_cuda_arch(major: int, minor: int, suffix: str = ""): + """A context manager to temporarily override CUDA architecture.""" + global _CUDA_ARCH + old_value = get_jit_cuda_arch() + _CUDA_ARCH = ArchInfo(major, minor, suffix) + try: + yield + finally: + _CUDA_ARCH = old_value + + +def get_jit_cuda_arch() -> ArchInfo: + """Get the current CUDA architecture info.""" + _init_jit_cuda_arch_once() + return _CUDA_ARCH + + +@cache_once +def is_arch_support_pdl() -> bool: + if is_hip_runtime() or is_musa_runtime(): + return False + return get_jit_cuda_arch().major >= 9 + + +def _find_package_root(package: str) -> Optional[pathlib.Path]: + spec = importlib.util.find_spec(package) + if spec is None or spec.origin is None: + return None + return pathlib.Path(spec.origin).resolve().parent + + +# NOTE: this might also be used in __main__.py for compile flags export +_REGISTERED_DEPENDENCIES: Dict[str, Callable[[], List[str]]] = {} + + +def register_dependency(name: str): + def decorator(f: Callable[[], List[str]]) -> Callable[[], List[str]]: + if name in _REGISTERED_DEPENDENCIES: + raise ValueError(f"Dependency {name} already registered") + _REGISTERED_DEPENDENCIES[name] = f + return f + + return decorator + + +@register_dependency("flashinfer") +def get_flashinfer_include_paths() -> List[str]: + include_paths: List[str] = [] + flashinfer_root = _find_package_root("flashinfer") + if flashinfer_root is None: + raise RuntimeError( + "Cannot find flashinfer package. Please install flashinfer to get" + "the required headers for JIT compilation." + ) + + flashinfer_data = flashinfer_root / "data" + candidates = [ + flashinfer_data / "include", + flashinfer_data / "csrc", + flashinfer_data / "cutlass" / "include", + flashinfer_data / "cutlass" / "tools" / "util" / "include", + flashinfer_data / "spdlog" / "include", + ] + + for path in candidates: + if not path.exists(): + raise RuntimeError( + f"Required header path {path} for flashinfer dependency not found." + " Please check your flashinfer installation." + ) + include_paths.append(str(path)) + return include_paths + + +@register_dependency("cutlass") +def get_cutlass_include_paths() -> List[str]: + include_paths: List[str] = [] + + flashinfer_root = _find_package_root("flashinfer") + if flashinfer_root is not None: + candidates = [ + flashinfer_root / "data" / "cutlass" / "include", + flashinfer_root / "data" / "cutlass" / "tools" / "util" / "include", + ] + for path in candidates: + if path.exists(): + include_paths.append(str(path)) + + deep_gemm_root = _find_package_root("deep_gemm") + if deep_gemm_root is not None: + candidate = deep_gemm_root / "include" + if candidate.exists(): + include_paths.append(str(candidate)) + + # De-duplicate while preserving order. + unique_paths = [] + seen = set() + for path in include_paths: + if path in seen: + continue + seen.add(path) + unique_paths.append(path) + + if not unique_paths: + raise RuntimeError( + "Cannot find CUTLASS headers required for JIT compilation. " + "Please install flashinfer or deep_gemm with CUTLASS headers." + ) + return unique_paths + + +__all__ = [ + "should_run_full_tests", + "get_ci_test_range", + "cache_once", + "is_hip_runtime", + "make_cpp_args", + "load_jit", + "override_jit_cuda_arch", + "get_jit_cuda_arch", + "is_arch_support_pdl", + "register_dependency", +] diff --git a/lightllm/third_party/sglang_jit/runtime_utils.py b/lightllm/third_party/sglang_jit/runtime_utils.py new file mode 100644 index 0000000000..d322498ca4 --- /dev/null +++ b/lightllm/third_party/sglang_jit/runtime_utils.py @@ -0,0 +1,5 @@ +import torch + + +def is_hip() -> bool: + return torch.version.hip is not None diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c8d7373d54..c892d4b0bc 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -464,6 +464,10 @@ def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: if model_type == "deepseek_v32": return "deepseekv32" + # DeepSeek V4 + if model_type == "deepseek_v4": + return "deepseekv4" + return None @@ -488,8 +492,8 @@ def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: ]: return "qwen3" - # DeepSeek V3 - if model_type in ["deepseek_v3", "deepseek_v31", "deepseek_v32"]: + # DeepSeek V3 / V4 (share the ... reasoning format, request-gated) + if model_type in ["deepseek_v3", "deepseek_v31", "deepseek_v32", "deepseek_v4"]: return "deepseek-v3" # DeepSeek R1