From 32103531a5ecca18971bce4bcb13d05420a183a2 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 25 May 2026 09:39:49 +0000 Subject: [PATCH 01/10] update cpu cache load stream async --- .../server/router/model_infer/infer_batch.py | 16 +- .../model_infer/mode_backend/base_backend.py | 3 +- .../mode_backend/multi_level_kv_cache.py | 235 ++++++++++++------ 3 files changed, 174 insertions(+), 80 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c..9d6d82102 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -484,7 +484,7 @@ def has_constraint_setting(self) -> bool: class InferReq: - class _CpuCacheTaskStatus(enum.Enum): + class _CpuCacheOffloadTaskStatus(enum.Enum): NOT_STARTED = 0 RUNNING = 1 FINISHED = 2 @@ -498,6 +498,9 @@ def is_running(self): def is_finished(self): return self == self.FINISHED + class _CpuCacheLoadTaskStatus(_CpuCacheOffloadTaskStatus): + pass + def __init__( self, req_id: int, @@ -548,7 +551,16 @@ def __init__( # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 - self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED + if self.args.enable_cpu_cache: + self.cpu_cache_load_task_status: "InferReq._CpuCacheLoadTaskStatus" = ( + InferReq._CpuCacheLoadTaskStatus.NOT_STARTED + ) + self.cpu_cache_offload_task_status: "InferReq._CpuCacheOffloadTaskStatus" = ( + InferReq._CpuCacheOffloadTaskStatus.NOT_STARTED + ) + else: + self.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED + self.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e47717747..c9b3c3405 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -569,7 +569,8 @@ def _get_classed_reqs( self._timer_merge_radix_tree() if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0: - self.multi_level_cache_module.update_cpu_cache_task_states() + self.multi_level_cache_module.update_cpu_cache_offload_task_states() + self.multi_level_cache_module.update_cpu_cache_load_task_states() if req_ids is None: req_ids = g_infer_context.infer_req_ids diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index d0025a03c..9eb2d4df9 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -35,7 +35,8 @@ def __init__(self, backend): self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device="cuda") self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device="cuda") - self.cpu_cache_handle_queue: Deque[TransTask] = deque() + self.cpu_cache_load_task_handle_queue: Deque[LoadTransTask] = deque() + self.cpu_cache_offload_task_handle_queue: Deque[OffloadTransTask] = deque() self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=False) @lru_cache() @@ -60,7 +61,7 @@ def need_sync_compute_stream(self) -> bool: def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): idle_token_num = g_infer_context.get_can_alloc_token_num() - all_page_list = [] + need_free_page_list = [] is_master_in_dp = self.backend.is_master_in_dp for req in reqs: page_list = req.shm_req.cpu_cache_match_page_indexes.get_all() @@ -81,74 +82,100 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): need_token_num = match_tokens - req.cur_kv_len # 多匹配了一定数量的token同时请求长度大于一定的长度,才进行复制操作,不然操作效率不高,代价过高 - if need_token_num >= 128 and req.shm_req.input_len >= 256: - if need_token_num <= idle_token_num: - if self.backend.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) - - # 计算需要加载的页面(只加载未匹配的部分) - ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len) - assert ready_page_num <= len(page_list) - need_pages = page_list[ready_page_num:] # 只取需要的页面 - - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) - - if self.need_sync_compute_stream(): - # TODO fa3 现在必须使用同步模式, 未来需要移除 - torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) - # g_infer_context.get_overlap_stream().synchronize() - - mem_manager = self.backend.model.mem_manager - req_manager = self.backend.model.req_manager - - mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) - page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda( - non_blocking=True - ) - # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, - # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 - # 这里需要进行pad操作,使操作的页面是完整的。 - _start = page_len_start_list[ready_page_num] - - _end = req.cur_kv_len - assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]" - mem_indexes_cuda = torch.cat( - [req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda] - ) - - assert ( - len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] - ) - - # 更新 req 状态。 - idle_token_num -= need_token_num - g_infer_context.req_manager.req_to_token_indexs[ - req.req_idx, req.cur_kv_len : (req.cur_kv_len + need_token_num) - ] = mem_indexes - req.cur_kv_len = req.cur_kv_len + need_token_num - - mem_manager.operator.load_cpu_cache_to_gpu( - mem_indexes=mem_indexes_cuda, - page_indexes=page_indexes_cuda, - cpu_cache_client=self.cpu_cache_client, - req=req, - ) - - torch.cuda.current_stream().synchronize() + ok_to_start_load_task = ( + need_token_num >= 128 and req.shm_req.input_len >= 256 and need_token_num <= idle_token_num + ) - if self.backend.is_master_in_dp: - req.shm_req.shm_cur_kv_len = req.cur_kv_len + if not ok_to_start_load_task: + req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED + need_free_page_list.extend(page_list) + continue + else: + assert req.cpu_cache_load_task_status.is_not_started() + trans_task = self._start_kv_cache_load_task( + req=req, + need_token_num=need_token_num, + page_list=page_list, + page_len_list=page_len_list, + page_len_start_list=page_len_start_list, + cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_stream(), + ) + assert trans_task is not None + req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.RUNNING - all_page_list.extend(page_list) + self.cpu_cache_load_task_handle_queue.append(trans_task) + idle_token_num -= need_token_num dist.barrier(group=self.init_sync_group) if self.backend.is_master_in_dp: self.cpu_cache_client.lock.acquire_sleep1ms() - self.cpu_cache_client.deref_pages(page_list=all_page_list) + self.cpu_cache_client.deref_pages(page_list=need_free_page_list) self.cpu_cache_client.lock.release() return + def _start_kv_cache_load_task( + self, + req: InferReq, + need_token_num: int, + page_list: List[int], + page_len_list: List[int], + page_len_start_list: List[int], + cpu_kv_cache_stream: torch.cuda.Stream, + ) -> Optional["LoadTransTask"]: + if self.backend.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) + + # 计算需要加载的页面(只加载未匹配的部分) + ready_page_num = bisect.bisect_right(page_len_list, req.cur_kv_len) + assert ready_page_num <= len(page_list) + need_pages = page_list[ready_page_num:] # 只取需要的页面 + + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) + + if self.need_sync_compute_stream(): + # TODO fa3 现在必须使用同步模式, 未来需要移除 + torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) + # g_infer_context.get_overlap_stream().synchronize() + + mem_manager = self.backend.model.mem_manager + req_manager = self.backend.model.req_manager + + mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) + page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True) + # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, + # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 + # 这里需要进行pad操作,使操作的页面是完整的。 + _start = page_len_start_list[ready_page_num] + + _end = req.cur_kv_len + assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]" + mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda]) + + assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] + + # 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作, + # 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。 + req.cur_kv_len = req.cur_kv_len + need_token_num + + mem_manager.operator.load_cpu_cache_to_gpu( + mem_indexes=mem_indexes_cuda, + page_indexes=page_indexes_cuda, + cpu_cache_client=self.cpu_cache_client, + req=req, + ) + + sync_event = torch.cuda.Event() + sync_event.record() + + trans_task = LoadTransTask( + req_obj=req, + page_list=page_list, + mem_indexes=mem_indexes, + sync_event=sync_event, + ) + return trans_task + def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]: """ 将满足cpu kv cache 卸载条件的请求进行处理, 并返回真的满足退出条件的请求list。 @@ -176,15 +203,15 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L continue # 如果请求已经完成了 cpu cache 的任务,则满足了退出条件 - if req.cpu_cache_task_status.is_finished(): + if req.cpu_cache_offload_task_status.is_finished(): true_finished_reqs.append(req) continue # 如果请求已经发起过卸载任务且正在卸载过程中,则在当前轮不进行处理 - if req.cpu_cache_task_status.is_running(): + if req.cpu_cache_offload_task_status.is_running(): continue - assert req.cpu_cache_task_status.is_not_started() + assert req.cpu_cache_offload_task_status.is_not_started() if self.need_sync_compute_stream(): # TODO fa3 现在必须使用同步模式, 未来需要移除, 必须等待 overlap stream 上的计算任务完成,不然会崩溃 @@ -195,7 +222,7 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L # 根据是否成功创建了卸载任务,决定是否将请求加入到处理队列中 if trans_task is not None: - self.cpu_cache_handle_queue.append(trans_task) + self.cpu_cache_offload_task_handle_queue.append(trans_task) else: true_finished_reqs.append(req) @@ -207,7 +234,7 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L def _start_kv_cache_offload_task( self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream - ) -> Optional["TransTask"]: + ) -> Optional["OffloadTransTask"]: with torch.cuda.stream(cpu_kv_cache_stream): # 综合考虑后只对prompt做缓存管理,不包含decode内容,这里与radix cache不一致 token_hash_list = req.shm_req.token_hash_list.get_all() @@ -226,7 +253,7 @@ def _start_kv_cache_offload_task( if move_block_size == 0: dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) - req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED return None try: @@ -241,7 +268,7 @@ def _start_kv_cache_offload_task( item_size = len(page_list) if item_size == 0: dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) - req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED return None broadcast_data = {"item_size": item_size, "page_list": page_list, "ready_list": ready_list} @@ -250,7 +277,7 @@ def _start_kv_cache_offload_task( recv_list = [None] dist.broadcast_object_list(recv_list, group=self.gloo_group, group_src=0) if isinstance(recv_list[0], int) and recv_list[0] == 0: - req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED return None broadcast_data = recv_list[0] item_size = broadcast_data["item_size"] @@ -285,8 +312,8 @@ def _start_kv_cache_offload_task( sync_event = torch.cuda.Event() sync_event.record() - req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING - trans_task = TransTask( + req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.RUNNING + trans_task = OffloadTransTask( move_token_num=move_token_num, page_indexes=page_indexes, page_readies=page_readies, @@ -314,15 +341,15 @@ def _handle_linear_att_last_page(self, req: InferReq, move_block_size: int, page return move_block_size - 1 return move_block_size - def update_cpu_cache_task_states(self): + def update_cpu_cache_offload_task_states(self): if self.backend.is_master_in_dp: trans_ok_tasks = [] - while len(self.cpu_cache_handle_queue) != 0: - task: TransTask = self.cpu_cache_handle_queue.popleft() + while len(self.cpu_cache_offload_task_handle_queue) != 0: + task: OffloadTransTask = self.cpu_cache_offload_task_handle_queue.popleft() if task.sync_event.query(): trans_ok_tasks.append(task) else: - self.cpu_cache_handle_queue.appendleft(task) + self.cpu_cache_offload_task_handle_queue.appendleft(task) break item_size = len(trans_ok_tasks) dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) @@ -330,7 +357,9 @@ def update_cpu_cache_task_states(self): recv_list = [None] dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) item_size = recv_list[0] - trans_ok_tasks: List[TransTask] = [self.cpu_cache_handle_queue.popleft() for _ in range(item_size)] + trans_ok_tasks: List[OffloadTransTask] = [ + self.cpu_cache_offload_task_handle_queue.popleft() for _ in range(item_size) + ] if item_size > 0: page_array_list = [task.page_indexes.tolist() for task in trans_ok_tasks] @@ -347,14 +376,66 @@ def update_cpu_cache_task_states(self): ) self.cpu_cache_client.lock.release() for task in trans_ok_tasks: - task.req_obj.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + task.req_obj.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + return + + def update_cpu_cache_load_task_states(self): + if self.backend.is_master_in_dp: + trans_ok_tasks = [] + while len(self.cpu_cache_load_task_handle_queue) != 0: + task: LoadTransTask = self.cpu_cache_load_task_handle_queue.popleft() + if task.sync_event.query(): + trans_ok_tasks.append(task) + else: + self.cpu_cache_load_task_handle_queue.appendleft(task) + break + item_size = len(trans_ok_tasks) + dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) + else: + recv_list = [None] + dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) + item_size = recv_list[0] + trans_ok_tasks: List[LoadTransTask] = [ + self.cpu_cache_load_task_handle_queue.popleft() for _ in range(item_size) + ] + + if item_size > 0: + need_free_page_list = [] + for task in trans_ok_tasks: + need_free_page_list.extend(task.page_list) + # 更新 req 状态。 + req = task.req_obj + assert ( + len(task.mem_indexes) <= req.cur_kv_len + ), f"invalid load task mem_indexes length [{len(task.mem_indexes)}] <= cur_kv_len [{req.cur_kv_len}]" + g_infer_context.req_manager.req_to_token_indexs[ + req.req_idx, (req.cur_kv_len - len(task.mem_indexes)) : req.cur_kv_len + ] = task.mem_indexes + + if self.backend.is_master_in_dp: + req.shm_req.shm_cur_kv_len = req.cur_kv_len + + task.req_obj.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED + + if self.backend.is_master_in_dp and need_free_page_list: + self.cpu_cache_client.lock.acquire_sleep1ms() + self.cpu_cache_client.deref_pages(page_list=need_free_page_list) + self.cpu_cache_client.lock.release() return @dataclasses.dataclass -class TransTask: +class OffloadTransTask: move_token_num: int page_indexes: torch.Tensor page_readies: torch.Tensor req_obj: InferReq sync_event: torch.cuda.Event + + +@dataclasses.dataclass +class LoadTransTask: + req_obj: InferReq + page_list: List[int] + mem_indexes: torch.Tensor + sync_event: torch.cuda.Event From 25931fb683df7397ceb4d356cccffffc2a006b5f Mon Sep 17 00:00:00 2001 From: wzj Date: Mon, 25 May 2026 13:47:21 +0000 Subject: [PATCH 02/10] fix --- .../server/router/model_infer/infer_batch.py | 16 +++++++++---- .../mode_backend/multi_level_kv_cache.py | 24 +++++++++++-------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 9d6d82102..13fbda49e 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -39,7 +39,8 @@ class InferenceContext: cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 - cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + cpu_kv_cache_load_stream: torch.cuda.Stream = None # 用于将 kv cache 从 cpu 加载到 gpu 的 stream + cpu_kv_cache_offload_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream is_linear_att_mixed_model: bool = False # 标记模型是否是full att 混合 linear att 的混合模型。 def register( @@ -77,10 +78,15 @@ def get_overlap_stream(self) -> torch.cuda.Stream: self.overlap_stream = torch.cuda.Stream() return self.overlap_stream - def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: - if self.cpu_kv_cache_stream is None: - self.cpu_kv_cache_stream = torch.cuda.Stream() - return self.cpu_kv_cache_stream + def get_cpu_kv_cache_offload_stream(self) -> torch.cuda.Stream: + if self.cpu_kv_cache_offload_stream is None: + self.cpu_kv_cache_offload_stream = torch.cuda.Stream() + return self.cpu_kv_cache_offload_stream + + def get_cpu_kv_cache_load_stream(self) -> torch.cuda.Stream: + if self.cpu_kv_cache_load_stream is None: + self.cpu_kv_cache_load_stream = torch.cuda.Stream() + return self.cpu_kv_cache_load_stream def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 9eb2d4df9..3b40929c0 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -98,7 +98,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): page_list=page_list, page_len_list=page_len_list, page_len_start_list=page_len_start_list, - cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_stream(), + cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_load_stream(), ) assert trans_task is not None req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.RUNNING @@ -180,25 +180,29 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L """ 将满足cpu kv cache 卸载条件的请求进行处理, 并返回真的满足退出条件的请求list。 """ + # 过滤不适合进行 kv 卸载到 cpu cache 的请求。 + if g_infer_context.is_linear_att_mixed_model: + offload_limit_size = self.args.linear_att_hash_page_size + else: + offload_limit_size = self.args.cpu_cache_token_page_size + # 如果开启了cpu cache,将达到finished状态的请求开启将gpu kv cache 卸载到 cpu cache中的操作。 # 当 kv cache 卸载完成后,才会进行请求的真实退出操作。 true_finished_reqs = [] - cpu_stream = g_infer_context.get_cpu_kv_cache_stream() + offload_stream = g_infer_context.get_cpu_kv_cache_offload_stream() for req in finished_reqs: # 只有 group_req_id 和 request_id 相同的请求才会被卸载到 cpu cache 中。 # 这个限制是为了兼容 diverse 模式下的请求处理, 只有主请求才 offload kv 到 cpu # cache 中 if req.shm_req.group_req_id != req.shm_req.request_id: + assert req.cpu_cache_offload_task_status.is_not_started() + req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED true_finished_reqs.append(req) continue - # 过滤不适合进行 kv 卸载到 cpu cache 的请求。 - if g_infer_context.is_linear_att_mixed_model: - offload_limit_size = self.args.linear_att_hash_page_size - else: - offload_limit_size = self.args.cpu_cache_token_page_size - if req.cur_kv_len < offload_limit_size or req.shm_req.input_len <= offload_limit_size: + assert req.cpu_cache_offload_task_status.is_not_started() + req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED true_finished_reqs.append(req) continue @@ -218,7 +222,7 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L g_infer_context.get_overlap_stream().synchronize() # 发起将请求的 kv cache 卸载到 cpu cache 中的任务 - trans_task = self._start_kv_cache_offload_task(req=req, cpu_kv_cache_stream=cpu_stream) + trans_task = self._start_kv_cache_offload_task(req=req, cpu_kv_cache_stream=offload_stream) # 根据是否成功创建了卸载任务,决定是否将请求加入到处理队列中 if trans_task is not None: @@ -228,7 +232,7 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L if self.need_sync_compute_stream(): # TODO fa3 现在必须使用同步模式, 未来需要移除 - cpu_stream.synchronize() + offload_stream.synchronize() return true_finished_reqs From 666d99262fbef7e9ade4cc496d34ce43b1b85eef Mon Sep 17 00:00:00 2001 From: wzj Date: Mon, 25 May 2026 13:48:47 +0000 Subject: [PATCH 03/10] fix --- lightllm/server/router/model_infer/infer_batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 13fbda49e..07404ffce 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -565,8 +565,8 @@ def __init__( InferReq._CpuCacheOffloadTaskStatus.NOT_STARTED ) else: - self.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED - self.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + self.cpu_cache_load_task_status = None + self.cpu_cache_offload_task_status = None # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 From b1d9d83644faee31f472793c5c38e4be7c6144d2 Mon Sep 17 00:00:00 2001 From: wzj Date: Mon, 25 May 2026 14:14:21 +0000 Subject: [PATCH 04/10] fix --- .../mode_backend/multi_level_kv_cache.py | 51 +++++++------------ 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 3b40929c0..f1d458224 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -4,7 +4,7 @@ import dataclasses import bisect from functools import lru_cache -from typing import Optional, List, Deque +from typing import Optional, List, Deque, Union from collections import deque from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.utils.config_utils import is_linear_att_mixed_model @@ -219,7 +219,7 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L if self.need_sync_compute_stream(): # TODO fa3 现在必须使用同步模式, 未来需要移除, 必须等待 overlap stream 上的计算任务完成,不然会崩溃 - g_infer_context.get_overlap_stream().synchronize() + offload_stream.wait_stream(g_infer_context.get_overlap_stream()) # 发起将请求的 kv cache 卸载到 cpu cache 中的任务 trans_task = self._start_kv_cache_offload_task(req=req, cpu_kv_cache_stream=offload_stream) @@ -345,15 +345,15 @@ def _handle_linear_att_last_page(self, req: InferReq, move_block_size: int, page return move_block_size - 1 return move_block_size - def update_cpu_cache_offload_task_states(self): + def _get_trans_ok_tasks_from_queue(self, task_queue: Deque) -> List[Union["OffloadTransTask", "LoadTransTask"]]: if self.backend.is_master_in_dp: trans_ok_tasks = [] - while len(self.cpu_cache_offload_task_handle_queue) != 0: - task: OffloadTransTask = self.cpu_cache_offload_task_handle_queue.popleft() + while len(task_queue) != 0: + task: Union[OffloadTransTask, LoadTransTask] = task_queue.popleft() if task.sync_event.query(): trans_ok_tasks.append(task) else: - self.cpu_cache_offload_task_handle_queue.appendleft(task) + task_queue.appendleft(task) break item_size = len(trans_ok_tasks) dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) @@ -361,11 +361,14 @@ def update_cpu_cache_offload_task_states(self): recv_list = [None] dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) item_size = recv_list[0] - trans_ok_tasks: List[OffloadTransTask] = [ - self.cpu_cache_offload_task_handle_queue.popleft() for _ in range(item_size) + trans_ok_tasks: List[Union[OffloadTransTask, LoadTransTask]] = [ + task_queue.popleft() for _ in range(item_size) ] + return trans_ok_tasks - if item_size > 0: + def update_cpu_cache_offload_task_states(self): + trans_ok_tasks = self._get_trans_ok_tasks_from_queue(self.cpu_cache_offload_task_handle_queue) + if len(trans_ok_tasks) > 0: page_array_list = [task.page_indexes.tolist() for task in trans_ok_tasks] move_token_nums = [task.move_token_num for task in trans_ok_tasks] if self.backend.is_master_in_dp: @@ -384,26 +387,8 @@ def update_cpu_cache_offload_task_states(self): return def update_cpu_cache_load_task_states(self): - if self.backend.is_master_in_dp: - trans_ok_tasks = [] - while len(self.cpu_cache_load_task_handle_queue) != 0: - task: LoadTransTask = self.cpu_cache_load_task_handle_queue.popleft() - if task.sync_event.query(): - trans_ok_tasks.append(task) - else: - self.cpu_cache_load_task_handle_queue.appendleft(task) - break - item_size = len(trans_ok_tasks) - dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) - else: - recv_list = [None] - dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) - item_size = recv_list[0] - trans_ok_tasks: List[LoadTransTask] = [ - self.cpu_cache_load_task_handle_queue.popleft() for _ in range(item_size) - ] - - if item_size > 0: + trans_ok_tasks = self._get_trans_ok_tasks_from_queue(self.cpu_cache_load_task_handle_queue) + if len(trans_ok_tasks) > 0: need_free_page_list = [] for task in trans_ok_tasks: need_free_page_list.extend(task.page_list) @@ -421,10 +406,10 @@ def update_cpu_cache_load_task_states(self): task.req_obj.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED - if self.backend.is_master_in_dp and need_free_page_list: - self.cpu_cache_client.lock.acquire_sleep1ms() - self.cpu_cache_client.deref_pages(page_list=need_free_page_list) - self.cpu_cache_client.lock.release() + if self.backend.is_master_in_dp and need_free_page_list: + self.cpu_cache_client.lock.acquire_sleep1ms() + self.cpu_cache_client.deref_pages(page_list=need_free_page_list) + self.cpu_cache_client.lock.release() return From effa8e958fae1d372628f8be422acca520497db5 Mon Sep 17 00:00:00 2001 From: wzj Date: Mon, 25 May 2026 14:25:57 +0000 Subject: [PATCH 05/10] fix --- .../mode_backend/multi_level_kv_cache.py | 101 +++++++++--------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index f1d458224..34317ff95 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -60,6 +60,7 @@ def need_sync_compute_stream(self) -> bool: return False def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): + load_stream = g_infer_context.get_cpu_kv_cache_load_stream() idle_token_num = g_infer_context.get_can_alloc_token_num() need_free_page_list = [] is_master_in_dp = self.backend.is_master_in_dp @@ -92,20 +93,27 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): continue else: assert req.cpu_cache_load_task_status.is_not_started() + + if self.need_sync_compute_stream(): + # TODO fa3 现在必须使用同步模式, 未来需要移除 + load_stream.wait_stream(g_infer_context.get_overlap_stream()) + trans_task = self._start_kv_cache_load_task( req=req, need_token_num=need_token_num, page_list=page_list, page_len_list=page_len_list, page_len_start_list=page_len_start_list, - cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_load_stream(), + load_stream=load_stream, ) assert trans_task is not None - req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.RUNNING self.cpu_cache_load_task_handle_queue.append(trans_task) idle_token_num -= need_token_num + if self.need_sync_compute_stream(): + load_stream.synchronize() + dist.barrier(group=self.init_sync_group) if self.backend.is_master_in_dp: @@ -121,7 +129,7 @@ def _start_kv_cache_load_task( page_list: List[int], page_len_list: List[int], page_len_start_list: List[int], - cpu_kv_cache_stream: torch.cuda.Stream, + load_stream: torch.cuda.Stream, ) -> Optional["LoadTransTask"]: if self.backend.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) @@ -133,48 +141,45 @@ def _start_kv_cache_load_task( mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) - if self.need_sync_compute_stream(): - # TODO fa3 现在必须使用同步模式, 未来需要移除 - torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) - # g_infer_context.get_overlap_stream().synchronize() - - mem_manager = self.backend.model.mem_manager - req_manager = self.backend.model.req_manager - - mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) - page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True) - # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, - # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 - # 这里需要进行pad操作,使操作的页面是完整的。 - _start = page_len_start_list[ready_page_num] - - _end = req.cur_kv_len - assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]" - mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda]) - - assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] - - # 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作, - # 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。 - req.cur_kv_len = req.cur_kv_len + need_token_num - - mem_manager.operator.load_cpu_cache_to_gpu( - mem_indexes=mem_indexes_cuda, - page_indexes=page_indexes_cuda, - cpu_cache_client=self.cpu_cache_client, - req=req, - ) - - sync_event = torch.cuda.Event() - sync_event.record() - - trans_task = LoadTransTask( - req_obj=req, - page_list=page_list, - mem_indexes=mem_indexes, - sync_event=sync_event, - ) - return trans_task + with torch.cuda.stream(load_stream): + mem_manager = self.backend.model.mem_manager + req_manager = self.backend.model.req_manager + + mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) + page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True) + # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, + # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 + # 这里需要进行pad操作,使操作的页面是完整的。 + _start = page_len_start_list[ready_page_num] + + _end = req.cur_kv_len + assert 0 <= _start <= _end, f"invalid pad range [{_start}, {_end}]" + mem_indexes_cuda = torch.cat([req_manager.req_to_token_indexs[req.req_idx, _start:_end], mem_indexes_cuda]) + + assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] + + # 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作, + # 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。 + req.cur_kv_len = req.cur_kv_len + need_token_num + + mem_manager.operator.load_cpu_cache_to_gpu( + mem_indexes=mem_indexes_cuda, + page_indexes=page_indexes_cuda, + cpu_cache_client=self.cpu_cache_client, + req=req, + ) + + sync_event = torch.cuda.Event() + sync_event.record() + + trans_task = LoadTransTask( + req_obj=req, + page_list=page_list, + mem_indexes=mem_indexes, + sync_event=sync_event, + ) + req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.RUNNING + return trans_task def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]: """ @@ -222,7 +227,7 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L offload_stream.wait_stream(g_infer_context.get_overlap_stream()) # 发起将请求的 kv cache 卸载到 cpu cache 中的任务 - trans_task = self._start_kv_cache_offload_task(req=req, cpu_kv_cache_stream=offload_stream) + trans_task = self._start_kv_cache_offload_task(req=req, offload_stream=offload_stream) # 根据是否成功创建了卸载任务,决定是否将请求加入到处理队列中 if trans_task is not None: @@ -237,9 +242,9 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L return true_finished_reqs def _start_kv_cache_offload_task( - self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream + self, req: InferReq, offload_stream: torch.cuda.Stream ) -> Optional["OffloadTransTask"]: - with torch.cuda.stream(cpu_kv_cache_stream): + with torch.cuda.stream(offload_stream): # 综合考虑后只对prompt做缓存管理,不包含decode内容,这里与radix cache不一致 token_hash_list = req.shm_req.token_hash_list.get_all() page_len_list = req.shm_req.token_hash_page_len_list.get_all() From b3bc2b4e0cfab7028a065830c067e22ea760a20f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 26 May 2026 01:27:13 +0000 Subject: [PATCH 06/10] fix --- .../model_infer/mode_backend/multi_level_kv_cache.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 34317ff95..617e16054 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -61,6 +61,10 @@ def need_sync_compute_stream(self) -> bool: def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): load_stream = g_infer_context.get_cpu_kv_cache_load_stream() + # 当前流实际是用于处理req 初始化,radix cache,req manager 信息的流,这里需要等待当前流完成, + # 因为当前流中包含很多对 req_manager.req_to_token_indexs 的操作,而后续的操作中,需要从 + # 中读取相关信息,如果当前流没有完成,则可能导致后续操作中,读取到错误的信息。 + load_stream.wait_stream(torch.cuda.current_stream()) idle_token_num = g_infer_context.get_can_alloc_token_num() need_free_page_list = [] is_master_in_dp = self.backend.is_master_in_dp @@ -195,6 +199,11 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L # 当 kv cache 卸载完成后,才会进行请求的真实退出操作。 true_finished_reqs = [] offload_stream = g_infer_context.get_cpu_kv_cache_offload_stream() + # 当前流实际是用于处理req 初始化,radix cache,req manager 信息的流,这里需要等待当前流完成, + # 因为当前流中包含很多对 req_manager.req_to_token_indexs 的操作,而后续的操作中,需要从 + # 中读取相关信息,如果当前流没有完成,则可能导致后续操作中,读取到错误的信息。 + offload_stream.wait_stream(torch.cuda.current_stream()) + for req in finished_reqs: # 只有 group_req_id 和 request_id 相同的请求才会被卸载到 cpu cache 中。 # 这个限制是为了兼容 diverse 模式下的请求处理, 只有主请求才 offload kv 到 cpu From d39d3ec9e1292cdd8f341e312d2fb6e2352bad33 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 26 May 2026 01:32:29 +0000 Subject: [PATCH 07/10] fix --- .../model_infer/mode_backend/multi_level_kv_cache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 617e16054..f7551935e 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -7,11 +7,9 @@ from typing import Optional, List, Deque, Union from collections import deque from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient -from lightllm.utils.config_utils import is_linear_att_mixed_model from lightllm.utils.envs_utils import get_env_start_args from ..infer_batch import InferReq from lightllm.utils.dist_utils import create_new_group_for_current_dp -from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu, load_cpu_kv_to_gpu from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.log_utils import init_logger @@ -31,6 +29,9 @@ def __init__(self, backend): self.offload_sync_group = create_new_group_for_current_dp("nccl") dist.barrier(group=self.offload_sync_group) self.offload_sync_tensor = torch.empty((1,), dtype=torch.int32, device="cuda") + self.load_sync_group = create_new_group_for_current_dp("nccl") + dist.barrier(group=self.load_sync_group) + self.load_sync_tensor = torch.empty((1,), dtype=torch.int32, device="cuda") self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device="cuda") self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device="cuda") @@ -173,6 +174,10 @@ def _start_kv_cache_load_task( req=req, ) + if self.backend.dp_world_size > 1: + # 这里只是为了做一个同步,让sync_event 完成的时候,各个tp都必然已经完成了。 + dist.all_reduce(self.load_sync_tensor, op=dist.ReduceOp.MAX, group=self.load_sync_group) + sync_event = torch.cuda.Event() sync_event.record() From 8fc2f813512731bda0083abf0234778f1f26b6d1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 26 May 2026 02:14:35 +0000 Subject: [PATCH 08/10] fix --- .../common/kv_cache_mem_manager/operator/base.py | 7 ++++++- .../kv_cache_mem_manager/operator/linear_att.py | 6 +++++- .../common/kv_cache_mem_manager/operator/normal.py | 2 ++ .../common/kv_cache_mem_manager/operator/quant.py | 2 ++ .../model_infer/mode_backend/multi_level_kv_cache.py | 12 +++--------- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/operator/base.py b/lightllm/common/kv_cache_mem_manager/operator/base.py index 682b8a5d6..dadef6829 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/base.py +++ b/lightllm/common/kv_cache_mem_manager/operator/base.py @@ -21,7 +21,12 @@ def copy_mem_to_mem(self, src_mem_index: torch.Tensor, dst_mem_index: torch.Tens # cpu cache 的相关操作接口 def load_cpu_cache_to_gpu( - self, mem_indexes: torch.Tensor, page_indexes: torch.Tensor, cpu_cache_client, req: "InferReq" + self, + move_token_num: int, + mem_indexes: torch.Tensor, + page_indexes: torch.Tensor, + cpu_cache_client, + req: "InferReq", ): raise NotImplementedError() diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 109e81322..4a9426fd8 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -26,11 +26,14 @@ def __init__(self, mem_manager): def load_cpu_cache_to_gpu( self, + move_token_num: int, mem_indexes: torch.Tensor, page_indexes: torch.Tensor, cpu_cache_client: "CpuKvCacheClient", req: "InferReq", ): + # mem_indexes 中包含pad的部分,所以真实需要的move_token_num 是小于 len(mem_indexes) 的。 + assert move_token_num <= len(mem_indexes) assert mem_indexes.is_cuda and page_indexes.is_cuda args = get_env_start_args() assert triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) == len(page_indexes) @@ -41,7 +44,8 @@ def load_cpu_cache_to_gpu( mem_manager: Qwen3NextMemManager = self.mem_manager big_page_num = len(mem_indexes) // args.cpu_cache_token_page_size - max_kv_len = (req.cur_kv_len // args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size + total_kv_len = req.cur_kv_len + move_token_num + max_kv_len = (total_kv_len // args.cpu_cache_token_page_size) * args.cpu_cache_token_page_size assert max_kv_len % args.cpu_cache_token_page_size == 0 big_page_buffer_ids_cpu = [] diff --git a/lightllm/common/kv_cache_mem_manager/operator/normal.py b/lightllm/common/kv_cache_mem_manager/operator/normal.py index 3c53ace07..9d41fc430 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/normal.py +++ b/lightllm/common/kv_cache_mem_manager/operator/normal.py @@ -26,11 +26,13 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: def load_cpu_cache_to_gpu( self, + move_token_num: int, mem_indexes: torch.Tensor, page_indexes: torch.Tensor, cpu_cache_client: "CpuKvCacheClient", req: "InferReq", ): + assert move_token_num <= len(mem_indexes) assert mem_indexes.is_cuda and page_indexes.is_cuda args = get_env_start_args() assert len(mem_indexes) % args.cpu_cache_token_page_size == 0 diff --git a/lightllm/common/kv_cache_mem_manager/operator/quant.py b/lightllm/common/kv_cache_mem_manager/operator/quant.py index a3a1c1d01..0cfe44901 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/quant.py +++ b/lightllm/common/kv_cache_mem_manager/operator/quant.py @@ -19,11 +19,13 @@ class QuantScaleMemOperator(BaseMemManagerOperator): def load_cpu_cache_to_gpu( self, + move_token_num: int, mem_indexes: torch.Tensor, page_indexes: torch.Tensor, cpu_cache_client: "CpuKvCacheClient", req: "InferReq", ): + assert move_token_num <= len(mem_indexes) assert mem_indexes.is_cuda and page_indexes.is_cuda args = get_env_start_args() assert len(mem_indexes) % args.cpu_cache_token_page_size == 0 diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index f7551935e..5617a8e21 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -163,11 +163,8 @@ def _start_kv_cache_load_task( assert len(mem_indexes_cuda) == page_len_list[len(page_list) - 1] - page_len_start_list[ready_page_num] - # 这里需要先更新 cur_kv_len 再进行 load_cpu_cache_to_gpu 操作, - # 因为 load_cpu_cache_to_gpu 操作会使用到 cur_kv_len 的值,主要是linear att 会用到。 - req.cur_kv_len = req.cur_kv_len + need_token_num - mem_manager.operator.load_cpu_cache_to_gpu( + move_token_num=need_token_num, mem_indexes=mem_indexes_cuda, page_indexes=page_indexes_cuda, cpu_cache_client=self.cpu_cache_client, @@ -413,13 +410,10 @@ def update_cpu_cache_load_task_states(self): need_free_page_list.extend(task.page_list) # 更新 req 状态。 req = task.req_obj - assert ( - len(task.mem_indexes) <= req.cur_kv_len - ), f"invalid load task mem_indexes length [{len(task.mem_indexes)}] <= cur_kv_len [{req.cur_kv_len}]" g_infer_context.req_manager.req_to_token_indexs[ - req.req_idx, (req.cur_kv_len - len(task.mem_indexes)) : req.cur_kv_len + req.req_idx, req.cur_kv_len : (req.cur_kv_len + len(task.mem_indexes)) ] = task.mem_indexes - + req.cur_kv_len += len(task.mem_indexes) if self.backend.is_master_in_dp: req.shm_req.shm_cur_kv_len = req.cur_kv_len From a22e1c719e15249bcebe78db70eba714ab6f1081 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 26 May 2026 02:28:33 +0000 Subject: [PATCH 09/10] fix --- .../model_infer/mode_backend/base_backend.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c9b3c3405..09878b460 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -523,6 +523,24 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: """ return [g_infer_context.requests_mapping[request_id] for request_id in req_ids] + def _filter_cpu_cache_task_not_ready_req_ids(self, req_ids: List[int]) -> List[int]: + """ + 过滤出当前正在进行 cpu cache 操作的任务, 对齐进行拦截操作。 + """ + if not self.args.enable_cpu_cache: + return req_ids + + ready_req_ids = [] + for req_id in req_ids: + req: InferReq = g_infer_context.requests_mapping[req_id] + if req.cpu_cache_load_task_status.is_finished(): + if req.cpu_cache_offload_task_status.is_not_started(): + ready_req_ids.append(req_id) + elif req.cpu_cache_offload_task_status.is_finished(): + ready_req_ids.append(req_id) + + return ready_req_ids + def _timer_merge_radix_tree(self): self._radix_tree_merge_counter += 1 if ( @@ -575,6 +593,9 @@ def _get_classed_reqs( if req_ids is None: req_ids = g_infer_context.infer_req_ids + # 过滤出当前正在进行 cpu cache 操作的任务, 对齐进行拦截操作。 + req_ids = self._filter_cpu_cache_task_not_ready_req_ids(req_ids) + if len(req_ids) == 0: return [], [] From 7f06d6aab63cbce10a9210b5199d1f4a6625b09a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 26 May 2026 02:42:06 +0000 Subject: [PATCH 10/10] fix --- .../server/router/model_infer/infer_batch.py | 13 ++++-------- .../mode_backend/multi_level_kv_cache.py | 20 +++++++++---------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 07404ffce..93e657ed1 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -490,7 +490,7 @@ def has_constraint_setting(self) -> bool: class InferReq: - class _CpuCacheOffloadTaskStatus(enum.Enum): + class _CpuCacheTaskStatus(enum.Enum): NOT_STARTED = 0 RUNNING = 1 FINISHED = 2 @@ -504,9 +504,6 @@ def is_running(self): def is_finished(self): return self == self.FINISHED - class _CpuCacheLoadTaskStatus(_CpuCacheOffloadTaskStatus): - pass - def __init__( self, req_id: int, @@ -558,11 +555,9 @@ def __init__( # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 if self.args.enable_cpu_cache: - self.cpu_cache_load_task_status: "InferReq._CpuCacheLoadTaskStatus" = ( - InferReq._CpuCacheLoadTaskStatus.NOT_STARTED - ) - self.cpu_cache_offload_task_status: "InferReq._CpuCacheOffloadTaskStatus" = ( - InferReq._CpuCacheOffloadTaskStatus.NOT_STARTED + self.cpu_cache_load_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED + self.cpu_cache_offload_task_status: "InferReq._CpuCacheTaskStatus" = ( + InferReq._CpuCacheTaskStatus.NOT_STARTED ) else: self.cpu_cache_load_task_status = None diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 5617a8e21..3fdfebd97 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -93,7 +93,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): ) if not ok_to_start_load_task: - req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED + req.cpu_cache_load_task_status = InferReq._CpuCacheTaskStatus.FINISHED need_free_page_list.extend(page_list) continue else: @@ -184,7 +184,7 @@ def _start_kv_cache_load_task( mem_indexes=mem_indexes, sync_event=sync_event, ) - req.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.RUNNING + req.cpu_cache_load_task_status = InferReq._CpuCacheTaskStatus.RUNNING return trans_task def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]: @@ -212,13 +212,13 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L # cache 中 if req.shm_req.group_req_id != req.shm_req.request_id: assert req.cpu_cache_offload_task_status.is_not_started() - req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.FINISHED true_finished_reqs.append(req) continue if req.cur_kv_len < offload_limit_size or req.shm_req.input_len <= offload_limit_size: assert req.cpu_cache_offload_task_status.is_not_started() - req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.FINISHED true_finished_reqs.append(req) continue @@ -273,7 +273,7 @@ def _start_kv_cache_offload_task( if move_block_size == 0: dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) - req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.FINISHED return None try: @@ -288,7 +288,7 @@ def _start_kv_cache_offload_task( item_size = len(page_list) if item_size == 0: dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) - req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.FINISHED return None broadcast_data = {"item_size": item_size, "page_list": page_list, "ready_list": ready_list} @@ -297,7 +297,7 @@ def _start_kv_cache_offload_task( recv_list = [None] dist.broadcast_object_list(recv_list, group=self.gloo_group, group_src=0) if isinstance(recv_list[0], int) and recv_list[0] == 0: - req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + req.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.FINISHED return None broadcast_data = recv_list[0] item_size = broadcast_data["item_size"] @@ -332,7 +332,7 @@ def _start_kv_cache_offload_task( sync_event = torch.cuda.Event() sync_event.record() - req.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.RUNNING + req.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.RUNNING trans_task = OffloadTransTask( move_token_num=move_token_num, page_indexes=page_indexes, @@ -399,7 +399,7 @@ def update_cpu_cache_offload_task_states(self): ) self.cpu_cache_client.lock.release() for task in trans_ok_tasks: - task.req_obj.cpu_cache_offload_task_status = InferReq._CpuCacheOffloadTaskStatus.FINISHED + task.req_obj.cpu_cache_offload_task_status = InferReq._CpuCacheTaskStatus.FINISHED return def update_cpu_cache_load_task_states(self): @@ -417,7 +417,7 @@ def update_cpu_cache_load_task_states(self): if self.backend.is_master_in_dp: req.shm_req.shm_cur_kv_len = req.cur_kv_len - task.req_obj.cpu_cache_load_task_status = InferReq._CpuCacheLoadTaskStatus.FINISHED + task.req_obj.cpu_cache_load_task_status = InferReq._CpuCacheTaskStatus.FINISHED if self.backend.is_master_in_dp and need_free_page_list: self.cpu_cache_client.lock.acquire_sleep1ms()