Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fastdeploy/cache_manager/cache_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class CacheStatus(Enum):
CPU = 3
GPU2STORAGE = 4
STORAGE2GPU = 5
DECODE_OFFLOAD = 6
DECODE_RESUME = 7
DECODE_CLEANUP = 8


class BlockNode:
Expand Down
17 changes: 17 additions & 0 deletions fastdeploy/cache_manager/cache_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ class ReadStorageTask(CacheTask):
@dataclass(frozen=True, kw_only=True)
class WriteStorageTask(CacheTask):
timeout: float = 30.0


@dataclass(frozen=True, kw_only=True)
class DecodeOffloadTask:
task_id: str
gpu_block_ids: List[int]


@dataclass(frozen=True, kw_only=True)
class DecodeResumeTask:
task_id: str
gpu_block_ids: List[int]


@dataclass(frozen=True, kw_only=True)
class DecodeCleanupTask:
task_id: str
135 changes: 134 additions & 1 deletion fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@

from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.cache_tasks import (
DecodeCleanupTask,
DecodeOffloadTask,
DecodeResumeTask,
ReadStorageTask,
WriteStorageTask,
)
from fastdeploy.cache_manager.ops import (
cuda_host_alloc,
cuda_host_free,
Expand Down Expand Up @@ -199,9 +205,13 @@ def __init__(self, args):
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.decode_offload_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.decode_resume_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.decode_cleanup_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.decode_offload_snapshots = {}

address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
Expand Down Expand Up @@ -1035,6 +1045,108 @@ def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIM

return True, ""

def _snapshot_blocks_to_cpu(self, gpu_cache_tensors, block_ids: List[int]):
snapshots = []
for cache_tensor in gpu_cache_tensors:
blocks = [cache_tensor[block_id] for block_id in block_ids]
layer_tensor = paddle.stack(blocks) if len(blocks) > 1 else blocks[0].unsqueeze(0)
snapshots.append(layer_tensor.to("cpu"))
return snapshots

def _restore_blocks_from_cpu(self, cpu_tensors, gpu_cache_tensors, block_ids: List[int]):
device = f"gpu:{self.device}"
for layer_id, cpu_tensor in enumerate(cpu_tensors):
gpu_tensor = gpu_cache_tensors[layer_id]
gpu_data = cpu_tensor.to(device)
for idx, block_id in enumerate(block_ids):
gpu_tensor[block_id] = gpu_data[idx]

def decode_offload_task(self, task: DecodeOffloadTask):
ok = False
meta = {
"rank": self.rank,
"num_blocks": len(task.gpu_block_ids),
"storage_level": "L2",
}
try:
if not task.gpu_block_ids:
raise ValueError(f"decode offload task {task.task_id} has empty gpu_block_ids")
snapshot = {
"key_caches": self._snapshot_blocks_to_cpu(self.gpu_cache_k_tensors, task.gpu_block_ids),
"value_caches": (
self._snapshot_blocks_to_cpu(self.gpu_cache_v_tensors, task.gpu_block_ids)
if self.gpu_cache_v_tensors
else []
),
"key_scales": (
self._snapshot_blocks_to_cpu(self.gpu_cache_scales_k_tensors, task.gpu_block_ids)
if self.gpu_cache_scales_k_tensors
else []
),
"value_scales": (
self._snapshot_blocks_to_cpu(self.gpu_cache_scales_v_tensors, task.gpu_block_ids)
if self.gpu_cache_scales_v_tensors
else []
),
}
self.decode_offload_snapshots[task.task_id] = snapshot
ok = True
except Exception as e:
meta["error"] = str(e)
logger.error(
f"decode_offload_task failed for {task.task_id}, error: {e}, traceback:\n{traceback.format_exc()}"
)
finally:
result = (CacheStatus.DECODE_OFFLOAD, task.task_id, self.rank, ok, meta)
self.cache_task_queue.put_transfer_done_signal(result)

def decode_resume_task(self, task: DecodeResumeTask):
ok = False
meta = {
"rank": self.rank,
"num_blocks": len(task.gpu_block_ids),
}
try:
if task.task_id not in self.decode_offload_snapshots:
raise KeyError(f"snapshot for {task.task_id} not found")
snapshot = self.decode_offload_snapshots[task.task_id]
self._restore_blocks_from_cpu(snapshot["key_caches"], self.gpu_cache_k_tensors, task.gpu_block_ids)
if self.gpu_cache_v_tensors and snapshot["value_caches"]:
self._restore_blocks_from_cpu(snapshot["value_caches"], self.gpu_cache_v_tensors, task.gpu_block_ids)
if self.gpu_cache_scales_k_tensors and snapshot["key_scales"]:
self._restore_blocks_from_cpu(
snapshot["key_scales"], self.gpu_cache_scales_k_tensors, task.gpu_block_ids
)
if self.gpu_cache_scales_v_tensors and snapshot["value_scales"]:
self._restore_blocks_from_cpu(
snapshot["value_scales"], self.gpu_cache_scales_v_tensors, task.gpu_block_ids
)
del self.decode_offload_snapshots[task.task_id]
ok = True
except Exception as e:
meta["error"] = str(e)
logger.error(
f"decode_resume_task failed for {task.task_id}, error: {e}, traceback:\n{traceback.format_exc()}"
)
finally:
result = (CacheStatus.DECODE_RESUME, task.task_id, self.rank, ok, meta)
self.cache_task_queue.put_transfer_done_signal(result)

def decode_cleanup_task(self, task: DecodeCleanupTask):
ok = False
meta = {"rank": self.rank}
try:
self.decode_offload_snapshots.pop(task.task_id, None)
ok = True
except Exception as e:
meta["error"] = str(e)
logger.error(
f"decode_cleanup_task failed for {task.task_id}, error: {e}, traceback:\n{traceback.format_exc()}"
)
finally:
result = (CacheStatus.DECODE_CLEANUP, task.task_id, self.rank, ok, meta)
self.cache_task_queue.put_transfer_done_signal(result)

def submit_task(self, thread_pool: concurrent.futures.ThreadPoolExecutor, task_fn, *args):

def inflight_task(fn, *args):
Expand Down Expand Up @@ -1129,6 +1241,27 @@ def do_data_transfer(self):
self.write_back_storage_task,
write_storage_task,
)
elif event_type.value == CacheStatus.DECODE_OFFLOAD.value:
decode_offload_task = event_args[0]
self.submit_task(
self.decode_offload_thread_pool,
self.decode_offload_task,
decode_offload_task,
)
elif event_type.value == CacheStatus.DECODE_RESUME.value:
decode_resume_task = event_args[0]
self.submit_task(
self.decode_resume_thread_pool,
self.decode_resume_task,
decode_resume_task,
)
elif event_type.value == CacheStatus.DECODE_CLEANUP.value:
decode_cleanup_task = event_args[0]
self.submit_task(
self.decode_cleanup_thread_pool,
self.decode_cleanup_task,
decode_cleanup_task,
)
else:
if self.n_ranks > 1:
self.cache_task_queue.barrier2.wait()
Expand Down
27 changes: 25 additions & 2 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
self.task_write_back_event = {}
self.task_prefetch_event = {}
self.storage_prefetch_block_ids = {}
self.transfer_result_handlers = []

