diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 84e7d804c32..ea273d49446 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -32,6 +32,9 @@ class CacheStatus(Enum): CPU = 3 GPU2STORAGE = 4 STORAGE2GPU = 5 + DECODE_OFFLOAD = 6 + DECODE_RESUME = 7 + DECODE_CLEANUP = 8 class BlockNode: diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py index fe15263827a..a34294e28a0 100644 --- a/fastdeploy/cache_manager/cache_tasks.py +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -35,3 +35,20 @@ class ReadStorageTask(CacheTask): @dataclass(frozen=True, kw_only=True) class WriteStorageTask(CacheTask): timeout: float = 30.0 + + +@dataclass(frozen=True, kw_only=True) +class DecodeOffloadTask: + task_id: str + gpu_block_ids: List[int] + + +@dataclass(frozen=True, kw_only=True) +class DecodeResumeTask: + task_id: str + gpu_block_ids: List[int] + + +@dataclass(frozen=True, kw_only=True) +class DecodeCleanupTask: + task_id: str diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 74386f909af..a5434f00c5e 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -31,7 +31,13 @@ from fastdeploy import envs from fastdeploy.cache_manager.cache_data import CacheStatus -from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask +from fastdeploy.cache_manager.cache_tasks import ( + DecodeCleanupTask, + DecodeOffloadTask, + DecodeResumeTask, + ReadStorageTask, + WriteStorageTask, +) from fastdeploy.cache_manager.ops import ( cuda_host_alloc, cuda_host_free, @@ -199,9 +205,13 @@ def __init__(self, args): self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.decode_offload_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.decode_resume_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.decode_cleanup_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 + self.decode_offload_snapshots = {} address = (args.pod_ip, args.cache_queue_port) self.cache_task_queue = EngineCacheQueue( @@ -1035,6 +1045,108 @@ def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIM return True, "" + def _snapshot_blocks_to_cpu(self, gpu_cache_tensors, block_ids: List[int]): + snapshots = [] + for cache_tensor in gpu_cache_tensors: + blocks = [cache_tensor[block_id] for block_id in block_ids] + layer_tensor = paddle.stack(blocks) if len(blocks) > 1 else blocks[0].unsqueeze(0) + snapshots.append(layer_tensor.to("cpu")) + return snapshots + + def _restore_blocks_from_cpu(self, cpu_tensors, gpu_cache_tensors, block_ids: List[int]): + device = f"gpu:{self.device}" + for layer_id, cpu_tensor in enumerate(cpu_tensors): + gpu_tensor = gpu_cache_tensors[layer_id] + gpu_data = cpu_tensor.to(device) + for idx, block_id in enumerate(block_ids): + gpu_tensor[block_id] = gpu_data[idx] + + def decode_offload_task(self, task: DecodeOffloadTask): + ok = False + meta = { + "rank": self.rank, + "num_blocks": len(task.gpu_block_ids), + "storage_level": "L2", + } + try: + if not task.gpu_block_ids: + raise ValueError(f"decode offload task {task.task_id} has empty gpu_block_ids") + snapshot = { + "key_caches": self._snapshot_blocks_to_cpu(self.gpu_cache_k_tensors, task.gpu_block_ids), + "value_caches": ( + self._snapshot_blocks_to_cpu(self.gpu_cache_v_tensors, task.gpu_block_ids) + if self.gpu_cache_v_tensors + else [] + ), + "key_scales": ( + self._snapshot_blocks_to_cpu(self.gpu_cache_scales_k_tensors, task.gpu_block_ids) + if self.gpu_cache_scales_k_tensors + else [] + ), + "value_scales": ( + self._snapshot_blocks_to_cpu(self.gpu_cache_scales_v_tensors, task.gpu_block_ids) + if self.gpu_cache_scales_v_tensors + else [] + ), + } + self.decode_offload_snapshots[task.task_id] = snapshot + ok = True + except Exception as e: + meta["error"] = str(e) + logger.error( + f"decode_offload_task failed for {task.task_id}, error: {e}, traceback:\n{traceback.format_exc()}" + ) + finally: + result = (CacheStatus.DECODE_OFFLOAD, task.task_id, self.rank, ok, meta) + self.cache_task_queue.put_transfer_done_signal(result) + + def decode_resume_task(self, task: DecodeResumeTask): + ok = False + meta = { + "rank": self.rank, + "num_blocks": len(task.gpu_block_ids), + } + try: + if task.task_id not in self.decode_offload_snapshots: + raise KeyError(f"snapshot for {task.task_id} not found") + snapshot = self.decode_offload_snapshots[task.task_id] + self._restore_blocks_from_cpu(snapshot["key_caches"], self.gpu_cache_k_tensors, task.gpu_block_ids) + if self.gpu_cache_v_tensors and snapshot["value_caches"]: + self._restore_blocks_from_cpu(snapshot["value_caches"], self.gpu_cache_v_tensors, task.gpu_block_ids) + if self.gpu_cache_scales_k_tensors and snapshot["key_scales"]: + self._restore_blocks_from_cpu( + snapshot["key_scales"], self.gpu_cache_scales_k_tensors, task.gpu_block_ids + ) + if self.gpu_cache_scales_v_tensors and snapshot["value_scales"]: + self._restore_blocks_from_cpu( + snapshot["value_scales"], self.gpu_cache_scales_v_tensors, task.gpu_block_ids + ) + del self.decode_offload_snapshots[task.task_id] + ok = True + except Exception as e: + meta["error"] = str(e) + logger.error( + f"decode_resume_task failed for {task.task_id}, error: {e}, traceback:\n{traceback.format_exc()}" + ) + finally: + result = (CacheStatus.DECODE_RESUME, task.task_id, self.rank, ok, meta) + self.cache_task_queue.put_transfer_done_signal(result) + + def decode_cleanup_task(self, task: DecodeCleanupTask): + ok = False + meta = {"rank": self.rank} + try: + self.decode_offload_snapshots.pop(task.task_id, None) + ok = True + except Exception as e: + meta["error"] = str(e) + logger.error( + f"decode_cleanup_task failed for {task.task_id}, error: {e}, traceback:\n{traceback.format_exc()}" + ) + finally: + result = (CacheStatus.DECODE_CLEANUP, task.task_id, self.rank, ok, meta) + self.cache_task_queue.put_transfer_done_signal(result) + def submit_task(self, thread_pool: concurrent.futures.ThreadPoolExecutor, task_fn, *args): def inflight_task(fn, *args): @@ -1129,6 +1241,27 @@ def do_data_transfer(self): self.write_back_storage_task, write_storage_task, ) + elif event_type.value == CacheStatus.DECODE_OFFLOAD.value: + decode_offload_task = event_args[0] + self.submit_task( + self.decode_offload_thread_pool, + self.decode_offload_task, + decode_offload_task, + ) + elif event_type.value == CacheStatus.DECODE_RESUME.value: + decode_resume_task = event_args[0] + self.submit_task( + self.decode_resume_thread_pool, + self.decode_resume_task, + decode_resume_task, + ) + elif event_type.value == CacheStatus.DECODE_CLEANUP.value: + decode_cleanup_task = event_args[0] + self.submit_task( + self.decode_cleanup_thread_pool, + self.decode_cleanup_task, + decode_cleanup_task, + ) else: if self.n_ranks > 1: self.cache_task_queue.barrier2.wait() diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index b022f61c26e..0fcabc7f204 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -97,6 +97,7 @@ def __init__( self.task_write_back_event = {} self.task_prefetch_event = {} self.storage_prefetch_block_ids = {} + self.transfer_result_handlers = [] # gpu cache data structure self.gpu_lru_leaf_heap = [] @@ -293,7 +294,11 @@ def launch_cache_manager( else: storage_arg_str = " " - if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend: + if ( + self.cache_config.swap_space + or self.cache_config.kvcache_storage_backend + or getattr(self.config, "enable_decode_offload", False) + ): for i in range(tensor_parallel_size): launch_cmd = ( "FLAGS_allocator_strategy=auto_growth " @@ -352,7 +357,11 @@ def launch_cache_manager( ) # Start additional threads - if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0: + if ( + cache_config.kvcache_storage_backend + or self.num_cpu_blocks > 0 + or getattr(self.config, "enable_decode_offload", False) + ): logger.info("Enable hierarchical cache.") threading.Thread(target=self.recv_data_transfer_result, daemon=True).start() if cache_config.enable_prefix_caching: @@ -1191,6 +1200,10 @@ def wait_prefetch_storage_task(self, req_id): del self.storage_prefetch_block_ids[req_id] return storage_block_ids + def register_transfer_result_handler(self, handler): + if handler not in self.transfer_result_handlers: + self.transfer_result_handlers.append(handler) + def free_nodes_directly(self, node): with self.request_release_lock: try: @@ -2058,6 +2071,16 @@ def recv_data_transfer_result(self): time.sleep(0.001) continue event_type = data[0] + handled = False + for handler in self.transfer_result_handlers: + try: + if handler(data): + handled = True + break + except Exception as e: + logger.warning(f"transfer result handler failed: {e}") + if handled: + continue if event_type.value == CacheStatus.STORAGE2GPU.value: logger.info(f"recv_data_transfer_result: {data}") diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6d31e5ca616..91400d5af1f 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1805,6 +1805,7 @@ def __init__( tool_parser: str = None, test_mode=False, routing_replay_config: Optional[RoutingReplayConfig] = None, + enable_decode_offload: bool = False, ): self.model_config: ModelConfig = model_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore @@ -1821,6 +1822,7 @@ def __init__( self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config self.router_config: RouterConfig = router_config self.routing_replay_config = routing_replay_config + self.enable_decode_offload = enable_decode_offload # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 889b11cbdc2..c71815d9e77 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -551,6 +551,11 @@ class EngineArgs: Flag to enable prefill_use_worst_num_tokens. Default is False (disabled). """ + enable_decode_offload: bool = False + """ + Flag to enable decode offload. Default is False (disabled). + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -1071,6 +1076,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.ep_prefill_use_worst_num_tokens, help="Enable prefill use worst num tokens for EP.", ) + parallel_group.add_argument( + "--enable-decode-offload", + action="store_true", + default=EngineArgs.enable_decode_offload, + help="Enable decode offload.", + ) # Load group load_group = parser.add_argument_group("Load Configuration") @@ -1514,4 +1525,5 @@ def create_engine_config(self) -> FDConfig: plas_attention_config=plas_attention_config, early_stop_config=early_stop_cfg, routing_replay_config=routing_replay_config, + enable_decode_offload=self.enable_decode_offload, ) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 5a060a1cade..df707eddeff 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1028,18 +1028,24 @@ def _fetch_request(): if self.cfg.scheduler_config.splitwise_role == "decode": for task in tasks: if task.task_type == RequestType.PREEMPTED: - msg = f"{task.request_id} decode not enough blocks, need to be rescheduled." - self.llm_logger.error(msg) - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=msg, - ) - ] - ) + req = self.resource_manager.requests.get(task.request_id) + if req is not None and req.is_offloaded: + self.llm_logger.info( + f"{task.request_id} decode request is preempted and offloaded, waiting for resume." + ) + else: + msg = f"{task.request_id} decode not enough blocks, need to be rescheduled." + self.llm_logger.error(msg) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) self.resource_manager.get_real_bsz() for task in tasks: if task.task_type == RequestType.PREFILL: diff --git a/fastdeploy/engine/offload_manager.py b/fastdeploy/engine/offload_manager.py new file mode 100644 index 00000000000..dd2e17070b0 --- /dev/null +++ b/fastdeploy/engine/offload_manager.py @@ -0,0 +1,358 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import time +from typing import Dict, List, Optional, Tuple + +from fastdeploy import envs +from fastdeploy.cache_manager.cache_data import CacheStatus +from fastdeploy.cache_manager.cache_tasks import DecodeCleanupTask, DecodeOffloadTask, DecodeResumeTask +from fastdeploy.engine.request import Request, RequestStatus +from fastdeploy.utils import offload_logger + + +class OffloadManager: + """ + Decode request KV cache offload orchestrator. + + Real KV cache snapshot/restore is executed inside cache_transfer_manager. + This class only manages request-level state, retry policy and task/result + synchronization. + """ + + STORAGE_LEVEL_CPU = "L2" + STORAGE_LEVEL_SSD = "L3" + + def __init__(self, config=None, cache_manager=None, model_runner=None): + self.config = config + self.cache_manager = cache_manager + self.model_runner = model_runner + + self.enable_offload = getattr(config, "enable_decode_offload", False) if config else False + self.min_steps = 20 + self.cpu_offloading_chunk_size = getattr(envs, "FD_CPU_OFFLOAD_CHUNK_SIZE", 8192) + self.cpu_memory_limit = getattr(envs, "FD_CPU_MEMORY_LIMIT", 50 * 1024 * 1024 * 1024) + + self._offloaded_requests: Dict[str, dict] = {} + self._lock = threading.Lock() + self._transfer_events: Dict[Tuple[int, str], threading.Event] = {} + self._transfer_results: Dict[Tuple[int, str], list] = {} + self._tensor_parallel_size = getattr(getattr(config, "parallel_config", None), "tensor_parallel_size", 1) + + offload_logger.info( + f"[DEBUG: offload] OffloadManager initialized: enable_offload={self.enable_offload}, " + f"min_steps={self.min_steps}" + ) + if self.cache_manager is not None and hasattr(self.cache_manager, "register_transfer_result_handler"): + self.cache_manager.register_transfer_result_handler(self._handle_transfer_result) + + def _transfer_key(self, event_type, task_id: str) -> Tuple[int, str]: + return (event_type.value, task_id) + + def _handle_transfer_result(self, data) -> bool: + event_type = data[0] + if event_type.value not in ( + CacheStatus.DECODE_OFFLOAD.value, + CacheStatus.DECODE_RESUME.value, + CacheStatus.DECODE_CLEANUP.value, + ): + return False + + task_id, rank, ok, meta = data[1:] + key = self._transfer_key(event_type, task_id) + with self._lock: + if key not in self._transfer_results: + self._transfer_results[key] = [] + self._transfer_results[key].append( + { + "rank": rank, + "ok": ok, + "meta": meta, + } + ) + if len(self._transfer_results[key]) >= self._tensor_parallel_size: + event = self._transfer_events.get(key) + if event is not None: + event.set() + return True + + def _issue_transfer_task(self, event_type, task): + if self.cache_manager is None or not hasattr(self.cache_manager, "cache_task_queue"): + return None + + key = self._transfer_key(event_type, task.task_id) + event = threading.Event() + with self._lock: + self._transfer_events[key] = event + self._transfer_results.pop(key, None) + self.cache_manager.cache_task_queue.put_transfer_task((event_type, task)) + event.wait(timeout=30) + if not event.is_set(): + offload_logger.error(f"Transfer task {task.task_id} timed out after 30s") + with self._lock: + self._transfer_results.pop(key, None) + self._transfer_events.pop(key, None) + return None + with self._lock: + results = self._transfer_results.pop(key, []) + self._transfer_events.pop(key, None) + return { + "ok": bool(results) and all(item["ok"] for item in results), + "results": results, + } + + def can_offload(self, request: Request) -> bool: + if not self.enable_offload: + return False + if request.is_offloaded: + return False + if not request.block_tables: + return False + if request.need_prefill_tokens is None: + offload_logger.warning( + f"[DEBUG: can_offload] {request.request_id}: need_prefill_tokens is None, cannot offload" + ) + return False + if request.num_computed_tokens < request.need_prefill_tokens: + offload_logger.warning( + f"[DEBUG: can_offload] {request.request_id} is not in decode phase, " + f"num_computed_tokens={request.num_computed_tokens}, " + f"need_prefill_tokens={request.need_prefill_tokens}, cannot offload" + ) + return False + return True + + def can_resume(self, request: Request) -> bool: + if not self.enable_offload: + return False + if request.request_id not in self._offloaded_requests: + return False + + offloaded_info = self._offloaded_requests.get(request.request_id) + if offloaded_info is None or offloaded_info.get("snapshot_handle") is None: + return False + if self.cache_manager is None: + return False + + return self.cache_manager.can_allocate_gpu_blocks(offloaded_info.get("num_blocks_needed", 0)) + + def offload_decode(self, running_requests: List[Request], min_steps: int = 20) -> Tuple[List[Request], List[Request]]: + if not self.enable_offload: + return [], [] + + offloaded_reqs = [] + abort_reqs = [] + remaining_count = len(running_requests) + + for req in running_requests: + if not self.can_offload(req): + continue + + if self.offload_req(req): + offloaded_reqs.append(req) + remaining_count -= 1 + else: + abort_reqs.append(req) + + if self.cache_manager is not None and remaining_count > 0: + block_size = self.cache_manager.cache_config.block_size + blocks_needed_per_request = (min_steps + block_size - 1) // block_size + total_blocks_needed = remaining_count * blocks_needed_per_request + current_free_blocks = len(getattr(self.cache_manager, "gpu_free_block_list", [])) + if current_free_blocks >= total_blocks_needed: + break + + return offloaded_reqs, abort_reqs + + def offload_req(self, request: Request) -> bool: + if not self.enable_offload or self.cache_manager is None: + return False + if request.is_offloaded: + offload_logger.warning(f"[DEBUG: offload_req] Request {request.request_id} already offloaded") + return False + + start_time = time.perf_counter() + snapshot_task = DecodeOffloadTask(task_id=request.request_id, gpu_block_ids=list(request.block_tables)) + snapshot_result = self._issue_transfer_task(CacheStatus.DECODE_OFFLOAD, snapshot_task) + if snapshot_result is None or not snapshot_result.get("ok", False): + elapsed_ms = (time.perf_counter() - start_time) * 1000 + offload_logger.error( + f"[DEBUG: offload_req] Failed to snapshot request {request.request_id}, " + f"elapsed_ms={elapsed_ms:.2f}, result={snapshot_result}" + ) + return False + + with self._lock: + need_prefill_tokens = request.need_prefill_tokens + if need_prefill_tokens is None: + need_prefill_tokens = request.prompt_token_ids_len if request.prompt_token_ids_len else 0 + original_block_tables = list(request.block_tables) if request.block_tables else [] + self._offloaded_requests[request.request_id] = { + "storage_level": self.STORAGE_LEVEL_CPU, + "num_tokens": request.num_total_tokens, + "num_blocks_needed": len(original_block_tables), + "output_token_ids": list(request.output_token_ids), + "num_computed_tokens": request.num_computed_tokens, + "need_prefill_tokens": need_prefill_tokens, + "prompt_token_ids": list(request.prompt_token_ids) if request.prompt_token_ids else None, + "prompt_token_ids_len": request.prompt_token_ids_len, + "sampling_params": request.sampling_params, + "block_tables": original_block_tables, + "snapshot_handle": request.request_id, + } + + self.release_gpu_blocks(request) + request.is_offloaded = True + elapsed_ms = (time.perf_counter() - start_time) * 1000 + offload_logger.info( + f"[DEBUG: offload_req] Request {request.request_id} offloaded to {self.STORAGE_LEVEL_CPU}, " + f"blocks_needed={len(original_block_tables)}, offload_time_ms={elapsed_ms:.2f}" + ) + return True + + def offload_kv_cache(self, request: Request, target_level: str = "L2") -> bool: + """ + Compatibility shim for future multi-level offload. + """ + if target_level == self.STORAGE_LEVEL_CPU: + return self.offload_req(request) + if target_level == self.STORAGE_LEVEL_SSD: + offload_logger.warning("[DEBUG: offload_kv_cache] SSD offload is not implemented in the first version") + return False + offload_logger.error(f"[DEBUG: offload_kv_cache] Invalid target_level: {target_level}") + return False + + def release_gpu_blocks(self, request: Request) -> None: + if self.cache_manager is None: + return + if request.block_tables: + blocks_to_release = list(request.block_tables) + self.cache_manager.recycle_gpu_blocks(blocks_to_release, request.request_id) + request.block_tables = [] + + def save_to_storage(self, kv_cache_cpu) -> Optional[str]: + """Compatibility placeholder for future SSD offload support.""" + offload_logger.warning("[DEBUG: save_to_storage] SSD offload is not implemented in the first version") + return None + + def load_from_storage(self, storage_path: str) -> Optional[dict]: + """Compatibility placeholder for future SSD resume support.""" + offload_logger.warning("[DEBUG: load_from_storage] SSD resume is not implemented in the first version") + return None + + def resume_decode(self, request: Request) -> Tuple[bool, Optional[int]]: + if not self.enable_offload: + return False, None + + start_time = time.perf_counter() + with self._lock: + offloaded_info = self._offloaded_requests.get(request.request_id) + if offloaded_info is None: + offload_logger.warning(f"[DEBUG: resume_decode] Request {request.request_id} is not offloaded") + return False, None + + num_blocks_needed = offloaded_info["num_blocks_needed"] + saved_num_computed_tokens = offloaded_info["num_computed_tokens"] + saved_need_prefill_tokens = offloaded_info["need_prefill_tokens"] + snapshot_handle = offloaded_info.get("snapshot_handle") + output_token_ids = list(offloaded_info.get("output_token_ids", [])) + need_prefill_tokens = offloaded_info.get("need_prefill_tokens") + + if saved_num_computed_tokens <= saved_need_prefill_tokens: + offload_logger.warning( + f"[DEBUG: resume_decode] Request {request.request_id} has invalid state: " + f"num_computed_tokens={saved_num_computed_tokens} <= need_prefill_tokens={saved_need_prefill_tokens}" + ) + return False, saved_num_computed_tokens + if self.cache_manager is None: + return False, saved_num_computed_tokens + if not self.cache_manager.can_allocate_gpu_blocks(num_blocks_needed): + offload_logger.debug( + f"[DEBUG: resume_decode] Not enough GPU blocks for {request.request_id}, " + f"need={num_blocks_needed}, will retry later" + ) + return False, saved_num_computed_tokens + + try: + if snapshot_handle is None: + offload_logger.warning( + f"[DEBUG: resume_decode] Request {request.request_id} has no snapshot handle" + ) + return False, saved_num_computed_tokens + + new_block_ids = self.cache_manager.allocate_gpu_blocks(num_blocks_needed, request.request_id) + request.block_tables = new_block_ids + resume_task = DecodeResumeTask(task_id=snapshot_handle, gpu_block_ids=new_block_ids) + resume_result = self._issue_transfer_task(CacheStatus.DECODE_RESUME, resume_task) + if resume_result is None or not resume_result.get("ok", False): + elapsed_ms = (time.perf_counter() - start_time) * 1000 + self.cache_manager.recycle_gpu_blocks(new_block_ids, request.request_id) + request.block_tables = [] + offload_logger.warning( + f"[DEBUG: resume_decode] Resume transfer failed for {request.request_id}, " + f"elapsed_ms={elapsed_ms:.2f}, result={resume_result}" + ) + return False, saved_num_computed_tokens + + request.output_token_ids = output_token_ids + request.num_computed_tokens = saved_num_computed_tokens + request.need_prefill_tokens = need_prefill_tokens + request.status = RequestStatus.RUNNING + request.is_offloaded = False + + with self._lock: + self._offloaded_requests.pop(request.request_id, None) + + elapsed_ms = (time.perf_counter() - start_time) * 1000 + offload_logger.info( + f"[DEBUG: resume_decode] Request {request.request_id} resumed successfully, " + f"resume_time_ms={elapsed_ms:.2f}" + ) + return True, saved_num_computed_tokens + except Exception as e: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + offload_logger.error( + f"[DEBUG: resume_decode] Failed to resume request {request.request_id}, " + f"elapsed_ms={elapsed_ms:.2f}: {e}" + ) + return False, saved_num_computed_tokens + + def cleanup_offloaded_request(self, request_id: str) -> None: + with self._lock: + offloaded_info = self._offloaded_requests.pop(request_id, None) + if offloaded_info is None: + return + + snapshot_handle = offloaded_info.get("snapshot_handle") + if self.cache_manager is not None and snapshot_handle is not None: + try: + self._issue_transfer_task(CacheStatus.DECODE_CLEANUP, DecodeCleanupTask(task_id=snapshot_handle)) + except Exception as e: + offload_logger.warning(f"[DEBUG: offload] Failed to cleanup snapshot {snapshot_handle}: {e}") + offload_logger.info(f"[DEBUG: offload] Cleaned up offloaded request: {request_id}") + + def get_offloaded_request_count(self) -> int: + with self._lock: + return len(self._offloaded_requests) + + def get_offloaded_request_ids(self) -> List[str]: + with self._lock: + return list(self._offloaded_requests.keys()) + + def prefetch_ssd_to_cpu(self) -> int: + """Compatibility placeholder for future SSD prefetch support.""" + return 0 diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 1e2a53ed205..d17ddf6c7c8 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -185,6 +185,7 @@ def __init__( self.status = RequestStatus.WAITING self.task_type = RequestType.PREFILL self.has_been_preempted_before = False + self.is_offloaded = False self.idx = None self.need_prefill_tokens = self.prompt_token_ids_len self.audio_output_token_ids = [] diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 9205f1c05c5..3db479710a7 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -49,7 +49,12 @@ from fastdeploy.spec_decode import SpecMethod from fastdeploy.trace.constants import LoggingEventName from fastdeploy.trace.trace_logger import print as trace_print -from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger +from fastdeploy.utils import ( + download_from_bos, + init_bos_client, + llm_logger, + offload_logger, +) @dataclass @@ -220,6 +225,13 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 + # OffloadManager for decode instances + self.offload_manager = None + if config.scheduler_config.splitwise_role == "decode" and getattr(config, "enable_decode_offload", False): + from fastdeploy.engine.offload_manager import OffloadManager + + self.offload_manager = OffloadManager(config, self.cache_manager, None) + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -246,6 +258,31 @@ def _prepare_decode_task(self, request): def _prepare_preempt_task(self, request): return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) + def _get_pending_preempt_slots(self) -> set[int]: + pending_slots = set() + request_id_set = getattr(self, "to_be_rescheduled_request_id_set", set()) + requests = getattr(self, "requests", {}) + for request_id in request_id_set: + request = requests.get(request_id) + if request is not None and request.idx is not None: + pending_slots.add(request.idx) + return pending_slots + + def _assign_rescheduled_slot(self, request): + allocated_position = self.get_available_position() + request.idx = allocated_position + allocated_position = request.idx + if allocated_position is None: + allocated_position = self.get_available_position() + request.idx = allocated_position + self.tasks_list[allocated_position] = request + self.stop_flags[allocated_position] = False + self.req_dict[request.request_id] = allocated_position + return allocated_position + + def available_batch(self): + return max(super().available_batch() - len(self._get_pending_preempt_slots()), 0) + def reschedule_preempt_task(self, request_id, process_func=None): with self.lock: llm_logger.debug(f"reschedule {request_id} into waiting queue") @@ -255,6 +292,7 @@ def reschedule_preempt_task(self, request_id, process_func=None): request.metrics.preempted_count += 1 if process_func is not None: process_func(request) + request.idx = None llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}") self.waiting.appendleft(request) self.to_be_rescheduled_request_id_set.remove(request_id) @@ -322,19 +360,42 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re if preempted_req.use_extend_tables: self.running.insert(0, preempted_req) continue + + # Try offload for decode instance requests in decode phase + is_decode_phase = ( + preempted_req.num_computed_tokens >= preempted_req.need_prefill_tokens + if preempted_req.need_prefill_tokens is not None + else False + ) + offloaded = False + if ( + self.config.scheduler_config.splitwise_role == "decode" + and is_decode_phase + and self.offload_manager is not None + and self.offload_manager.can_offload(preempted_req) + ): + if self.offload_manager.offload_req(preempted_req): + offloaded = True + offload_logger.info( + f"Request {preempted_req.request_id} offloaded before preempt, " + f"tokens={preempted_req.num_computed_tokens}" + ) + preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 + if not offloaded: + preempted_req.num_computed_tokens = 0 if self.config.scheduler_config.splitwise_role == "decode": self.tasks_list[preempted_req.idx] = None self.stop_flags[preempted_req.idx] = True - if preempted_req.request_id in self.requests: - del self.requests[preempted_req.request_id] if preempted_req.request_id in self.req_dict: del self.req_dict[preempted_req.request_id] - self._free_blocks(preempted_req) + if not offloaded: + self._free_blocks(preempted_req) + self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") else: - self._free_blocks(preempted_req) + if not offloaded: + self._free_blocks(preempted_req) preempted_req.num_cached_blocks = 0 self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) trace_print( @@ -469,6 +530,10 @@ def _get_num_new_tokens(self, request, token_budget): f"need_prefill={request.need_prefill_tokens}, computed={request.num_computed_tokens}" ) num_new_tokens = min(num_new_tokens, token_budget) + decode_chunk_limit = None + if self.config.scheduler_config.splitwise_role == "decode": + decode_chunk_limit = self.config.get_max_chunk_tokens(self.config.model_config.mm_max_tokens_per_item) + num_new_tokens = min(num_new_tokens, decode_chunk_limit) # Deterministic mode: align chunk boundaries to split_kv_size # This ensures batch-invariant attention by making each chunk @@ -509,6 +574,8 @@ def _get_num_new_tokens(self, request, token_budget): request.with_image = False if not self.config.model_config.enable_mm: + if decode_chunk_limit is not None: + num_new_tokens = min(num_new_tokens, decode_chunk_limit) return num_new_tokens inputs = request.multimodal_inputs @@ -679,6 +746,8 @@ def _compute_audio_prefix_count(end_idx, end_patch_idx): request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions) # Compatible with scenarios without images and videos. + if decode_chunk_limit is not None: + num_new_tokens = min(num_new_tokens, decode_chunk_limit) return num_new_tokens def exist_mm_prefill(self, scheduled_reqs): @@ -956,6 +1025,38 @@ def _allocate_decode_and_extend(): self._free_blocks(request) break elif request.status == RequestStatus.PREEMPTED: + # Try to resume offloaded request first + if request.is_offloaded and self.offload_manager is not None: + # Only attempt resume when running requests have finished + # or there are enough free blocks to sustain all requests. + # This prevents thrashing (immediate re-preempt after resume). + if len(self.running) > 0: + offloaded_info = self.offload_manager._offloaded_requests.get(request.request_id) + num_blocks_for_resume = offloaded_info["num_blocks_needed"] if offloaded_info else 0 + block_size = self.cache_manager.cache_config.block_size + min_steps = self.offload_manager.min_steps + blocks_per_step = (min_steps + block_size - 1) // block_size + total_running_after_resume = len(self.running) + 1 + total_blocks_needed = num_blocks_for_resume + total_running_after_resume * blocks_per_step + free_blocks = len(getattr(self.cache_manager, "gpu_free_block_list", [])) + if free_blocks < total_blocks_needed: + # Not enough blocks to resume without thrashing, wait for running to finish + break + + resume_success, _ = self.offload_manager.resume_decode(request) + if resume_success: + offload_logger.info(f"Resumed offloaded request {request.request_id}") + self.waiting.popleft() + self._assign_rescheduled_slot(request) + self.running.append(request) + scheduled_reqs.append(self._prepare_decode_task(request)) + continue + else: + offload_logger.debug( + f"Failed to resume offloaded request {request.request_id}, will retry" + ) + break + request.need_prefill_tokens = ( request.num_total_tokens ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct @@ -994,6 +1095,7 @@ def _allocate_decode_and_extend(): ) request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() + self._assign_rescheduled_slot(request) self.running.append(request) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens @@ -1138,9 +1240,10 @@ def download_bos_features(bos_client, features_urls): inputs["audio_features"] = result def get_available_position(self) -> int: + pending_preempt_slots = self._get_pending_preempt_slots() position = 0 while position < self.max_num_seqs: - if self.stop_flags[position] is True: + if self.stop_flags[position] is True and position not in pending_preempt_slots: return position position += 1 raise RuntimeError("No available position is available for new request") @@ -1412,6 +1515,8 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): if request is None: llm_logger.error(f"invalid request id: {req_id} self.requests: {self.requests}") continue + if self.offload_manager is not None: + self.offload_manager.cleanup_offloaded_request(req_id) if request in self.waiting: llm_logger.error(f"request {request.request_id} scheduled into waiting list, after finished") continue diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 3ab5061e78e..9fb3c292a2e 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -169,6 +169,12 @@ def _validate_split_kv_size(value: int) -> int: "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), "FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")), + # Enable decode KV cache offload for preempted requests + "FD_ENABLE_DECODE_OFFLOAD": lambda: bool(int(os.getenv("FD_ENABLE_DECODE_OFFLOAD", "0"))), + # CPU memory limit in bytes for offload + "FD_CPU_MEMORY_LIMIT": lambda: int(os.getenv("FD_CPU_MEMORY_LIMIT", str(50 * 1024 * 1024 * 1024))), + "FD_CPU_OFFLOAD_CHUNK_SIZE": lambda: int(os.getenv("FD_CPU_OFFLOAD_CHUNK_SIZE", "8192")), + "FD_OFFLOAD_STORAGE_PATH": lambda: os.getenv("FD_OFFLOAD_STORAGE_PATH", "/tmp/fastdeploy_offload"), "FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))), "FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")), # "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU" diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 662b1cfc8a3..b5f72860601 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -201,6 +201,8 @@ def _reschedule_preempt_task_use_zmq(self, datas): batch_id_set.add(data.batch_id) llm_logger.debug(f"_reschedule_preempt_task_use_zmq batch_id_set {batch_id_set}") for request_id in need_to_be_reschedule_req_ids: + if request_id not in self.resource_manager.requests: + continue if ( self.resource_manager.requests[request_id].idx not in batch_id_set ): # No more token generated for preempted request @@ -212,6 +214,15 @@ def _reschedule_preempt_task_use_zmq(self, datas): f"finish reschedule_preempt_task request_id {request_id} at {self.resource_manager.requests[request_id].idx}" ) + def _find_pending_preempt_request_id_by_idx(self, idx: int): + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + return None + for request_id in list(self.resource_manager.to_be_rescheduled_request_id_set): + request = self.resource_manager.requests.get(request_id) + if request is not None and getattr(request, "idx", None) == idx: + return request_id + return None + def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: RequestOutput, is_prefill: bool): """ process output token by token @@ -278,6 +289,13 @@ def _process_batch_output_use_zmq(self, receive_datas): for _, stream_data in enumerate(receive_datas): i = stream_data.batch_id if self.resource_manager.stop_flags[i]: + pending_preempt_request_id = self._find_pending_preempt_request_id_by_idx(i) + if pending_preempt_request_id is None: + continue + token_ids = stream_data.tokens + if token_ids is not None and token_ids[-1] == PREEMPTED_TOKEN_ID: + llm_logger.info(f"sync preemption for request_id {pending_preempt_request_id} done.") + self.resource_manager.reschedule_preempt_task(pending_preempt_request_id) continue task: Request = self.resource_manager.tasks_list[i] @@ -740,7 +758,21 @@ def _process_batch_output(self): batch_result = list() # reschedule for i in range(batch): + pending_preempt_request_id = None if self.resource_manager.stop_flags[i]: + pending_preempt_request_id = self._find_pending_preempt_request_id_by_idx(i) + if pending_preempt_request_id is None: + continue + if self.cfg.speculative_config.method: + if accept_num[i] == PREEMPTED_TOKEN_ID: + llm_logger.info(f"sync preemption for request_id {pending_preempt_request_id} done.") + self.resource_manager.reschedule_preempt_task(pending_preempt_request_id) + continue + + token_id = int(tokens[i, 0]) + if token_id == PREEMPTED_TOKEN_ID: + llm_logger.info(f"sync preemption for request_id {pending_preempt_request_id} done.") + self.resource_manager.reschedule_preempt_task(pending_preempt_request_id) continue recovery_stop = False @@ -821,6 +853,7 @@ def _process_batch_output(self): if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set and token_id == PREEMPTED_TOKEN_ID + and task_id in self.resource_manager.requests ): llm_logger.info(f"sync preemption for request_id {task_id} done.") self.resource_manager.reschedule_preempt_task(task_id) @@ -979,6 +1012,19 @@ def _process_batch_output(self): llm_logger.debug(f"get response from infer: {result}") batch_result.append(result) + # Reschedule preempted requests whose idx >= batch (not covered by range(batch)) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + batch_id_set = set(range(batch)) + for request_id in list(self.resource_manager.to_be_rescheduled_request_id_set): + if request_id not in self.resource_manager.requests: + continue + req = self.resource_manager.requests[request_id] + if getattr(req, "idx", None) not in batch_id_set: + llm_logger.debug( + f"reschedule_preempt_task request_id {request_id} at idx {getattr(req, 'idx', None)} (out of batch range {batch})" + ) + self.resource_manager.reschedule_preempt_task(request_id) + if self.cfg.speculative_config.method: self._record_speculative_decoding_metrics(accept_num) self.postprocess(batch_result, mtype) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index f09082364f4..06a79251e82 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -1163,6 +1163,7 @@ def _bos_download(bos_client, link): router_logger = get_logger("router", "router.log") fmq_logger = get_logger("fmq", "fmq.log") obj_logger = get_logger("obj", "obj.log") # debug内存问题 +offload_logger = get_logger("offload", "offload_manager.log") # debug offload def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 940e37a9421..edd58631870 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -868,6 +868,11 @@ def parse_args(): action="store_true", help="enable chunked moe", ) + parser.add_argument( + "--enable_decode_offload", + action="store_true", + help="enable decode KV cache offload for preempted requests", + ) parser.add_argument( "--chunked_moe_size", type=int, @@ -1223,6 +1228,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: scheduler_config=scheduler_config, ips=args.ips, plas_attention_config=plas_attention_config, + enable_decode_offload=args.enable_decode_offload, structured_outputs_config=structured_outputs_config, eplb_config=eplb_config, routing_replay_config=routing_replay_config, diff --git a/tests/engine/test_resource_manager_v1.py b/tests/engine/test_resource_manager_v1.py index 0031a2e4f69..3e6459dacbe 100644 --- a/tests/engine/test_resource_manager_v1.py +++ b/tests/engine/test_resource_manager_v1.py @@ -99,6 +99,84 @@ def test_preempted_all_with_normal_requests(self): self.assertEqual(len(self.manager.waiting), 0) self.assertEqual(len(self.manager.to_be_rescheduled_request_id_set), 2) + def test_schedule_resumed_offloaded_request_reassigns_slot(self): + engine_args = EngineArgs( + model=MODEL_NAME, + max_model_len=8192, + tensor_parallel_size=1, + engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")), + cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")), + ) + mock_config = engine_args.create_engine_config() + manager = ResourceManagerV1( + max_num_seqs=2, + config=mock_config, + tensor_parallel_size=1, + splitwise_role="decode", + local_data_parallel_id=0, + ) + manager.cache_manager = Mock() + manager.cache_manager.can_allocate_gpu_blocks.return_value = True + manager.offload_manager = Mock() + + request = Mock(spec=Request) + request.request_id = "req-offloaded" + request.status = RequestStatus.PREEMPTED + request.is_offloaded = True + request.idx = 0 + request.block_tables = [7, 8] + request.num_total_tokens = 32 + request.need_prefill_tokens = 16 + request.num_computed_tokens = 32 + + def _resume(req): + req.block_tables = [11, 12] + return True, req.num_computed_tokens + + manager.offload_manager.resume_decode.side_effect = _resume + manager.requests[request.request_id] = request + manager.waiting.append(request) + manager.tasks_list[0] = Mock() + manager.stop_flags[0] = False + manager.tasks_list[1] = None + manager.stop_flags[1] = True + + scheduled_reqs, error_reqs = manager.schedule() + + self.assertEqual(error_reqs, []) + self.assertEqual(len(scheduled_reqs), 1) + self.assertEqual(scheduled_reqs[0].request_id, request.request_id) + self.assertEqual(scheduled_reqs[0].idx, 1) + self.assertEqual(request.idx, 1) + self.assertEqual(manager.req_dict[request.request_id], 1) + self.assertIs(manager.tasks_list[1], request) + self.assertIsNot(manager.tasks_list[0], request) + self.assertFalse(manager.stop_flags[1]) + manager.need_block_num_signal.clear() + + def test_pending_preempt_slot_is_reserved_until_ack(self): + request = Mock(spec=Request) + request.request_id = "req-pending" + request.idx = 0 + request.status = RequestStatus.PREEMPTED + request.has_been_preempted_before = False + request.metrics = Mock() + request.metrics.preempted_count = 0 + + self.manager.stop_flags = [True, True, True, True] + self.manager.requests[request.request_id] = request + self.manager.to_be_rescheduled_request_id_set.add(request.request_id) + + self.assertEqual(self.manager.available_batch(), 3) + self.assertEqual(self.manager.get_available_position(), 1) + + self.manager.reschedule_preempt_task(request.request_id) + + self.assertEqual(request.idx, None) + self.assertEqual(self.manager.available_batch(), 4) + self.assertEqual(self.manager.get_available_position(), 0) + self.assertEqual(self.manager.waiting[0], request) + if __name__ == "__main__": unittest.main() diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index e8ff821a268..e49ee1ace41 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -25,6 +25,7 @@ import pytest from fastdeploy import envs +from fastdeploy.config import PREEMPTED_TOKEN_ID from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput from fastdeploy.output import token_processor from fastdeploy.output.token_processor import ( @@ -253,6 +254,20 @@ def test_reschedule_preempt_task_use_zmq_reschedules_missing_batch(): assert "reschedule-req-a" in rm.recycled +def test_process_batch_output_use_zmq_reschedules_preempted_stopped_slot(): + processor, rm, _, _ = _make_processor() + rm.stop_flags[0] = True + rm.to_be_rescheduled_request_id_set = {"req-a"} + rm.requests = {"req-a": types.SimpleNamespace(idx=0)} + receive_datas = [types.SimpleNamespace(batch_id=0, tokens=np.array([PREEMPTED_TOKEN_ID]), pooler_output=None)] + + with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True): + batch_result = processor._process_batch_output_use_zmq(receive_datas) + + assert batch_result == [] + assert "reschedule-req-a" in rm.recycled + + def test_process_batch_draft_tokens_collects_top_logprobs(): processor, rm, _, _ = _make_processor(speculative_method="mtp", enable_logprob=True) rm.tasks_list[0] = types.SimpleNamespace(request_id="task-0", block_tables=[1]) @@ -996,6 +1011,24 @@ def test_process_batch_output_skips_already_stopped_slot(): assert processor.cached_generated_tokens.put_results.called +def test_process_batch_output_reschedules_stopped_slot_negative_token(): + processor, rm, _, _ = _make_processor() + task_id = "req-stopped-neg" + rm.stop_flags[0] = True + rm.requests = {task_id: types.SimpleNamespace(idx=0)} + rm.to_be_rescheduled_request_id_set = {task_id} + processor.output_tokens[1, 0] = 1 + processor.output_tokens[2, 0] = -9 + + with ( + mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", True), + mock.patch.object(token_processor, "main_process_metrics", _Metrics()), + ): + processor._process_batch_output() + + assert rm.recycled[-1] == f"reschedule-{task_id}" + + def test_process_batch_output_speculative_negative_token_reschedules(): processor, rm, _, _ = _make_processor(speculative_method="mtp") task_id = "req-spec-neg"