Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions docs/features/weight_update.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ In FastDeploy >= 2.6, the underlying control-signal communication path is optimi
| `/v1/is_paused` | `GET` | none | Return `{"is_paused": bool}`. |
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | Offload selected GPU memory objects. Supported tags are `weight` and `kv_cache`. If omitted, both are used. |
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | Reload previously offloaded weights and/or KV cache. On success, the engine resumes automatically. |
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. |
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "verify_checksum": false}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. |

### Compatibility Notes

Expand Down Expand Up @@ -114,7 +114,7 @@ After `wakeup` succeeds, FastDeploy automatically calls `resume`.
Current request fields:

- `version`: optional string. Used to choose a target checkpoint version.
- `rsync_config`: optional dictionary. Must contain `etcd_server` when provided.
- `verify_checksum`: optional boolean. Defaults to `false`. Set to `true` to verify data integrity during weight synchronization.

Important semantics:

Expand Down Expand Up @@ -186,9 +186,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
"verify_checksum": false
}'
```

Expand Down Expand Up @@ -261,9 +259,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
"verify_checksum": false
}'

# Resume the service after the update completes
Expand Down
12 changes: 4 additions & 8 deletions docs/zh/features/weight_update.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
| `/v1/is_paused` | `GET` | 无 | 返回 `{"is_paused": bool}`。 |
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | 卸载指定 GPU 内存对象。支持 `weight` 与 `kv_cache`;不传时默认同时处理两者。 |
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | 重新加载之前被卸载的权重和/或 KV Cache。成功后会自动 `resume`。 |
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 |
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "verify_checksum": false}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 |

### 兼容性说明

Expand Down Expand Up @@ -113,7 +113,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
当前支持的请求字段:

- `version`:可选字符串,用于指定目标 checkpoint 版本。
- `rsync_config`:可选字典;如果传入,必须包含 `etcd_server`
- `verify_checksum`:可选布尔值;默认为 `false`。设置为 `true` 时,会在权重同步过程中校验数据完整性

关键语义:

Expand Down Expand Up @@ -185,9 +185,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
"verify_checksum": false
}'
```

Expand Down Expand Up @@ -260,9 +258,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
"verify_checksum": false
}'

# 更新完成后恢复服务
Expand Down
15 changes: 5 additions & 10 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,19 +461,14 @@ async def update_weights(request: Request) -> Response:
)
args["version"] = request_data["version"]

# Validate and extract rsync_config parameter
if "rsync_config" in request_data and request_data["rsync_config"] is not None:
if not isinstance(request_data["rsync_config"], dict):
# Validate and extract verify_checksum parameter
if "verify_checksum" in request_data and request_data["verify_checksum"] is not None:
if not isinstance(request_data["verify_checksum"], bool):
return JSONResponse(
status_code=400,
content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"},
content={"error": "Invalid parameter type", "message": "verify_checksum must be a boolean"},
)
if "etcd_server" not in request_data["rsync_config"]:
return JSONResponse(
status_code=400,
content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"},
)
args["rsync_config"] = request_data["rsync_config"]
args["verify_checksum"] = request_data["verify_checksum"]

control_request = ControlRequest(request_id, "update_weights", args)
control_response = await app.state.engine_client.run_control_method(control_request)
Expand Down
67 changes: 25 additions & 42 deletions fastdeploy/rl/dynamic_weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import gc
import glob
import io
import os
import re
import time
Expand All @@ -31,30 +30,6 @@
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus


def sync_weights_by_rdma(config, step, rank):
from checkpoint_transfer.core import RDMAWeightsDownloader

downloader = RDMAWeightsDownloader(config)
downloader.initialize()
logger.info(f"Fetching weights for step:{step}, rank:{rank}...")
data = downloader.get_weights(step, rank)
if data is None:
logger.error("Failed to get weights!")
raise Exception("Failed to rsync weights through checkpoint_transfer")
logger.info(f"Successfully retrieved data. Type: {type(data)}")
if isinstance(data, np.ndarray):
data_bytes = data.tobytes()
elif isinstance(data, (bytes, bytearray)):
data_bytes = data
else:
data_bytes = bytes(data)
logger.info(f"Data size: {len(data_bytes)} bytes")

buffer = io.BytesIO(data_bytes)
new_state_dict = paddle.load(buffer)
return new_state_dict


class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""

Expand All @@ -75,6 +50,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
else:
self.model_list = models
self._capture_model_state()
self.rdma_handle = None
if self.load_config.load_strategy == "rsync":
self.update_weights_by_rdma()
else:
Expand All @@ -91,10 +67,12 @@ def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for model in self.model_list:
for name, param in model.state_dict().items():
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
if hasattr(param, "_is_initialized") and not param._is_initialized():
param.initialize()
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}")
self.state_dict[name] = param

def update_weights_by_rdma(self, version: str = None, rsync_config: Dict[str, Any] = None):
def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False):
def valid_parameters(old_state_dict, new_state_dict):
is_valid = True
for key in old_state_dict:
Expand All @@ -110,17 +88,11 @@ def valid_parameters(old_state_dict, new_state_dict):
)
elif old_state_dict[key].dtype != new_state_dict[key].dtype:
is_valid = False
logger.error(f"Invalid parameter: {key} dtype mismatch")
logger.error(
f"Invalid parameter: {key} dtype mismatch, old:{old_state_dict[key].dtype}, new:{new_state_dict[key].dtype}"
)
return is_valid

if rsync_config is None:
rsync_config = self.fd_config.load_config.rsync_config
if rsync_config is None or len(rsync_config) == 0:
raise Exception(
"rsync config not set, please set it in 1) launch arguments '--rsync-config' "
"or 2) interface arguments 'rsync_config'"
)