# gpu cache data structure
self.gpu_lru_leaf_heap = []
Expand Down Expand Up @@ -293,7 +294,11 @@ def launch_cache_manager(
else:
storage_arg_str = " "

if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend:
if (
self.cache_config.swap_space
or self.cache_config.kvcache_storage_backend
or getattr(self.config, "enable_decode_offload", False)
):
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
Expand Down Expand Up @@ -352,7 +357,11 @@ def launch_cache_manager(
)

# Start additional threads
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
if (
cache_config.kvcache_storage_backend
or self.num_cpu_blocks > 0
or getattr(self.config, "enable_decode_offload", False)
):
logger.info("Enable hierarchical cache.")
threading.Thread(target=self.recv_data_transfer_result, daemon=True).start()
if cache_config.enable_prefix_caching:
Expand Down Expand Up @@ -1191,6 +1200,10 @@ def wait_prefetch_storage_task(self, req_id):
del self.storage_prefetch_block_ids[req_id]
return storage_block_ids

def register_transfer_result_handler(self, handler):
if handler not in self.transfer_result_handlers:
self.transfer_result_handlers.append(handler)

def free_nodes_directly(self, node):
with self.request_release_lock:
try:
Expand Down Expand Up @@ -2058,6 +2071,16 @@ def recv_data_transfer_result(self):
time.sleep(0.001)
continue
event_type = data[0]
handled = False
for handler in self.transfer_result_handlers:
try:
if handler(data):
handled = True
break
except Exception as e:
logger.warning(f"transfer result handler failed: {e}")
if handled:
continue

if event_type.value == CacheStatus.STORAGE2GPU.value:
logger.info(f"recv_data_transfer_result: {data}")
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,7 @@ def __init__(
tool_parser: str = None,
test_mode=False,
routing_replay_config: Optional[RoutingReplayConfig] = None,
enable_decode_offload: bool = False,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
Expand All @@ -1821,6 +1822,7 @@ def __init__(
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
self.router_config: RouterConfig = router_config
self.routing_replay_config = routing_replay_config
self.enable_decode_offload = enable_decode_offload

# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs
Expand Down
12 changes: 12 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,11 @@ class EngineArgs:
Flag to enable prefill_use_worst_num_tokens. Default is False (disabled).
"""

enable_decode_offload: bool = False
"""
Flag to enable decode offload. Default is False (disabled).
"""

def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
Expand Down Expand Up @@ -1071,6 +1076,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.ep_prefill_use_worst_num_tokens,
help="Enable prefill use worst num tokens for EP.",
)
parallel_group.add_argument(
"--enable-decode-offload",
action="store_true",
default=EngineArgs.enable_decode_offload,
help="Enable decode offload.",
)

# Load group
load_group = parser.add_argument_group("Load Configuration")
Expand Down Expand Up @@ -1514,4 +1525,5 @@ def create_engine_config(self) -> FDConfig:
plas_attention_config=plas_attention_config,
early_stop_config=early_stop_cfg,
routing_replay_config=routing_replay_config,
enable_decode_offload=self.enable_decode_offload,
)
30 changes: 18 additions & 12 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,18 +1028,24 @@ def _fetch_request():
if self.cfg.scheduler_config.splitwise_role == "decode":
for task in tasks:
if task.task_type == RequestType.PREEMPTED:
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
self.llm_logger.error(msg)
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
req = self.resource_manager.requests.get(task.request_id)
if req is not None and req.is_offloaded:
self.llm_logger.info(
f"{task.request_id} decode request is preempted and offloaded, waiting for resume."
)
else:
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
self.llm_logger.error(msg)
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
Comment on lines +1038 to +1048
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 RequestType.PREEMPTED(decode 侧)仍向 scheduler.put_results() 发送 finished=True 的 500 错误(仅 is_offloaded 跳过)。但本 PR 同时引入了 preempt 后在本地 reschedule 并继续输出 token 的逻辑;scheduler 一旦因 finished=True 提前回收该 request_id,后续输出会被当作 expired response 丢弃,导致请求无法恢复。建议将 PREEMPTED 作为纯内部控制任务,不要向 scheduler 发送 finished=True 的错误结果(必要时用非 finished 的状态事件/独立通道通知)。

Suggested change
self.llm_logger.error(msg)
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
# Treat PREEMPTED as an internal control signal on decode side.
# Do not send a finished=True error result to scheduler here,
# otherwise the request_id would be reclaimed and could not be resumed.
self.llm_logger.info(msg)

Copilot uses AI. Check for mistakes.
self.resource_manager.get_real_bsz()
for task in tasks:
if task.task_type == RequestType.PREFILL:
Expand Down
Loading
Loading