Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 172 additions & 2 deletions lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,60 @@
from lightllm.common.basemodel.infer_struct import InferStateInfo


# this flash_mla extra-cache fork only instantiates h_q in {64, 128}; pad TP-split q heads up
# to the nearest supported count (zero heads are discarded from the output slice).
FLASHMLA_SUPPORTED_HEADS = (64, 128)


def _pad_q_heads(q_4d: torch.Tensor, attn_sink: torch.Tensor):
h_q = q_4d.shape[2]
if h_q in FLASHMLA_SUPPORTED_HEADS:
return q_4d, attn_sink, h_q
target = next((h for h in FLASHMLA_SUPPORTED_HEADS if h >= h_q), None)
assert target is not None, f"num q heads {h_q} exceeds flash_mla support {FLASHMLA_SUPPORTED_HEADS}"
q_pad = torch.nn.functional.pad(q_4d, (0, 0, 0, target - h_q))
sink_pad = torch.nn.functional.pad(attn_sink, (0, target - h_q))
return q_pad, sink_pad, h_q


def _view_dsv4_flashmla_cache(layer_buffer: torch.Tensor, page_size: int) -> torch.Tensor:
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_MLA_BYTES_PER_TOKEN

usable = page_size * DSV4_MLA_BYTES_PER_TOKEN
return layer_buffer[:, :usable].view(layer_buffer.shape[0], page_size, 1, DSV4_MLA_BYTES_PER_TOKEN)


@dataclasses.dataclass
class _Dsv4Metadata:
swa_indices: torch.Tensor
swa_lengths: torch.Tensor
extra_cache: torch.Tensor = None
extra_indices: torch.Tensor = None
extra_lengths: torch.Tensor = None


def _metadata_from_dict(infer_state, nsa_dict: dict) -> "_Dsv4Metadata":
"""Bundle the model-built FINAL index tensors (carried in nsa_dict by DeepseekV4IndexInfer) with
the layer-keyed fp8 extra-cache byte view. The cache view is data-independent (a fixed per-layer
buffer slice), so it is built here -- a genuine flash_mla ABI concern -- rather than on the model
side; only the index/length tensors cross the att_control boundary."""
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_C128_PAGE_SIZE, DSV4_C4_PAGE_SIZE

ratio = nsa_dict["compress_ratio"]
extra_cache = None
if ratio:
page = DSV4_C4_PAGE_SIZE if ratio == 4 else DSV4_C128_PAGE_SIZE
extra_buffer = infer_state.mem_manager.get_compressed_kv_buffer(nsa_dict["layer_index"])
extra_cache = _view_dsv4_flashmla_cache(extra_buffer, page)
return _Dsv4Metadata(
swa_indices=nsa_dict["swa_indices"],
swa_lengths=nsa_dict["swa_lengths"],
extra_cache=extra_cache,
extra_indices=nsa_dict.get("extra_indices"),
extra_lengths=nsa_dict.get("extra_lengths"),
)


class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand Down Expand Up @@ -62,6 +116,12 @@ def prefill_att(
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
if att_control.nsa_prefill_dict.get("flashmla_kvcache"):
return self._flashmla_kvcache_prefill_att(
q=q,
packed_kv=k,
nsa_dict=att_control.nsa_prefill_dict,
)
return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control)