if version is None or version == "":
version = self.read_model_version_from_file()
if version is None or version == "":
Expand All @@ -129,11 +101,23 @@ def valid_parameters(old_state_dict, new_state_dict):
"or 2) interface arguments 'version'"
)

logger.info(f"START update_weights_by_rdma, version:{version}, rsync_config:{rsync_config}")
rank = self.local_rank
logger.info(
f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, "
f"version:{version}, verify_checksum:{verify_checksum}"
)

if self.rdma_handle is None:
from checkpoint_transfer import CheckpointTransfer

config = self.fd_config.load_config.rsync_config
logger.info(f"CheckpointTransfer rsync config:{config}")
self.rdma_handle = CheckpointTransfer(**config, local_rank=self.local_rank, group_size=self.nranks)
self.rdma_handle.initialize()

sync_start = time.perf_counter()
new_state_dict = sync_weights_by_rdma(rsync_config, version, rank)
new_state_dict = dict()
for key, param in self.rdma_handle.receive_stream(step_id=version, verify_checksum=verify_checksum):
new_state_dict[key] = param
sync_cost = time.perf_counter() - sync_start
logger.info(f"weights sync cost {sync_cost:.2f} seconds")

Expand All @@ -148,18 +132,17 @@ def valid_parameters(old_state_dict, new_state_dict):
param.set_value(new_state_dict[name])
update_cost = time.perf_counter() - update_start
logger.info(f"params set value cost {update_cost:.2f} seconds")

total_cost = time.perf_counter() - sync_start
logger.info(
f"END update_weights_by_rdma, cost {total_cost:.2f} seconds"
f" version:{version}, rsync_config: {rsync_config}",
f" version:{version}, verify_checksum: {verify_checksum}, local_rank: {self.local_rank}",
)
return {
"sync_cost": sync_cost,
"update_cost": update_cost,
"total_cost": total_cost,
"version": version,
"rank": rank,
"rank": self.local_rank,
}

def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
from concurrent.futures import Future
from threading import Thread
from typing import Any, Dict, List, Optional, cast
from typing import Dict, List, Optional, cast

import numpy as np
import paddle
Expand Down Expand Up @@ -2692,8 +2692,8 @@ def update_parameters(self, pid):

self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")

def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
def update_weights(self, version: str = None, verify_checksum: bool = False):
return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum)

def sleep(self, tags):

Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import gc
import time
from typing import Any, Dict, List, Optional
from typing import List, Optional

import paddle
import pynvml
Expand Down Expand Up @@ -192,9 +192,9 @@ def initialize_cache(self, num_gpu_blocks: int) -> None:
if self.fd_config.routing_replay_config.enable_routing_replay:
self.model_runner.initialize_routing_replay_manager()

def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
def update_weights(self, version: str = None, verify_checksum: bool = False):
"""update weights in place"""
return self.model_runner.update_weights(version, rsync_config)
return self.model_runner.update_weights(version, verify_checksum)

def sleep(self, **kwargs) -> None:
"""Offload memory from GPU"""
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/worker/metax_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
from concurrent.futures import Future
from threading import Thread
from typing import Any, Dict, List, Optional, cast
from typing import List, Optional, cast

import numpy as np
import paddle
Expand Down Expand Up @@ -2550,8 +2550,8 @@ def update_parameters(self, pid):

self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")

def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
def update_weights(self, version: str = None, verify_checksum: bool = False):
return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum)

def padding_cudagraph_inputs(self) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/worker/metax_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import gc
import os
import time
from typing import Any, Dict, List, Optional
from typing import List, Optional

import paddle
from paddle import nn
Expand Down Expand Up @@ -191,9 +191,9 @@ def initialize_cache(self, num_gpu_blocks: int) -> None:
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
def update_weights(self, version: str = None, verify_checksum: bool = False):
"""update weights in place"""
return self.model_runner.update_weights(version, rsync_config)
return self.model_runner.update_weights(version, verify_checksum)

def execute_model(
self,
Expand Down
16 changes: 8 additions & 8 deletions tests/entrypoints/openai/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,25 +604,25 @@ async def test_update_weights_route_validation():
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)

valid_req = MagicMock()
valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}')
valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}})
valid_req.body = AsyncMock(return_value=b'{"version":"v2","verify_checksum":true}')
valid_req.json = AsyncMock(return_value={"version": "v2", "verify_checksum": True})
valid_resp = await api_server.update_weights(valid_req)
assert valid_resp.status_code == 200
control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0]
assert control_request.method == "update_weights"
assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}
assert control_request.args == {"version": "v2", "verify_checksum": True}

invalid_version_req = MagicMock()
invalid_version_req.body = AsyncMock(return_value=b'{"version":1}')
invalid_version_req.json = AsyncMock(return_value={"version": 1})
invalid_version_resp = await api_server.update_weights(invalid_version_req)
assert invalid_version_resp.status_code == 400

invalid_rsync_req = MagicMock()
invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}')
invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}})
invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req)
assert invalid_rsync_resp.status_code == 400
invalid_checksum_req = MagicMock()
invalid_checksum_req.body = AsyncMock(return_value=b'{"verify_checksum":"true"}')
invalid_checksum_req.json = AsyncMock(return_value={"verify_checksum": "true"})
invalid_checksum_resp = await api_server.update_weights(invalid_checksum_req)
assert invalid_checksum_resp.status_code == 400


@pytest.mark.asyncio
Expand Down
Loading