Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lightllm/common/kv_cache_mem_manager/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ def copy_mem_to_mem(self, src_mem_index: torch.Tensor, dst_mem_index: torch.Tens

# cpu cache 的相关操作接口
def load_cpu_cache_to_gpu(
self, mem_indexes: torch.Tensor, page_indexes: torch.Tensor, cpu_cache_client, req: "InferReq"
self,
move_token_num: int,
mem_indexes: torch.Tensor,
page_indexes: torch.Tensor,
cpu_cache_client,
req: "InferReq",
):
raise NotImplementedError()

Expand Down
6 changes: 5 additions & 1 deletion lightllm/common/kv_cache_mem_manager/operator/linear_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def __init__(self, mem_manager):

def load_cpu_cache_to_gpu(
self,
move_token_num: int,
mem_indexes: torch.Tensor,
page_indexes: torch.Tensor,
cpu_cache_client: "CpuKvCacheClient",
req: "InferReq",
):
# mem_indexes 中包含pad的部分,所以真实需要的move_token_num 是小于 len(mem_indexes) 的。
assert move_token_num <= len(mem_indexes)
assert mem_indexes.is_cuda and page_indexes.is_cuda
args = get_env_start_args()
assert triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) == len(page_indexes)
Expand All @@ -41,7 +44,8 @@ def load_cpu_cache_to_gpu(
mem_manager: Qwen3NextMemManager = self.mem_manager

big_page_num = len(mem_indexes) // args.cpu_cache_token_page_size
max_kv_len = (req.cur_kv_len // args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size
total_kv_len = req.cur_kv_len + move_token_num
max_kv_len = (total_kv_len // args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size
assert max_kv_len % args.cpu_cache_token_page_size == 0

big_page_buffer_ids_cpu = []
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/kv_cache_mem_manager/operator/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:

def load_cpu_cache_to_gpu(
self,
move_token_num: int,
mem_indexes: torch.Tensor,
page_indexes: torch.Tensor,
cpu_cache_client: "CpuKvCacheClient",
req: "InferReq",
):
assert move_token_num <= len(mem_indexes)
assert mem_indexes.is_cuda and page_indexes.is_cuda
args = get_env_start_args()
assert len(mem_indexes) % args.cpu_cache_token_page_size == 0
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/kv_cache_mem_manager/operator/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ class QuantScaleMemOperator(BaseMemManagerOperator):

def load_cpu_cache_to_gpu(
self,
move_token_num: int,
mem_indexes: torch.Tensor,
page_indexes: torch.Tensor,
cpu_cache_client: "CpuKvCacheClient",
req: "InferReq",
):
assert move_token_num <= len(mem_indexes)
assert mem_indexes.is_cuda and page_indexes.is_cuda
args = get_env_start_args()
assert len(mem_indexes) % args.cpu_cache_token_page_size == 0
Expand Down
25 changes: 19 additions & 6 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class InferenceContext:
cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None

overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。
cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream
cpu_kv_cache_load_stream: torch.cuda.Stream = None # 用于将 kv cache 从 cpu 加载到 gpu 的 stream
cpu_kv_cache_offload_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream
is_linear_att_mixed_model: bool = False # 标记模型是否是full att 混合 linear att 的混合模型。

def register(
Expand Down Expand Up @@ -77,10 +78,15 @@ def get_overlap_stream(self) -> torch.cuda.Stream:
self.overlap_stream = torch.cuda.Stream()
return self.overlap_stream

def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream:
if self.cpu_kv_cache_stream is None:
self.cpu_kv_cache_stream = torch.cuda.Stream()
return self.cpu_kv_cache_stream
def get_cpu_kv_cache_offload_stream(self) -> torch.cuda.Stream:
if self.cpu_kv_cache_offload_stream is None:
self.cpu_kv_cache_offload_stream = torch.cuda.Stream()
return self.cpu_kv_cache_offload_stream

def get_cpu_kv_cache_load_stream(self) -> torch.cuda.Stream:
if self.cpu_kv_cache_load_stream is None:
self.cpu_kv_cache_load_stream = torch.cuda.Stream()
return self.cpu_kv_cache_load_stream

def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]:
req_objs = []
Expand Down Expand Up @@ -548,7 +554,14 @@ def __init__(

# 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache
# 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态
self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED
if self.args.enable_cpu_cache:
self.cpu_cache_load_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED
self.cpu_cache_offload_task_status: "InferReq._CpuCacheTaskStatus" = (
InferReq._CpuCacheTaskStatus.NOT_STARTED
)
else:
self.cpu_cache_load_task_status = None
self.cpu_cache_offload_task_status = None

# mtp_step 用来记录一个请求 draft模型每步需要生成的token数量
# 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量
Expand Down
24 changes: 23 additions & 1 deletion lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,24 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
"""
return [g_infer_context.requests_mapping[request_id] for request_id in req_ids]

def _filter_cpu_cache_task_not_ready_req_ids(self, req_ids: List[int]) -> List[int]:
"""
过滤出当前正在进行 cpu cache 操作的任务, 对齐进行拦截操作。
"""
if not self.args.enable_cpu_cache:
return req_ids

ready_req_ids = []
for req_id in req_ids:
req: InferReq = g_infer_context.requests_mapping[req_id]
if req.cpu_cache_load_task_status.is_finished():
if req.cpu_cache_offload_task_status.is_not_started():
ready_req_ids.append(req_id)
elif req.cpu_cache_offload_task_status.is_finished():
ready_req_ids.append(req_id)

return ready_req_ids

def _timer_merge_radix_tree(self):
self._radix_tree_merge_counter += 1
if (
Expand Down Expand Up @@ -569,11 +587,15 @@ def _get_classed_reqs(
self._timer_merge_radix_tree()

if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0:
self.multi_level_cache_module.update_cpu_cache_task_states()
self.multi_level_cache_module.update_cpu_cache_offload_task_states()
self.multi_level_cache_module.update_cpu_cache_load_task_states()

if req_ids is None:
req_ids = g_infer_context.infer_req_ids

# 过滤出当前正在进行 cpu cache 操作的任务, 对齐进行拦截操作。
req_ids = self._filter_cpu_cache_task_not_ready_req_ids(req_ids)

if len(req_ids) == 0:
return [], []

Expand Down
Loading
Loading