def _nsa_prefill_att(
Expand All @@ -78,6 +138,8 @@ def _nsa_prefill_att(
kv_lora_rank = nsa_dict["kv_lora_rank"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
attn_sink = nsa_dict.get("attn_sink")
topk_length = nsa_dict.get("topk_length")

if self.infer_state.prefix_total_token_num > 0:
# 当前推理生成的token kv部分从 prefill_cache_kv 中获取,历史
Expand All @@ -101,9 +163,51 @@ def _nsa_prefill_att(
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
attn_sink=attn_sink,
topk_length=topk_length,
)
return mla_out

def _flashmla_kvcache_prefill_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous()
metadata = _metadata_from_dict(self.infer_state, nsa_dict)
return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)

def _flashmla_kvcache_att(
self,
q: torch.Tensor,
packed_kv: torch.Tensor,
metadata: _Dsv4Metadata,
attn_sink: torch.Tensor,
nsa_dict: dict,
) -> torch.Tensor:
import flash_mla
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE

q_4d = q.unsqueeze(1).contiguous()
q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
sched_meta, _ = flash_mla.get_mla_metadata()
out, _ = flash_mla.flash_mla_with_kvcache(
q=q_4d,
k_cache=k_cache,
block_table=None,
cache_seqlens=None,
head_dim_v=nsa_dict["head_dim_v"],
tile_scheduler_metadata=sched_meta,
num_splits=None,
softmax_scale=nsa_dict["softmax_scale"],
causal=False,
is_fp8_kvcache=True,
indices=metadata.swa_indices,
attn_sink=attn_sink,
topk_length=metadata.swa_lengths,
extra_k_cache=metadata.extra_cache,
extra_indices_in_kvcache=metadata.extra_indices,
extra_topk_length=metadata.extra_lengths,
)
return out[:, 0, :num_real_heads].contiguous()


@dataclasses.dataclass
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
Expand Down Expand Up @@ -143,7 +247,18 @@ def init_state(self):
)
import flash_mla

self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
# one sched_meta per layer type: the lazy config locks extra-cache geometry (page size,
# presence) on first invocation, so swa-only/c4/c128 layers must not share one object.
self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
return

def reset_sched_meta_for_capture(self):
# cuda-graph capture hook: the warmup pass already locked/stored sched meta on this
# (shared) state object; reset so the capture pass re-plans INSIDE the graph and every
# replay re-plans from the live tensors instead of binding warmup leftovers.
import flash_mla

self.flashmla_sched_meta = {ratio: flash_mla.get_mla_metadata()[0] for ratio in (0, 4, 128)}
return

def decode_att(
Expand All @@ -156,6 +271,12 @@ def decode_att(
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
if att_control.nsa_decode_dict.get("flashmla_kvcache"):
return self._flashmla_kvcache_decode_att(
q=q,
packed_kv=k,
nsa_dict=att_control.nsa_decode_dict,
)
return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control)

def _nsa_decode_att(
Expand All @@ -170,6 +291,11 @@ def _nsa_decode_att(
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
attn_sink = nsa_dict.get("attn_sink")
topk_length = nsa_dict.get("topk_length")
extra_k_cache = nsa_dict.get("extra_k_cache")
extra_indices = nsa_dict.get("extra_indices_in_kvcache")
extra_topk_length = nsa_dict.get("extra_topk_length")

if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)
Expand All @@ -189,10 +315,54 @@ def _nsa_decode_att(
block_table=None,
cache_seqlens=None,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
tile_scheduler_metadata=self.flashmla_sched_meta[0],
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices,
attn_sink=attn_sink,
topk_length=topk_length,
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_length,
)
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]

def _flashmla_kvcache_decode_att(self, q: torch.Tensor, packed_kv: torch.Tensor, nsa_dict: dict) -> torch.Tensor:
attn_sink = nsa_dict["attn_sink"].to(torch.float32).contiguous()
metadata = _metadata_from_dict(self.infer_state, nsa_dict)
return self._flashmla_kvcache_att(q, packed_kv, metadata, attn_sink, nsa_dict)

def _flashmla_kvcache_att(
self,
q: torch.Tensor,
packed_kv: torch.Tensor,
metadata: _Dsv4Metadata,
attn_sink: torch.Tensor,
nsa_dict: dict,
) -> torch.Tensor:
import flash_mla
from lightllm.common.kv_cache_mem_manager.deepseek4_mem_manager import DSV4_SWA_PAGE_SIZE

q_4d = q.unsqueeze(1).contiguous()
q_4d, attn_sink, num_real_heads = _pad_q_heads(q_4d, attn_sink)
k_cache = _view_dsv4_flashmla_cache(packed_kv, DSV4_SWA_PAGE_SIZE)
out, _ = flash_mla.flash_mla_with_kvcache(
q=q_4d,
k_cache=k_cache,
block_table=None,
cache_seqlens=None,
head_dim_v=nsa_dict["head_dim_v"],
tile_scheduler_metadata=self.flashmla_sched_meta[nsa_dict["compress_ratio"]],
num_splits=None,
softmax_scale=nsa_dict["softmax_scale"],
causal=False,
is_fp8_kvcache=True,
indices=metadata.swa_indices,
attn_sink=attn_sink,
topk_length=metadata.swa_lengths,
extra_k_cache=metadata.extra_cache,
extra_indices_in_kvcache=metadata.extra_indices,
extra_topk_length=metadata.extra_lengths,
)
return out[:, 0, :num_real_heads].contiguous()
45 changes: 45 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ def _init_custom(self):

@torch.no_grad()
def forward(self, model_input: ModelInput):
# decode 槽位 prep: 放在 to_cuda 前 (b_req_idx/b_seq_len/mem_indexes_cpu 还是原生 CPU 张量),
# 且此刻已在 forward 的 CUDA stream 上 -> 与后续 attention 同流, 无跨流竞态、无 D2H。
# mem_indexes_cpu is None 时跳过: cudagraph warmup 的输入全在 CUDA 且 b_req_idx 全为 HOLD, prep 本就是 no-op。
if not model_input.is_prefill and model_input.mem_indexes_cpu is not None:
self.req_manager.prepare_decode(model_input.b_req_idx, model_input.b_seq_len, model_input.mem_indexes_cpu)
model_input.to_cuda()
assert model_input.mem_indexes.is_cuda

Expand Down Expand Up @@ -521,6 +526,18 @@ def _prefill(
alloc_mem_index=infer_state.mem_index,
max_q_seq_len=infer_state.max_q_seq_len,
)
if hasattr(self.req_manager, "prepare_prefill_swa"):
self.req_manager.prepare_prefill_swa(
b_req_idx=infer_state.b_req_idx,
b_ready_cache_len=infer_state.b_ready_cache_len,
b_seq_len=infer_state.b_seq_len,
)
if hasattr(self.req_manager, "prepare_prefill_compress_slots"):
self.req_manager.prepare_prefill_compress_slots(
b_req_idx=infer_state.b_req_idx,
b_ready_cache_len=infer_state.b_ready_cache_len,
b_seq_len=infer_state.b_seq_len,
)
prefill_mem_indexes_ready_event = torch.cuda.Event()
prefill_mem_indexes_ready_event.record()

Expand Down Expand Up @@ -741,6 +758,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state0.mem_index,
max_q_seq_len=infer_state0.max_q_seq_len,
)
if hasattr(self.req_manager, "prepare_prefill_swa"):
self.req_manager.prepare_prefill_swa(
b_req_idx=infer_state0.b_req_idx,
b_ready_cache_len=infer_state0.b_ready_cache_len,
b_seq_len=infer_state0.b_seq_len,
)
if hasattr(self.req_manager, "prepare_prefill_compress_slots"):
self.req_manager.prepare_prefill_compress_slots(
b_req_idx=infer_state0.b_req_idx,
b_ready_cache_len=infer_state0.b_ready_cache_len,
b_seq_len=infer_state0.b_seq_len,
)
infer_state0.init_some_extra_state(self)
infer_state0.init_att_state()

Expand All @@ -754,6 +783,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
alloc_mem_index=infer_state1.mem_index,
max_q_seq_len=infer_state1.max_q_seq_len,
)
if hasattr(self.req_manager, "prepare_prefill_swa"):
self.req_manager.prepare_prefill_swa(
b_req_idx=infer_state1.b_req_idx,
b_ready_cache_len=infer_state1.b_ready_cache_len,
b_seq_len=infer_state1.b_seq_len,
)
if hasattr(self.req_manager, "prepare_prefill_compress_slots"):
self.req_manager.prepare_prefill_compress_slots(
b_req_idx=infer_state1.b_req_idx,
b_ready_cache_len=infer_state1.b_ready_cache_len,
b_seq_len=infer_state1.b_seq_len,
)
infer_state1.init_some_extra_state(self)
infer_state1.init_att_state()

Expand Down Expand Up @@ -781,6 +822,10 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod

@torch.no_grad()
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
# decode 槽位 prep: 在 to_cuda 前 (原生 CPU 张量)、且已在 forward 的 CUDA stream 上 (见 forward 注释)。
for mi in (model_input0, model_input1):
if mi.mem_indexes_cpu is not None:
self.req_manager.prepare_decode(mi.b_req_idx, mi.b_seq_len, mi.mem_indexes_cpu)
model_input0.to_cuda()
model_input1.to_cuda()
assert self.args.enable_tpsp_mix_mode
Expand Down
19 changes: 19 additions & 0 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
logger = init_logger(__name__)


def _reset_att_state_sched_meta(infer_state: InferStateInfo):
# capture 前调用: warmup 趟用 copy.copy 浅拷贝共享 decode_att_state,其内部惰性初始化的
# 调度对象(如 FlashMLASchedMeta,首次内核调用时按当时数据规划并回写)会被 warmup 的
# dummy 负载锁定;若不重置,捕获趟将绑定为 dummy 规划的调度张量,所有 replay 都用错误
# 的 tile schedule(DSV4 实测 gsm8k 0.96 -> 0.74)。重置后规划发生在捕获区内,随 replay 重算。
for att_state in (infer_state.decode_att_state, infer_state.decode_att_state1):
if att_state is None:
continue
reset_fn = getattr(att_state, "reset_sched_meta_for_capture", None)
if reset_fn is not None:
reset_fn()
return


class CudaGraph:
# CudaGraph forward pass for the decoding stage.

Expand Down Expand Up @@ -94,6 +108,8 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
if param_name not in pure_para_set:
delattr(infer_state, param_name)

_reset_att_state_sched_meta(infer_state)

with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
Expand Down Expand Up @@ -128,6 +144,9 @@ def _capture_decode_overlap(
if para_name not in pure_para_set1:
delattr(infer_state1, para_name)

_reset_att_state_sched_meta(infer_state)
_reset_att_state_sched_meta(infer_state1)

with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
Expand Down
Loading
Loading