From 0947b1ea5cf66c0722dd0f43728b62a173b2fb31 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 29 May 2026 12:59:56 +0800 Subject: [PATCH 01/34] refactor(server): extract state module and introduce StateBackend abstraction - Move utils/state/ to server/state/ as top-level module - Fix all 8 import references (no re-export compatibility layer) - Add StateBackend ABC with set/get/delete/exists/keys/count/set_nx/close/health_check - Implement MemoryBackend (sync in-memory, compatible with Ray Actor) - Refactor ConfigManager to use StateBackend with 'config::' key prefix - Inject optional backend parameter into ServerState and get_server_state factory - Add unified exception hierarchy (TwinkleServerError and subclasses) - Create telemetry/ skeleton directory for Phase 2 --- .../server/transformer/server_config.yaml | 68 +++++++-------- src/twinkle/server/common/router.py | 2 +- src/twinkle/server/exceptions.py | 23 +++++ src/twinkle/server/gateway/server.py | 2 +- src/twinkle/server/model/app.py | 2 +- src/twinkle/server/processor/app.py | 2 +- src/twinkle/server/sampler/app.py | 2 +- .../server/{utils => }/state/__init__.py | 0 src/twinkle/server/state/backend/__init__.py | 4 + src/twinkle/server/state/backend/base.py | 56 +++++++++++++ .../server/state/backend/memory_backend.py | 83 ++++++++++++++++++ src/twinkle/server/{utils => }/state/base.py | 0 src/twinkle/server/state/config_manager.py | 81 ++++++++++++++++++ .../{utils => }/state/future_manager.py | 0 .../server/{utils => }/state/model_manager.py | 0 .../server/{utils => }/state/models.py | 0 .../{utils => }/state/sampling_manager.py | 0 .../server/{utils => }/state/server_state.py | 84 ++++++++++++++++++- .../{utils => }/state/session_manager.py | 0 src/twinkle/server/telemetry/__init__.py | 0 src/twinkle/server/utils/lifecycle/base.py | 2 +- .../server/utils/state/config_manager.py | 53 ------------ src/twinkle/server/utils/task_queue/mixin.py | 2 +- src/twinkle/server/utils/task_queue/worker.py | 2 +- 24 files changed, 370 insertions(+), 98 deletions(-) create mode 100644 src/twinkle/server/exceptions.py rename src/twinkle/server/{utils => }/state/__init__.py (100%) create mode 100644 src/twinkle/server/state/backend/__init__.py create mode 100644 src/twinkle/server/state/backend/base.py create mode 100644 src/twinkle/server/state/backend/memory_backend.py rename src/twinkle/server/{utils => }/state/base.py (100%) create mode 100644 src/twinkle/server/state/config_manager.py rename src/twinkle/server/{utils => }/state/future_manager.py (100%) rename src/twinkle/server/{utils => }/state/model_manager.py (100%) rename src/twinkle/server/{utils => }/state/models.py (100%) rename src/twinkle/server/{utils => }/state/sampling_manager.py (100%) rename src/twinkle/server/{utils => }/state/server_state.py (85%) rename src/twinkle/server/{utils => }/state/session_manager.py (100%) create mode 100644 src/twinkle/server/telemetry/__init__.py delete mode 100644 src/twinkle/server/utils/state/config_manager.py diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 570142afa..0a97a8896 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -64,43 +64,43 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_TRUST_REMOTE_CODE: "1" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: 1 # Number of GPUs to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "1" # 4. Processor Service - name: processor diff --git a/src/twinkle/server/common/router.py b/src/twinkle/server/common/router.py index dee1bd36e..5ecf55d28 100644 --- a/src/twinkle/server/common/router.py +++ b/src/twinkle/server/common/router.py @@ -4,7 +4,7 @@ RequestRouter, RunningReplica) from typing import Dict, List, Optional -from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.utils.logger import get_logger logger = get_logger() diff --git a/src/twinkle/server/exceptions.py b/src/twinkle/server/exceptions.py new file mode 100644 index 000000000..f96d42500 --- /dev/null +++ b/src/twinkle/server/exceptions.py @@ -0,0 +1,23 @@ +"""Twinkle Server unified exception hierarchy.""" + +from __future__ import annotations + + +class TwinkleServerError(Exception): + """所有 Twinkle Server 异常的基类""" + pass + + +class StateBackendError(TwinkleServerError): + """状态后端操作失败(连接断开、超时、数据序列化错误等)""" + pass + + +class ConfigMismatchError(TwinkleServerError): + """配置签名不匹配 — 重启后检测到配置变更,持久化数据可能与当前配置不兼容""" + pass + + +class ResourceExhaustedError(TwinkleServerError): + """资源耗尽 — 队列满、内存不足、连接池耗尽等""" + pass diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 755c5d2b4..92897af6e 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -15,7 +15,7 @@ import twinkle_client.types as types from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import get_server_state +from twinkle.server.state import get_server_state from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger from .proxy import ServiceProxy diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 5d0bc2285..e8da0078e 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -17,7 +17,7 @@ from twinkle import DeviceGroup, DeviceMesh from twinkle.server.utils.lifecycle import AdapterManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 40fdadbea..1db7c59dc 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -22,7 +22,7 @@ from twinkle import DeviceGroup, DeviceMesh, get_logger from twinkle.server.utils.lifecycle import ProcessorManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token from .twinkle_handlers import _register_processor_routes diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 177344727..c1a283bd6 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -14,7 +14,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger diff --git a/src/twinkle/server/utils/state/__init__.py b/src/twinkle/server/state/__init__.py similarity index 100% rename from src/twinkle/server/utils/state/__init__.py rename to src/twinkle/server/state/__init__.py diff --git a/src/twinkle/server/state/backend/__init__.py b/src/twinkle/server/state/backend/__init__.py new file mode 100644 index 000000000..f3baab6f5 --- /dev/null +++ b/src/twinkle/server/state/backend/__init__.py @@ -0,0 +1,4 @@ +from .base import StateBackend +from .memory_backend import MemoryBackend + +__all__ = ['StateBackend', 'MemoryBackend'] diff --git a/src/twinkle/server/state/backend/base.py b/src/twinkle/server/state/backend/base.py new file mode 100644 index 000000000..6a44d7944 --- /dev/null +++ b/src/twinkle/server/state/backend/base.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class StateBackend(ABC): + """状态存储后端的统一接口。 + + 所有状态管理操作通过此接口进行,支持多种后端实现(内存、文件、Redis)。 + """ + + @abstractmethod + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """存储键值对,可选 TTL(秒)""" + ... + + @abstractmethod + async def get(self, key: str) -> Any | None: + """获取值,不存在或已过期返回 None""" + ... + + @abstractmethod + async def delete(self, key: str) -> None: + """删除键,不存在时静默忽略""" + ... + + @abstractmethod + async def exists(self, key: str) -> bool: + """检查键是否存在且未过期""" + ... + + @abstractmethod + async def keys(self, pattern: str) -> list[str]: + """按模式匹配返回所有键名。pattern 支持 * 通配符(如 'session::*')""" + ... + + @abstractmethod + async def count(self, pattern: str) -> int: + """按模式匹配计数""" + ... + + @abstractmethod + async def set_nx(self, key: str, value: Any) -> bool: + """Set if not exists. 返回 True 如果成功设置,False 如果键已存在""" + ... + + @abstractmethod + async def close(self) -> None: + """关闭后端连接/释放资源""" + ... + + @abstractmethod + async def health_check(self) -> bool: + """检查后端是否健康可用""" + ... diff --git a/src/twinkle/server/state/backend/memory_backend.py b/src/twinkle/server/state/backend/memory_backend.py new file mode 100644 index 000000000..a960f1e78 --- /dev/null +++ b/src/twinkle/server/state/backend/memory_backend.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import time +from fnmatch import fnmatch +from typing import Any + +from .base import StateBackend + + +class MemoryBackend(StateBackend): + """基于内存字典的状态后端实现。 + + 使用 ``dict[str, tuple[Any, float | None]]`` 存储 (value, expire_at)。 + 过期检查在 get/exists 时进行(惰性过期),适用于 Ray Actor 单线程模型。 + """ + + def __init__(self) -> None: + self._store: dict[str, tuple[Any, float | None]] = {} + + def _is_expired(self, key: str) -> bool: + """检查键是否已过期。如已过期则删除并返回 True。""" + entry = self._store.get(key) + if entry is None: + return True + _, expire_at = entry + if expire_at is not None and time.time() >= expire_at: + del self._store[key] + return True + return False + + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """存储键值对,可选 TTL(秒)""" + expire_at = (time.time() + ttl) if ttl is not None else None + self._store[key] = (value, expire_at) + + async def get(self, key: str) -> Any | None: + """获取值,不存在或已过期返回 None""" + if self._is_expired(key): + return None + value, _ = self._store[key] + return value + + async def delete(self, key: str) -> None: + """删除键,不存在时静默忽略""" + self._store.pop(key, None) + + async def exists(self, key: str) -> bool: + """检查键是否存在且未过期""" + return not self._is_expired(key) + + async def keys(self, pattern: str) -> list[str]: + """按模式匹配返回所有键名。pattern 支持 * 通配符。""" + result: list[str] = [] + # 遍历时收集过期键,避免在迭代中修改字典 + expired_keys: list[str] = [] + for key, (_, expire_at) in self._store.items(): + if expire_at is not None and time.time() >= expire_at: + expired_keys.append(key) + continue + if fnmatch(key, pattern): + result.append(key) + for key in expired_keys: + del self._store[key] + return result + + async def count(self, pattern: str) -> int: + """按模式匹配计数""" + return len(await self.keys(pattern)) + + async def set_nx(self, key: str, value: Any) -> bool: + """Set if not exists. 返回 True 如果成功设置,False 如果键已存在。""" + if not self._is_expired(key): + return False + self._store[key] = (value, None) + return True + + async def close(self) -> None: + """关闭后端,清空存储""" + self._store.clear() + + async def health_check(self) -> bool: + """检查后端是否健康可用,内存后端始终返回 True""" + return True diff --git a/src/twinkle/server/utils/state/base.py b/src/twinkle/server/state/base.py similarity index 100% rename from src/twinkle/server/utils/state/base.py rename to src/twinkle/server/state/base.py diff --git a/src/twinkle/server/state/config_manager.py b/src/twinkle/server/state/config_manager.py new file mode 100644 index 000000000..ecf76f63b --- /dev/null +++ b/src/twinkle/server/state/config_manager.py @@ -0,0 +1,81 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +from typing import Any + +from .backend import StateBackend + +# Key prefix used to namespace configuration entries inside the backend. +_CONFIG_PREFIX = 'config::' +_CONFIG_PATTERN = f'{_CONFIG_PREFIX}*' + + +class ConfigManager: + """ + Manages key-value configuration entries via a :class:`StateBackend`. + + Configuration entries have no expiry; they persist until explicitly removed + or cleared. This manager does not inherit from BaseManager because config + values are arbitrary Python objects rather than Pydantic models, and all + storage is delegated to the injected backend. + + Methods are ``async`` because :class:`StateBackend` operations are async. + ConfigManager is expected to run inside a single-threaded Ray Actor, so + there is no need to add additional locking on top of the backend. + """ + + def __init__(self, backend: StateBackend) -> None: + self._backend = backend + + @staticmethod + def _make_key(key: str) -> str: + return f'{_CONFIG_PREFIX}{key}' + + # ----- CRUD ----- + + async def add(self, key: str, value: Any) -> None: + """Add or overwrite a configuration value.""" + await self._backend.set(self._make_key(key), value) + + async def add_or_get(self, key: str, value: Any) -> Any: + """Add a value if the key does not exist; otherwise return the existing value. + + Args: + key: Configuration key. + value: Value to store if the key is absent. + + Returns: + The existing or newly stored value. + """ + backend_key = self._make_key(key) + existing = await self._backend.get(backend_key) + if existing is not None: + return existing + # Use set_nx for atomicity within a single backend; if another + # writer already populated the key we return the winning value. + if await self._backend.set_nx(backend_key, value): + return value + return await self._backend.get(backend_key) + + async def get(self, key: str) -> Any | None: + """Return the configuration value for key, or None.""" + return await self._backend.get(self._make_key(key)) + + async def pop(self, key: str) -> Any | None: + """Remove and return the configuration value for key, or None.""" + backend_key = self._make_key(key) + value = await self._backend.get(backend_key) + if value is None: + return None + await self._backend.delete(backend_key) + return value + + async def clear(self) -> None: + """Remove all configuration entries.""" + keys = await self._backend.keys(_CONFIG_PATTERN) + for backend_key in keys: + await self._backend.delete(backend_key) + + async def count(self) -> int: + """Return the number of stored configuration entries.""" + return await self._backend.count(_CONFIG_PATTERN) diff --git a/src/twinkle/server/utils/state/future_manager.py b/src/twinkle/server/state/future_manager.py similarity index 100% rename from src/twinkle/server/utils/state/future_manager.py rename to src/twinkle/server/state/future_manager.py diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/state/model_manager.py similarity index 100% rename from src/twinkle/server/utils/state/model_manager.py rename to src/twinkle/server/state/model_manager.py diff --git a/src/twinkle/server/utils/state/models.py b/src/twinkle/server/state/models.py similarity index 100% rename from src/twinkle/server/utils/state/models.py rename to src/twinkle/server/state/models.py diff --git a/src/twinkle/server/utils/state/sampling_manager.py b/src/twinkle/server/state/sampling_manager.py similarity index 100% rename from src/twinkle/server/utils/state/sampling_manager.py rename to src/twinkle/server/state/sampling_manager.py diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/state/server_state.py similarity index 85% rename from src/twinkle/server/utils/state/server_state.py rename to src/twinkle/server/state/server_state.py index fd3a76269..2af95f856 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -11,6 +11,7 @@ from twinkle.server.utils.metrics import get_resource_metrics from twinkle.utils.logger import get_logger +from .backend import MemoryBackend, StateBackend from .config_manager import ConfigManager from .future_manager import FutureManager from .model_manager import ModelManager @@ -37,15 +38,19 @@ class ServerState: def __init__( self, + backend: StateBackend | None = None, expiration_timeout: float = 86400.0, # 24 hours in seconds cleanup_interval: float = 3600.0, # 1 hour in seconds per_token_model_limit: int = 30, **kwargs) -> None: + # Backend is currently consumed only by ConfigManager; other managers + # will be migrated in subsequent tasks. + self._backend: StateBackend = backend if backend is not None else MemoryBackend() self._session_mgr = SessionManager(expiration_timeout) self._model_mgr = ModelManager(expiration_timeout, per_token_model_limit) self._sampling_mgr = SamplingSessionManager(expiration_timeout) self._future_mgr = FutureManager(expiration_timeout) - self._config_mgr = ConfigManager() + self._config_mgr = ConfigManager(self._backend) self.expiration_timeout = expiration_timeout self.cleanup_interval = cleanup_interval @@ -251,6 +256,32 @@ async def store_future_status( queue_state_reason=queue_state_reason, ) + # ----- Configuration Management ----- + + async def add_config(self, key: str, value: Any) -> None: + """Add or overwrite a configuration value.""" + await self._config_mgr.add(key, value) + + async def add_or_get_config(self, key: str, value: Any) -> Any: + """Add a config value if absent; otherwise return the existing value.""" + return await self._config_mgr.add_or_get(key, value) + + async def get_config(self, key: str) -> Any | None: + """Return the configuration value for key, or None.""" + return await self._config_mgr.get(key) + + async def pop_config(self, key: str) -> Any | None: + """Remove and return the configuration value for key, or None.""" + return await self._config_mgr.pop(key) + + async def clear_config(self) -> None: + """Remove all configuration entries.""" + await self._config_mgr.clear() + + async def count_config(self) -> int: + """Return the number of stored configuration entries.""" + return await self._config_mgr.count() + # ----- Resource Cleanup ----- async def cleanup_expired_resources(self) -> dict[str, int]: @@ -429,6 +460,26 @@ async def create_sampling_session(self, payload: dict[str, Any], sampling_sessio async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: return await self._actor.get_sampling_session.remote(sampling_session_id) + # ----- Configuration Management ----- + + async def add_config(self, key: str, value: Any) -> None: + await self._actor.add_config.remote(key, value) + + async def add_or_get_config(self, key: str, value: Any) -> Any: + return await self._actor.add_or_get_config.remote(key, value) + + async def get_config(self, key: str) -> Any | None: + return await self._actor.get_config.remote(key) + + async def pop_config(self, key: str) -> Any | None: + return await self._actor.pop_config.remote(key) + + async def clear_config(self) -> None: + await self._actor.clear_config.remote() + + async def count_config(self) -> int: + return await self._actor.count_config.remote() + # ----- Future Management ----- async def get_future(self, request_id: str) -> dict[str, Any] | None: @@ -448,6 +499,26 @@ async def store_future_status( await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, queue_state_reason) + # ----- Configuration Management ----- + + async def add_config(self, key: str, value: Any) -> None: + await self._actor.add_config.remote(key, value) + + async def add_or_get_config(self, key: str, value: Any) -> Any: + return await self._actor.add_or_get_config.remote(key, value) + + async def get_config(self, key: str) -> Any | None: + return await self._actor.get_config.remote(key) + + async def pop_config(self, key: str) -> Any | None: + return await self._actor.pop_config.remote(key) + + async def clear_config(self) -> None: + await self._actor.clear_config.remote() + + async def count_config(self) -> int: + return await self._actor.count_config.remote() + # ----- Resource Cleanup ----- async def cleanup_expired_resources(self) -> dict[str, int]: @@ -468,7 +539,10 @@ async def get_cleanup_stats(self) -> dict[str, Any]: # --------------------------------------------------------------------------- -def get_server_state(actor_name: str = 'twinkle_server_state', **kwargs) -> ServerStateProxy: +def get_server_state( + actor_name: str = 'twinkle_server_state', + backend: StateBackend | None = None, + **kwargs) -> ServerStateProxy: """Get or create the ServerState Ray actor. Ensures only one ServerState actor exists with the given name. Uses a @@ -476,6 +550,9 @@ def get_server_state(actor_name: str = 'twinkle_server_state', **kwargs) -> Serv Args: actor_name: Name for the Ray actor (default: 'twinkle_server_state'). + backend: Optional :class:`StateBackend` injected into the ServerState + actor. When ``None`` the actor falls back to an in-process + :class:`MemoryBackend`. **kwargs: Additional keyword arguments passed to ServerState constructor (e.g., expiration_timeout, cleanup_interval). @@ -487,7 +564,8 @@ def get_server_state(actor_name: str = 'twinkle_server_state', **kwargs) -> Serv except ValueError: try: _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**kwargs) + actor = _ServerState.options(name=actor_name, lifetime='detached').remote( + backend=backend, **kwargs) try: ray.get(actor.start_cleanup_task.remote()) except Exception as e: diff --git a/src/twinkle/server/utils/state/session_manager.py b/src/twinkle/server/state/session_manager.py similarity index 100% rename from src/twinkle/server/utils/state/session_manager.py rename to src/twinkle/server/state/session_manager.py diff --git a/src/twinkle/server/telemetry/__init__.py b/src/twinkle/server/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/twinkle/server/utils/lifecycle/base.py b/src/twinkle/server/utils/lifecycle/base.py index 6c1c1b57a..2d947bdd0 100644 --- a/src/twinkle/server/utils/lifecycle/base.py +++ b/src/twinkle/server/utils/lifecycle/base.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy + from twinkle.server.state import ServerStateProxy from twinkle.utils.logger import get_logger diff --git a/src/twinkle/server/utils/state/config_manager.py b/src/twinkle/server/utils/state/config_manager.py deleted file mode 100644 index e1aa3bcea..000000000 --- a/src/twinkle/server/utils/state/config_manager.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -from typing import Any - - -class ConfigManager: - """ - Manages key-value configuration entries. - - Configuration entries have no expiry; they persist until explicitly removed - or cleared. This manager does not inherit from BaseManager because config - values are arbitrary Python objects rather than Pydantic models. - """ - - def __init__(self) -> None: - self._store: dict[str, Any] = {} - - # ----- CRUD ----- - - def add(self, key: str, value: Any) -> None: - """Add or overwrite a configuration value.""" - self._store[key] = value - - def add_or_get(self, key: str, value: Any) -> Any: - """Add a value if the key does not exist; otherwise return the existing value. - - Args: - key: Configuration key. - value: Value to store if the key is absent. - - Returns: - The existing or newly stored value. - """ - if key not in self._store: - self._store[key] = value - return self._store[key] - - def get(self, key: str) -> Any | None: - """Return the configuration value for key, or None.""" - return self._store.get(key) - - def pop(self, key: str) -> Any | None: - """Remove and return the configuration value for key, or None.""" - return self._store.pop(key, None) - - def clear(self) -> None: - """Remove all configuration entries.""" - self._store.clear() - - def count(self) -> int: - """Return the number of stored configuration entries.""" - return len(self._store) diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py index a5ecbc7e8..5c8d77b92 100644 --- a/src/twinkle/server/utils/task_queue/mixin.py +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -22,7 +22,7 @@ from .worker import ComputeWorker if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy + from twinkle.server.state import ServerStateProxy logger = get_logger() diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index 77740cb72..f8121736b 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from twinkle.server.utils.metrics import TaskMetrics - from twinkle.server.utils.state import ServerStateProxy + from twinkle.server.state import ServerStateProxy logger = get_logger() From 7df3d66ce79a2b79ff44ae49add4fc0912829562 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 29 May 2026 13:16:39 +0800 Subject: [PATCH 02/34] feat(server): replace Ray metrics with OpenTelemetry observability (trace/metric/log) - Add telemetry/provider.py: OTEL TracerProvider/MeterProvider/LoggerProvider init with debug (console) and OTLP export modes, graceful shutdown - Add telemetry/metrics.py: MetricsRegistry singleton facade over OTEL meters (low-invasiveness: business code uses MetricsRegistry.get() only) - Add telemetry/tracing.py: get_tracer/inject_context/extract_context with noop fallback when OTEL SDK is not installed - Rewrite utils/metrics.py as thin adapter layer: _Counter/_Histogram/_Gauge map Ray-style API (inc/set/observe) to MetricsRegistry OTEL instruments - Update server_state.py _metrics_loop to use MetricsRegistry UpDownCounter with delta calculation - Inject trace context in gateway/proxy.py for distributed tracing - All OTEL imports are guarded (optional dependency): server starts normally without opentelemetry packages installed (NoOp fallback) - Completely remove ray.util.metrics dependency (zero residual references) --- src/twinkle/server/gateway/proxy.py | 5 + src/twinkle/server/state/server_state.py | 28 +- src/twinkle/server/telemetry/__init__.py | 25 ++ src/twinkle/server/telemetry/metrics.py | 86 ++++++ src/twinkle/server/telemetry/provider.py | 251 ++++++++++++++++++ src/twinkle/server/telemetry/tracing.py | 57 ++++ src/twinkle/server/utils/metrics.py | 243 +++++++---------- .../server/utils/task_queue/rate_limiter.py | 3 +- 8 files changed, 536 insertions(+), 162 deletions(-) create mode 100644 src/twinkle/server/telemetry/metrics.py create mode 100644 src/twinkle/server/telemetry/provider.py create mode 100644 src/twinkle/server/telemetry/tracing.py diff --git a/src/twinkle/server/gateway/proxy.py b/src/twinkle/server/gateway/proxy.py index 0978b8e03..14dc35424 100644 --- a/src/twinkle/server/gateway/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -12,6 +12,7 @@ from fastapi import Request, Response from typing import Any +from twinkle.server.telemetry.tracing import inject_context from twinkle.utils.logger import get_logger logger = get_logger() @@ -97,6 +98,10 @@ async def proxy_request( target_url = self._build_target_url(service_type, base_model, endpoint) headers = self._prepare_headers(request.headers) + # Inject current trace context into outgoing headers for distributed tracing. + # When telemetry is not initialized, this is a noop. + inject_context(headers) + try: logger.debug( 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 2af95f856..f3521981a 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -9,7 +9,7 @@ from datetime import datetime from typing import Any -from twinkle.server.utils.metrics import get_resource_metrics +from twinkle.server.telemetry import MetricsRegistry from twinkle.utils.logger import get_logger from .backend import MemoryBackend, StateBackend from .config_manager import ConfigManager @@ -328,15 +328,29 @@ async def _cleanup_loop(self) -> None: continue async def _metrics_loop(self) -> None: - """Background task that updates resource gauge metrics every N seconds.""" - resource_metrics = get_resource_metrics() + """Background task that updates resource gauge metrics every N seconds. + + OTEL up/down counters take *deltas*, not absolute values, so we keep + track of the last reported count per resource and emit only the + difference on each tick. + """ + registry = MetricsRegistry.get() + sources = ( + ('active_sessions', self._session_mgr.count), + ('active_models', self._model_mgr.count), + ('active_sampling_sessions', self._sampling_mgr.count), + ('active_futures', self._future_mgr.count), + ) + last_values: dict[str, int] = {name: 0 for name, _ in sources} while self._metrics_running: try: await asyncio.sleep(self._metrics_update_interval) - resource_metrics.active_sessions.set(self._session_mgr.count()) - resource_metrics.active_models.set(self._model_mgr.count()) - resource_metrics.active_sampling_sessions.set(self._sampling_mgr.count()) - resource_metrics.active_futures.set(self._future_mgr.count()) + for name, counter in sources: + current = counter() + delta = current - last_values[name] + if delta != 0: + getattr(registry, name).add(delta) + last_values[name] = current except asyncio.CancelledError: break except Exception as e: diff --git a/src/twinkle/server/telemetry/__init__.py b/src/twinkle/server/telemetry/__init__.py index e69de29bb..dd2c3bcfc 100644 --- a/src/twinkle/server/telemetry/__init__.py +++ b/src/twinkle/server/telemetry/__init__.py @@ -0,0 +1,25 @@ +from .metrics import MetricsRegistry +from .provider import ( + TelemetryConfig, + get_meter, + init_telemetry, + shutdown_telemetry, +) +from .tracing import ( + get_tracer, + inject_context, + extract_context, + get_current_span, +) + +__all__ = [ + "MetricsRegistry", + "TelemetryConfig", + "get_meter", + "init_telemetry", + "shutdown_telemetry", + "get_tracer", + "inject_context", + "extract_context", + "get_current_span", +] diff --git a/src/twinkle/server/telemetry/metrics.py b/src/twinkle/server/telemetry/metrics.py new file mode 100644 index 000000000..e95b8242d --- /dev/null +++ b/src/twinkle/server/telemetry/metrics.py @@ -0,0 +1,86 @@ +"""Twinkle Server metrics registry — low-invasiveness facade over OpenTelemetry metrics.""" + +from __future__ import annotations + +from .provider import get_meter + + +class MetricsRegistry: + """集中声明所有指标。业务代码通过 MetricsRegistry.get() 获取单例使用。 + + 当 telemetry 未初始化时,OTEL 返回 NoOp meter,所有记录操作自动静默。 + """ + + _instance: MetricsRegistry | None = None + + def __init__(self) -> None: + meter = get_meter("twinkle-server") + + # === HTTP 请求 === + self.requests_total = meter.create_counter( + "twinkle.http.requests.total", + description="Total HTTP requests received", + ) + self.request_duration_seconds = meter.create_histogram( + "twinkle.http.request.duration_seconds", + description="HTTP request duration in seconds", + unit="s", + ) + + # === 任务队列 === + self.queue_depth = meter.create_up_down_counter( + "twinkle.queue.depth", + description="Current task queue depth", + ) + self.task_execution_seconds = meter.create_histogram( + "twinkle.task.execution_seconds", + description="Task execution duration in seconds", + unit="s", + ) + self.task_wait_seconds = meter.create_histogram( + "twinkle.task.wait_seconds", + description="Task wait time in queue before execution", + unit="s", + ) + self.rate_limit_rejections = meter.create_counter( + "twinkle.rate_limit.rejections.total", + description="Total requests rejected by rate limiter", + ) + self.tasks_total = meter.create_counter( + "twinkle.tasks.total", + description="Total task completions, partitioned by status", + ) + self.rate_limiter_active_tokens = meter.create_up_down_counter( + "twinkle.rate_limiter.active_tokens", + description="Number of tokens currently tracked by the rate limiter", + ) + + # === 资源 === + self.active_sessions = meter.create_up_down_counter( + "twinkle.sessions.active", + description="Number of active client sessions", + ) + self.active_models = meter.create_up_down_counter( + "twinkle.models.active", + description="Number of registered models", + ) + self.active_sampling_sessions = meter.create_up_down_counter( + "twinkle.sampling_sessions.active", + description="Number of active sampling sessions", + ) + self.active_futures = meter.create_up_down_counter( + "twinkle.futures.active", + description="Number of pending futures/tasks", + ) + + @classmethod + def get(cls) -> MetricsRegistry: + """获取全局 MetricsRegistry 单例。首次调用时创建。""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + """重置单例(用于测试或 telemetry 重新初始化)""" + cls._instance = None diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py new file mode 100644 index 000000000..71bcc3734 --- /dev/null +++ b/src/twinkle/server/telemetry/provider.py @@ -0,0 +1,251 @@ +"""OpenTelemetry provider initialization for Twinkle server. + +Bootstraps the three OTEL pillars (traces, metrics, logs) with either an OTLP +or a console exporter. Designed to be a thin, side-effect-driven module that +exposes: + +- ``TelemetryConfig``: pydantic configuration model +- ``init_telemetry``: entry point that wires up global providers +- ``shutdown_telemetry``: graceful teardown for the global providers +- ``get_meter``: convenience accessor used by ``MetricsRegistry`` +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional OTEL imports — keep them lazy/guarded so that a missing optional +# dependency does not break the rest of the server. +# --------------------------------------------------------------------------- +try: + from opentelemetry import metrics, trace + from opentelemetry._logs import set_logger_provider + from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler + from opentelemetry.sdk._logs.export import ( + BatchLogRecordProcessor, + ConsoleLogExporter, + ) + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import ( + ConsoleMetricExporter, + PeriodicExportingMetricReader, + ) + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + ConsoleSpanExporter, + ) + + _OTEL_AVAILABLE = True + _OTEL_IMPORT_ERROR: Optional[BaseException] = None +except Exception as exc: # pragma: no cover - defensive fallback + _OTEL_AVAILABLE = False + _OTEL_IMPORT_ERROR = exc + +# OTLP exporters are a separate optional dependency from the SDK itself. +try: + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( + OTLPLogExporter, + ) + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter, + ) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter, + ) + + _OTLP_AVAILABLE = True +except Exception: # pragma: no cover - defensive fallback + _OTLP_AVAILABLE = False + +# Logging instrumentor is also optional. +try: + from opentelemetry.instrumentation.logging import LoggingInstrumentor + + _LOGGING_INSTRUMENTOR_AVAILABLE = True +except Exception: # pragma: no cover - defensive fallback + _LOGGING_INSTRUMENTOR_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# Module-level state for shutdown. +# --------------------------------------------------------------------------- +_tracer_provider: Optional[Any] = None +_meter_provider: Optional[Any] = None +_logger_provider: Optional[Any] = None +_logging_handler: Optional[Any] = None +_initialized: bool = False + + +class TelemetryConfig(BaseModel): + """Configuration for the OpenTelemetry pipeline.""" + + enabled: bool = False + service_name: str = "twinkle-server" + otlp_endpoint: str = "http://localhost:4317" + debug: bool = False # True: Console Exporter; False: OTLP Exporter + export_interval_ms: int = 30000 + resource_attributes: dict = Field(default_factory=dict) + + +def init_telemetry(config: TelemetryConfig) -> None: + """Initialize the three OTEL pillars (traces, metrics, logs). + + No-op when ``config.enabled`` is False or the OTEL SDK is missing. + """ + global _tracer_provider, _meter_provider, _logger_provider + global _logging_handler, _initialized + + if not config.enabled: + return + + if not _OTEL_AVAILABLE: + logger.warning( + "OpenTelemetry SDK not available, skipping telemetry init: %s", + _OTEL_IMPORT_ERROR, + ) + return + + if _initialized: + logger.debug("Telemetry already initialized; skipping re-init.") + return + + # ---- Resource ------------------------------------------------------- + resource_attrs: dict = {"service.name": config.service_name} + if config.resource_attributes: + resource_attrs.update(config.resource_attributes) + resource = Resource.create(resource_attrs) + + use_console = config.debug or not _OTLP_AVAILABLE + if config.debug is False and not _OTLP_AVAILABLE: + logger.warning( + "OTLP exporters not available; falling back to console exporters." + ) + + # ---- Traces --------------------------------------------------------- + if use_console: + span_exporter = ConsoleSpanExporter() + else: + span_exporter = OTLPSpanExporter(endpoint=config.otlp_endpoint) + + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(BatchSpanProcessor(span_exporter)) + trace.set_tracer_provider(tracer_provider) + _tracer_provider = tracer_provider + + # ---- Metrics -------------------------------------------------------- + if use_console: + metric_exporter = ConsoleMetricExporter() + else: + metric_exporter = OTLPMetricExporter(endpoint=config.otlp_endpoint) + + metric_reader = PeriodicExportingMetricReader( + metric_exporter, + export_interval_millis=config.export_interval_ms, + ) + meter_provider = MeterProvider( + resource=resource, + metric_readers=[metric_reader], + ) + metrics.set_meter_provider(meter_provider) + _meter_provider = meter_provider + + # ---- Logs ----------------------------------------------------------- + if use_console: + log_exporter = ConsoleLogExporter() + else: + log_exporter = OTLPLogExporter(endpoint=config.otlp_endpoint) + + logger_provider = LoggerProvider(resource=resource) + logger_provider.add_log_record_processor( + BatchLogRecordProcessor(log_exporter) + ) + set_logger_provider(logger_provider) + _logger_provider = logger_provider + + # Bridge Python logging -> OTEL logs. + if _LOGGING_INSTRUMENTOR_AVAILABLE: + try: + LoggingInstrumentor().instrument(set_logging_format=True) + except Exception as exc: # pragma: no cover - defensive + logger.warning("LoggingInstrumentor failed to instrument: %s", exc) + + handler = LoggingHandler( + level=logging.NOTSET, logger_provider=logger_provider + ) + logging.getLogger().addHandler(handler) + _logging_handler = handler + + _initialized = True + logger.info( + "Telemetry initialized (service=%s, debug=%s, otlp_endpoint=%s)", + config.service_name, + config.debug, + config.otlp_endpoint, + ) + + +def shutdown_telemetry() -> None: + """Shutdown all OTEL providers and detach the logging handler.""" + global _tracer_provider, _meter_provider, _logger_provider + global _logging_handler, _initialized + + if _logging_handler is not None: + try: + logging.getLogger().removeHandler(_logging_handler) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to detach logging handler: %s", exc) + _logging_handler = None + + if _tracer_provider is not None: + try: + _tracer_provider.shutdown() + except Exception as exc: # pragma: no cover - defensive + logger.warning("TracerProvider shutdown failed: %s", exc) + _tracer_provider = None + + if _meter_provider is not None: + try: + _meter_provider.shutdown() + except Exception as exc: # pragma: no cover - defensive + logger.warning("MeterProvider shutdown failed: %s", exc) + _meter_provider = None + + if _logger_provider is not None: + try: + _logger_provider.shutdown() + except Exception as exc: # pragma: no cover - defensive + logger.warning("LoggerProvider shutdown failed: %s", exc) + _logger_provider = None + + _initialized = False + + +class _NoopInstrument: + """No-op instrument for when OTEL SDK is not available.""" + def add(self, *args, **kwargs): pass + def record(self, *args, **kwargs): pass + + +class _NoopMeter: + """No-op meter for when OTEL SDK is not available.""" + def create_counter(self, *args, **kwargs): return _NoopInstrument() + def create_up_down_counter(self, *args, **kwargs): return _NoopInstrument() + def create_histogram(self, *args, **kwargs): return _NoopInstrument() + + +_noop_meter = _NoopMeter() + + +def get_meter(name: str = "twinkle-server"): + """Return an OTEL meter. Returns NoOp meter if OTEL SDK is not available.""" + if not _OTEL_AVAILABLE: + return _noop_meter + return metrics.get_meter(name) diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py new file mode 100644 index 000000000..9982af3cb --- /dev/null +++ b/src/twinkle/server/telemetry/tracing.py @@ -0,0 +1,57 @@ +"""Twinkle Server tracing utilities — thin wrapper over OpenTelemetry tracing.""" + +from __future__ import annotations + +try: + from opentelemetry import trace + from opentelemetry.propagate import inject, extract + from opentelemetry.context import Context + _OTEL_AVAILABLE = True +except Exception: + _OTEL_AVAILABLE = False + + +def get_tracer(name: str = "twinkle-server"): + """获取 tracer 实例。OTEL 未安装时返回 NoOp tracer。""" + if not _OTEL_AVAILABLE: + return _NoopTracer() + return trace.get_tracer(name) + + +def inject_context(carrier: dict) -> None: + """将当前 trace context 注入到 carrier。OTEL 未安装时为 noop。""" + if not _OTEL_AVAILABLE: + return + inject(carrier) + + +def extract_context(carrier: dict): + """从 carrier 中提取 trace context。OTEL 未安装时返回空 context。""" + if not _OTEL_AVAILABLE: + return None + return extract(carrier) + + +def get_current_span(): + """获取当前活跃的 span。OTEL 未安装时返回 noop span。""" + if not _OTEL_AVAILABLE: + return _NoopSpan() + return trace.get_current_span() + + +class _NoopSpan: + """Minimal noop span for when OTEL is not available.""" + def set_attribute(self, *args, **kwargs): pass + def set_status(self, *args, **kwargs): pass + def add_event(self, *args, **kwargs): pass + def end(self, *args, **kwargs): pass + def __enter__(self): return self + def __exit__(self, *args): pass + + +class _NoopTracer: + """Minimal noop tracer for when OTEL is not available.""" + def start_as_current_span(self, name, **kwargs): + return _NoopSpan() + def start_span(self, name, **kwargs): + return _NoopSpan() diff --git a/src/twinkle/server/utils/metrics.py b/src/twinkle/server/utils/metrics.py index eee915d78..c02da78e5 100644 --- a/src/twinkle/server/utils/metrics.py +++ b/src/twinkle/server/utils/metrics.py @@ -2,98 +2,110 @@ """ Central metrics module for Twinkle server observability. -Provides ray.util.metrics instruments that feed both the Ray Dashboard -(port 8265) and Prometheus (via /api/prometheus). +This module is a *thin adapter layer* over the OpenTelemetry instruments +declared in :class:`twinkle.server.telemetry.metrics.MetricsRegistry`. +It preserves the legacy Ray-style API (``.inc()`` / ``.set()`` / ``.observe()`` +with ``tags=`` keyword) so that existing call sites do not need to change +their call patterns, while routing every measurement through OTEL. -All metric names use the ``twinkle_`` prefix. Metric instances are -cached per deployment to avoid duplicate registration. - -Public entry-points: +Public entry-points (unchanged signatures): * ``create_metrics_middleware(deployment)`` – FastAPI HTTP middleware -* ``get_task_metrics(deployment)`` – task-queue / rate-limit gauges -* ``get_resource_metrics()`` – ServerState resource gauges +* ``get_task_metrics(deployment)`` – task-queue / rate-limit gauges """ from __future__ import annotations import time from pydantic import BaseModel, ConfigDict -from ray.util.metrics import Counter, Gauge, Histogram from typing import Any, Callable +from twinkle.server.telemetry import MetricsRegistry from twinkle.utils.logger import get_logger logger = get_logger() # --------------------------------------------------------------------------- -# Histogram bucket boundaries (seconds) – shared by all histograms +# Lazy caches – populated on first call per deployment # --------------------------------------------------------------------------- -_HISTOGRAM_BOUNDARIES = [ - 0.01, - 0.05, - 0.1, - 0.25, - 0.5, - 1.0, - 2.5, - 5.0, - 10.0, - 30.0, - 60.0, - 120.0, - 300.0, -] +_task_metrics_cache: dict[str, 'TaskMetrics'] = {} +_request_metrics_cache: dict[str, '_RequestMetrics'] = {} -# --------------------------------------------------------------------------- -# Lazy caches – populated on first call per deployment / globally -# --------------------------------------------------------------------------- -_task_metrics_cache: dict[str, TaskMetrics] = {} -_resource_metrics_cache: ResourceMetrics | None = None -_request_metrics_cache: dict[str, _RequestMetrics] = {} # --------------------------------------------------------------------------- -# Pydantic models for structured metric access +# Adapter classes – wrap OTEL instruments to expose the legacy Ray-style API +# (``.inc(tags=...)`` / ``.set(value, tags=...)`` / ``.observe(value, tags=...)``) +# while delegating all measurements to OpenTelemetry. # --------------------------------------------------------------------------- -class TaskMetrics(BaseModel): - """Task queue metrics container. +class _Counter: + """Adapter mapping ``.inc(value, tags=...)`` to ``otel_counter.add()``.""" - Attributes: - queue_depth: Current number of queued tasks. - tasks_total: Total task completions. - execution_seconds: Pure task execution time in seconds. - queue_wait_seconds: Time from enqueue to execution start. - rate_limit_rejections: Total rate-limit rejections. - rate_limiter_active_tokens: Tokens tracked by rate limiter. + def __init__(self, instrument: Any) -> None: + self._instrument = instrument + + def inc(self, value: float = 1.0, tags: dict[str, str] | None = None) -> None: + self._instrument.add(value, attributes=tags or {}) + + +class _Histogram: + """Adapter mapping ``.observe(value, tags=...)`` to ``otel_histogram.record()``.""" + + def __init__(self, instrument: Any) -> None: + self._instrument = instrument + + def observe(self, value: float, tags: dict[str, str] | None = None) -> None: + self._instrument.record(value, attributes=tags or {}) + + +class _Gauge: + """Adapter mapping ``.set(value, tags=...)`` onto an OTEL UpDownCounter. + + OpenTelemetry up/down counters take *deltas*, not absolute values, so we + track the last reported value per attribute combination and emit the + incremental change. State is held per adapter instance (= per deployment), + keyed by the frozen attribute tuple. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + def __init__(self, instrument: Any) -> None: + self._instrument = instrument + self._last: dict[tuple, float] = {} - queue_depth: Gauge - tasks_total: Counter - execution_seconds: Histogram - queue_wait_seconds: Histogram - rate_limit_rejections: Counter - rate_limiter_active_tokens: Gauge + def set(self, value: float, tags: dict[str, str] | None = None) -> None: + attrs = tags or {} + key = tuple(sorted(attrs.items())) + last = self._last.get(key, 0.0) + delta = value - last + if delta != 0: + self._instrument.add(delta, attributes=attrs) + self._last[key] = value -class ResourceMetrics(BaseModel): - """Resource gauge metrics container. +# --------------------------------------------------------------------------- +# Pydantic containers for structured metric access +# --------------------------------------------------------------------------- + + +class TaskMetrics(BaseModel): + """Task queue metrics container. Attributes: - active_sessions: Current active session count. - active_models: Current registered model count. - active_sampling_sessions: Current sampling session count. - active_futures: Current future/request count. + queue_depth: Current number of queued tasks (gauge). + tasks_total: Total task completions (counter). + execution_seconds: Pure task execution time in seconds (histogram). + queue_wait_seconds: Time from enqueue to execution start (histogram). + rate_limit_rejections: Total rate-limit rejections (counter). + rate_limiter_active_tokens: Tokens tracked by rate limiter (gauge). """ model_config = ConfigDict(arbitrary_types_allowed=True) - active_sessions: Gauge - active_models: Gauge - active_sampling_sessions: Gauge - active_futures: Gauge + queue_depth: _Gauge + tasks_total: _Counter + execution_seconds: _Histogram + queue_wait_seconds: _Histogram + rate_limit_rejections: _Counter + rate_limiter_active_tokens: _Gauge class _RequestMetrics(BaseModel): @@ -101,8 +113,8 @@ class _RequestMetrics(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - requests_total: Counter - request_duration_seconds: Histogram + requests_total: _Counter + request_duration_seconds: _Histogram # --------------------------------------------------------------------------- @@ -111,22 +123,14 @@ class _RequestMetrics(BaseModel): def _get_request_metrics(deployment: str) -> _RequestMetrics: - """Return (or create) per-deployment HTTP request metrics.""" + """Return (or create) per-deployment HTTP request metric adapters.""" if deployment in _request_metrics_cache: return _request_metrics_cache[deployment] + reg = MetricsRegistry.get() metrics = _RequestMetrics( - requests_total=Counter( - 'twinkle_requests_total', - description='Total HTTP requests.', - tag_keys=('deployment', 'method', 'status'), - ), - request_duration_seconds=Histogram( - 'twinkle_request_duration_seconds', - description='End-to-end HTTP request latency in seconds.', - boundaries=_HISTOGRAM_BOUNDARIES, - tag_keys=('deployment', 'method'), - ), + requests_total=_Counter(reg.requests_total), + request_duration_seconds=_Histogram(reg.request_duration_seconds), ) _request_metrics_cache[deployment] = metrics return metrics @@ -174,94 +178,25 @@ async def metrics_middleware(request: Any, call_next: Callable) -> Any: def get_task_metrics(deployment: str) -> TaskMetrics: - """Return (or create) per-deployment task-queue metrics. + """Return (or create) per-deployment task-queue metric adapters. - Returns a :class:`TaskMetrics` Pydantic model with: - - - ``queue_depth`` – Gauge - - ``tasks_total`` – Counter - - ``execution_seconds`` – Histogram - - ``queue_wait_seconds`` – Histogram - - ``rate_limit_rejections`` – Counter - - ``rate_limiter_active_tokens`` – Gauge + Returns a :class:`TaskMetrics` container of adapter objects; the + adapters delegate every measurement to the OTEL instruments held by + :class:`twinkle.server.telemetry.metrics.MetricsRegistry`. A separate + adapter instance is cached per deployment so that gauge-state tracking + (last value per attribute set) stays isolated. """ if deployment in _task_metrics_cache: return _task_metrics_cache[deployment] + reg = MetricsRegistry.get() metrics = TaskMetrics( - queue_depth=Gauge( - 'twinkle_task_queue_depth', - description='Current number of queued tasks.', - tag_keys=('deployment', ), - ), - tasks_total=Counter( - 'twinkle_tasks_total', - description='Total task completions.', - tag_keys=('deployment', 'task_type', 'status'), - ), - execution_seconds=Histogram( - 'twinkle_task_execution_seconds', - description='Pure task execution time in seconds.', - boundaries=_HISTOGRAM_BOUNDARIES, - tag_keys=('deployment', 'task_type'), - ), - queue_wait_seconds=Histogram( - 'twinkle_task_queue_wait_seconds', - description='Time from enqueue to execution start in seconds.', - boundaries=_HISTOGRAM_BOUNDARIES, - tag_keys=('deployment', 'task_type'), - ), - rate_limit_rejections=Counter( - 'twinkle_rate_limit_rejections_total', - description='Total rate-limit rejections.', - tag_keys=('deployment', ), - ), - rate_limiter_active_tokens=Gauge( - 'twinkle_rate_limiter_active_tokens', - description='Number of tokens tracked by the rate limiter.', - tag_keys=('deployment', ), - ), + queue_depth=_Gauge(reg.queue_depth), + tasks_total=_Counter(reg.tasks_total), + execution_seconds=_Histogram(reg.task_execution_seconds), + queue_wait_seconds=_Histogram(reg.task_wait_seconds), + rate_limit_rejections=_Counter(reg.rate_limit_rejections), + rate_limiter_active_tokens=_Gauge(reg.rate_limiter_active_tokens), ) _task_metrics_cache[deployment] = metrics return metrics - - -# --------------------------------------------------------------------------- -# D. Resource gauges (ServerState actor, updated every 15 s) -# --------------------------------------------------------------------------- - - -def get_resource_metrics() -> ResourceMetrics: - """Return (or create) global resource gauge metrics. - - Returns a :class:`ResourceMetrics` Pydantic model with: - - - ``active_sessions`` – Gauge - - ``active_models`` – Gauge - - ``active_sampling_sessions`` – Gauge - - ``active_futures`` – Gauge - """ - global _resource_metrics_cache - if _resource_metrics_cache is not None: - return _resource_metrics_cache - - metrics = ResourceMetrics( - active_sessions=Gauge( - 'twinkle_active_sessions', - description='Current active session count.', - ), - active_models=Gauge( - 'twinkle_active_models', - description='Current registered model count.', - ), - active_sampling_sessions=Gauge( - 'twinkle_active_sampling_sessions', - description='Current sampling session count.', - ), - active_futures=Gauge( - 'twinkle_active_futures', - description='Current future/request count.', - ), - ) - _resource_metrics_cache = metrics - return metrics diff --git a/src/twinkle/server/utils/task_queue/rate_limiter.py b/src/twinkle/server/utils/task_queue/rate_limiter.py index 229428801..d3370371a 100644 --- a/src/twinkle/server/utils/task_queue/rate_limiter.py +++ b/src/twinkle/server/utils/task_queue/rate_limiter.py @@ -55,7 +55,8 @@ def __init__( will be removed. Default is 10.0 (10x the window). token_cleanup_interval: How often to run the cleanup task in seconds. Default is 60.0 (every minute). - active_tokens_gauge: Optional ray.util.metrics Gauge for tracking active token count. + active_tokens_gauge: Optional gauge adapter (see twinkle.server.utils.metrics) + for tracking the active token count. deployment_name: Deployment name for metrics labels. """ self.rps_limit = rps_limit From cc7eed8946fedff521c9c1f205ff06fdecb5202e Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 29 May 2026 13:28:00 +0800 Subject: [PATCH 03/34] fix(server): remove duplicate config methods in ServerStateProxy --- src/twinkle/server/state/server_state.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index f3521981a..26b425410 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -513,26 +513,6 @@ async def store_future_status( await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, queue_state_reason) - # ----- Configuration Management ----- - - async def add_config(self, key: str, value: Any) -> None: - await self._actor.add_config.remote(key, value) - - async def add_or_get_config(self, key: str, value: Any) -> Any: - return await self._actor.add_or_get_config.remote(key, value) - - async def get_config(self, key: str) -> Any | None: - return await self._actor.get_config.remote(key) - - async def pop_config(self, key: str) -> Any | None: - return await self._actor.pop_config.remote(key) - - async def clear_config(self) -> None: - await self._actor.clear_config.remote() - - async def count_config(self) -> int: - return await self._actor.count_config.remote() - # ----- Resource Cleanup ----- async def cleanup_expired_resources(self) -> dict[str, int]: From 81bcd6d712f0a2b08d79ad7cc846bf8bf0b9668f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 29 May 2026 13:53:50 +0800 Subject: [PATCH 04/34] feat(server): add persistence layer with FileBackend, RedisBackend, and config signature - Implement FileBackend (JSON file storage with atomic write, fcntl lock, TTL) - Implement RedisBackend (redis.asyncio, optional dependency with guard) - Add PersistenceConfig + create_backend factory (memory/file/redis modes) - Adapt all Managers (Session/Model/Sampling/Future) to use StateBackend - BaseManager: async CRUD via backend with key prefix isolation - ModelManager: hybrid mode (records persisted, indexes in-memory with rebuild) - Add config signature validation (SHA256 hash, warn/clear/abort policies) - Fix ABORT policy exception propagation in get_server_state - Add comprehensive unit tests (62 passed) --- src/twinkle/server/state/__init__.py | 13 + src/twinkle/server/state/backend/__init__.py | 12 +- src/twinkle/server/state/backend/factory.py | 53 +++ .../server/state/backend/file_backend.py | 168 ++++++++ .../server/state/backend/redis_backend.py | 98 +++++ src/twinkle/server/state/base.py | 90 ++-- src/twinkle/server/state/config_signature.py | 99 +++++ src/twinkle/server/state/future_manager.py | 22 +- src/twinkle/server/state/model_manager.py | 61 ++- src/twinkle/server/state/sampling_manager.py | 14 +- src/twinkle/server/state/server_state.py | 117 ++++-- src/twinkle/server/state/session_manager.py | 29 +- tests/server/__init__.py | 0 tests/server/state/__init__.py | 0 tests/server/state/test_config_signature.py | 153 +++++++ tests/server/state/test_factory.py | 85 ++++ tests/server/state/test_file_backend.py | 211 ++++++++++ tests/server/state/test_managers.py | 383 ++++++++++++++++++ tests/server/state/test_redis_backend.py | 198 +++++++++ 19 files changed, 1703 insertions(+), 103 deletions(-) create mode 100644 src/twinkle/server/state/backend/factory.py create mode 100644 src/twinkle/server/state/backend/file_backend.py create mode 100644 src/twinkle/server/state/backend/redis_backend.py create mode 100644 src/twinkle/server/state/config_signature.py create mode 100644 tests/server/__init__.py create mode 100644 tests/server/state/__init__.py create mode 100644 tests/server/state/test_config_signature.py create mode 100644 tests/server/state/test_factory.py create mode 100644 tests/server/state/test_file_backend.py create mode 100644 tests/server/state/test_managers.py create mode 100644 tests/server/state/test_redis_backend.py diff --git a/src/twinkle/server/state/__init__.py b/src/twinkle/server/state/__init__.py index 0e34697ad..f79f39121 100644 --- a/src/twinkle/server/state/__init__.py +++ b/src/twinkle/server/state/__init__.py @@ -1,6 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .backend import PersistenceConfig, create_backend from .base import BaseManager from .config_manager import ConfigManager +from .config_signature import ( + SignatureMismatchPolicy, + compute_signature, + validate_config_signature, +) from .future_manager import FutureManager from .model_manager import ModelManager from .models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord @@ -26,4 +32,11 @@ 'ServerState', 'ServerStateProxy', 'get_server_state', + # Persistence backend factory + 'PersistenceConfig', + 'create_backend', + # Config signature validation + 'compute_signature', + 'validate_config_signature', + 'SignatureMismatchPolicy', ] diff --git a/src/twinkle/server/state/backend/__init__.py b/src/twinkle/server/state/backend/__init__.py index f3baab6f5..1007e53ae 100644 --- a/src/twinkle/server/state/backend/__init__.py +++ b/src/twinkle/server/state/backend/__init__.py @@ -1,4 +1,14 @@ from .base import StateBackend +from .factory import PersistenceConfig, create_backend +from .file_backend import FileBackend from .memory_backend import MemoryBackend +from .redis_backend import RedisBackend -__all__ = ['StateBackend', 'MemoryBackend'] +__all__ = [ + 'StateBackend', + 'FileBackend', + 'MemoryBackend', + 'RedisBackend', + 'PersistenceConfig', + 'create_backend', +] diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py new file mode 100644 index 000000000..2e03f72d7 --- /dev/null +++ b/src/twinkle/server/state/backend/factory.py @@ -0,0 +1,53 @@ +"""Backend factory for creating StateBackend instances based on configuration.""" +from __future__ import annotations + +import logging +from typing import Literal + +from pydantic import BaseModel + +from .base import StateBackend +from .memory_backend import MemoryBackend + +logger = logging.getLogger(__name__) + + +class PersistenceConfig(BaseModel): + """Configuration for state persistence backend.""" + mode: Literal['memory', 'file', 'redis'] = 'memory' + file_path: str | None = None # required for file mode + redis_url: str | None = None # required for redis mode + key_prefix: str = '' # optional global key prefix + + +def create_backend(config: PersistenceConfig | None = None) -> StateBackend: + """Create a StateBackend instance based on persistence configuration. + + Args: + config: Persistence configuration. Defaults to memory mode if None. + + Returns: + A configured StateBackend instance. + + Raises: + ValueError: If required config fields are missing for the selected mode. + ImportError: If required packages are not installed. + """ + if config is None: + config = PersistenceConfig() + + match config.mode: + case 'memory': + return MemoryBackend() + case 'file': + if not config.file_path: + raise ValueError('file_path is required for file persistence mode') + from .file_backend import FileBackend + return FileBackend(config.file_path) + case 'redis': + if not config.redis_url: + raise ValueError('redis_url is required for redis persistence mode') + from .redis_backend import RedisBackend + return RedisBackend(config.redis_url, key_prefix=config.key_prefix) + case _: + raise ValueError(f'Unknown persistence mode: {config.mode}') diff --git a/src/twinkle/server/state/backend/file_backend.py b/src/twinkle/server/state/backend/file_backend.py new file mode 100644 index 000000000..0ddb80324 --- /dev/null +++ b/src/twinkle/server/state/backend/file_backend.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import asyncio +import fcntl +import json +import os +import tempfile +import time +from fnmatch import fnmatch +from typing import Any + +from .base import StateBackend + + +class FileBackend(StateBackend): + """基于本地 JSON 文件的持久化状态后端实现。 + + 存储格式为单个 JSON 文件:``{key: {"value": ..., "expire_at": float|null}}``。 + 文件读写通过 ``asyncio.to_thread`` 包装,避免阻塞事件循环。 + 写入使用临时文件 + ``os.replace`` 原子替换,并通过 ``fcntl.flock`` 防止多进程并发写入。 + """ + + def __init__(self, file_path: str) -> None: + self._file_path = file_path + self._init_file() + + def _init_file(self) -> None: + """如果文件或目录不存在,自动创建。""" + dir_path = os.path.dirname(self._file_path) + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(self._file_path): + with open(self._file_path, 'w', encoding='utf-8') as f: + json.dump({}, f) + + def _load_sync(self) -> dict[str, dict[str, Any]]: + """同步读取 JSON 文件,返回完整 data dict。""" + try: + with open(self._file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + except (json.JSONDecodeError, FileNotFoundError): + data = {} + return data + + def _save_sync(self, data: dict[str, dict[str, Any]]) -> None: + """同步写入:清理过期键 → 写临时文件 → flock → os.replace。""" + # 写入前清理过期键 + now = time.time() + data = { + k: v for k, v in data.items() + if v.get('expire_at') is None or v['expire_at'] > now + } + + dir_path = os.path.dirname(self._file_path) or '.' + fd = tempfile.NamedTemporaryFile( + mode='w', + suffix='.tmp', + dir=dir_path, + delete=False, + encoding='utf-8', + ) + try: + json.dump(data, fd, ensure_ascii=False) + fd.flush() + os.fsync(fd.fileno()) + fd.close() + + # 对临时文件加排他锁后原子替换 + with open(fd.name, 'r') as lock_f: + fcntl.flock(lock_f.fileno(), fcntl.LOCK_EX) + os.replace(fd.name, self._file_path) + fcntl.flock(lock_f.fileno(), fcntl.LOCK_UN) + except BaseException: + # 清理临时文件 + if os.path.exists(fd.name): + os.unlink(fd.name) + raise + + async def _load(self) -> dict[str, dict[str, Any]]: + return await asyncio.to_thread(self._load_sync) + + async def _save(self, data: dict[str, dict[str, Any]]) -> None: + await asyncio.to_thread(self._save_sync, data) + + def _is_expired(self, entry: dict[str, Any]) -> bool: + """检查条目是否已过期。""" + expire_at = entry.get('expire_at') + return expire_at is not None and time.time() >= expire_at + + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """存储键值对,可选 TTL(秒)""" + expire_at = (time.time() + ttl) if ttl is not None else None + data = await self._load() + data[key] = {'value': value, 'expire_at': expire_at} + await self._save(data) + + async def get(self, key: str) -> Any | None: + """获取值,不存在或已过期返回 None""" + data = await self._load() + entry = data.get(key) + if entry is None: + return None + if self._is_expired(entry): + del data[key] + await self._save(data) + return None + return entry['value'] + + async def delete(self, key: str) -> None: + """删除键,不存在时静默忽略""" + data = await self._load() + if key in data: + del data[key] + await self._save(data) + + async def exists(self, key: str) -> bool: + """检查键是否存在且未过期""" + data = await self._load() + entry = data.get(key) + if entry is None: + return False + if self._is_expired(entry): + del data[key] + await self._save(data) + return False + return True + + async def keys(self, pattern: str) -> list[str]: + """按模式匹配返回所有键名。pattern 支持 * 通配符。""" + data = await self._load() + result: list[str] = [] + expired_keys: list[str] = [] + for key, entry in data.items(): + if self._is_expired(entry): + expired_keys.append(key) + continue + if fnmatch(key, pattern): + result.append(key) + if expired_keys: + for key in expired_keys: + del data[key] + await self._save(data) + return result + + async def count(self, pattern: str) -> int: + """按模式匹配计数""" + return len(await self.keys(pattern)) + + async def set_nx(self, key: str, value: Any) -> bool: + """Set if not exists. 返回 True 如果成功设置,False 如果键已存在且未过期。""" + data = await self._load() + entry = data.get(key) + if entry is not None and not self._is_expired(entry): + return False + data[key] = {'value': value, 'expire_at': None} + await self._save(data) + return True + + async def close(self) -> None: + """关闭后端,文件后端无需持久连接,no-op。""" + pass + + async def health_check(self) -> bool: + """检查文件路径是否可写""" + try: + return os.access(self._file_path, os.W_OK) + except OSError: + return False diff --git a/src/twinkle/server/state/backend/redis_backend.py b/src/twinkle/server/state/backend/redis_backend.py new file mode 100644 index 000000000..bcdc1a136 --- /dev/null +++ b/src/twinkle/server/state/backend/redis_backend.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import json +from typing import Any + +from .base import StateBackend + +try: + import redis.asyncio as aioredis + + _REDIS_AVAILABLE = True +except ImportError: + _REDIS_AVAILABLE = False + + +class RedisBackend(StateBackend): + """基于 Redis 的持久化状态后端实现。 + + 使用 ``redis.asyncio`` 客户端,值通过 JSON 序列化存储为 Redis string。 + TTL 由 Redis 原生 EXPIRE 机制管理。 + """ + + def __init__(self, redis_url: str, key_prefix: str = "") -> None: + if not _REDIS_AVAILABLE: + raise ImportError( + "redis package required. Install with: pip install redis" + ) + self._client = aioredis.from_url(redis_url, decode_responses=True) + self._prefix = key_prefix + + def _make_key(self, key: str) -> str: + """为 key 添加命名空间前缀。""" + return f"{self._prefix}{key}" if self._prefix else key + + def _strip_prefix(self, key: str) -> str: + """从完整 key 中移除命名空间前缀。""" + return key[len(self._prefix):] if self._prefix else key + + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """存储键值对,可选 TTL(秒)""" + real_key = self._make_key(key) + data = json.dumps(value) + if ttl is not None: + await self._client.set(real_key, data, ex=ttl) + else: + await self._client.set(real_key, data) + + async def get(self, key: str) -> Any | None: + """获取值,不存在或已过期返回 None""" + real_key = self._make_key(key) + raw = await self._client.get(real_key) + if raw is None: + return None + return json.loads(raw) + + async def delete(self, key: str) -> None: + """删除键,不存在时静默忽略""" + real_key = self._make_key(key) + await self._client.delete(real_key) + + async def exists(self, key: str) -> bool: + """检查键是否存在且未过期""" + real_key = self._make_key(key) + return bool(await self._client.exists(real_key)) + + async def keys(self, pattern: str) -> list[str]: + """按模式匹配返回所有键名。pattern 支持 * 通配符。 + + 注意:生产环境高 key 数量时建议改用 SCAN 以避免阻塞。 + """ + real_pattern = self._make_key(pattern) + raw_keys = await self._client.keys(real_pattern) + return [self._strip_prefix(k) for k in raw_keys] + + async def count(self, pattern: str) -> int: + """按模式匹配计数""" + return len(await self.keys(pattern)) + + async def set_nx(self, key: str, value: Any, ttl: int | None = None) -> bool: + """Set if not exists. 返回 True 如果成功设置,False 如果键已存在。""" + real_key = self._make_key(key) + data = json.dumps(value) + if ttl is not None: + result = await self._client.set(real_key, data, nx=True, ex=ttl) + else: + result = await self._client.set(real_key, data, nx=True) + return result is not None + + async def close(self) -> None: + """关闭 Redis 连接""" + await self._client.aclose() + + async def health_check(self) -> bool: + """检查 Redis 是否健康可用""" + try: + return await self._client.ping() + except Exception: + return False diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index c7480ec78..5c4723ed9 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -1,51 +1,85 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +import logging import time from abc import ABC, abstractmethod -from datetime import datetime -from pydantic import BaseModel +from datetime import datetime, timezone from typing import Generic, TypeVar +from pydantic import BaseModel + +from twinkle.server.state.backend.base import StateBackend + T = TypeVar('T', bound=BaseModel) +logger = logging.getLogger(__name__) class BaseManager(ABC, Generic[T]): - """ - Abstract base class for resource managers. + """Abstract base class for resource managers using StateBackend. - Provides common CRUD operations and timestamp parsing. + Provides common async CRUD operations and timestamp parsing. Subclasses must implement `cleanup_expired`. """ - def __init__(self, expiration_timeout: float) -> None: - self._store: dict[str, T] = {} + def __init__(self, backend: StateBackend, key_prefix: str, record_type: type[T], expiration_timeout: float): + self._backend = backend + self._prefix = key_prefix # e.g. "session::", "model::", "sampling::", "future::" + self._record_type = record_type self.expiration_timeout = expiration_timeout - # ----- CRUD ----- - - def add(self, resource_id: str, record: T) -> None: - """Store a record under the given ID.""" - self._store[resource_id] = record + def _make_key(self, resource_id: str) -> str: + return f"{self._prefix}{resource_id}" - def get(self, resource_id: str) -> T | None: - """Return the record for the given ID, or None.""" - return self._store.get(resource_id) + def _strip_prefix(self, key: str) -> str: + return key[len(self._prefix):] - def remove(self, resource_id: str) -> bool: - """Remove a record by ID. Returns True if it existed.""" - return self._store.pop(resource_id, None) is not None + # ----- CRUD ----- - def count(self) -> int: - """Return the number of stored records.""" - return len(self._store) + async def add(self, resource_id: str, record: T) -> None: + """Store a record in the backend.""" + await self._backend.set(self._make_key(resource_id), record.model_dump()) + + async def get(self, resource_id: str) -> T | None: + """Retrieve a record by ID.""" + data = await self._backend.get(self._make_key(resource_id)) + if data is None: + return None + return self._record_type.model_validate(data) + + async def remove(self, resource_id: str) -> bool: + """Remove a record. Returns True if it existed.""" + key = self._make_key(resource_id) + exists = await self._backend.exists(key) + if exists: + await self._backend.delete(key) + return exists + + async def count(self) -> int: + """Count all records managed by this manager.""" + return await self._backend.count(f"{self._prefix}*") + + async def keys(self) -> list[str]: + """Get all resource IDs (without prefix).""" + raw_keys = await self._backend.keys(f"{self._prefix}*") + return [self._strip_prefix(k) for k in raw_keys] + + async def get_all(self) -> dict[str, T]: + """Load all records from backend. Used for index rebuilding.""" + all_keys = await self._backend.keys(f"{self._prefix}*") + result = {} + for key in all_keys: + data = await self._backend.get(key) + if data is not None: + resource_id = self._strip_prefix(key) + result[resource_id] = self._record_type.model_validate(data) + return result # ----- Cleanup ----- @abstractmethod - def cleanup_expired(self, cutoff_time: float) -> int: - """ - Remove all records older than cutoff_time. + async def cleanup_expired(self, cutoff_time: float, **kwargs) -> int: + """Remove all records older than cutoff_time. Args: cutoff_time: Unix timestamp; records with activity before this are removed. @@ -53,6 +87,7 @@ def cleanup_expired(self, cutoff_time: float) -> int: Returns: Number of records removed. """ + ... # ----- Helpers ----- @@ -63,6 +98,9 @@ def _parse_timestamp(self, timestamp_str: str) -> float: never accidentally kept alive forever. """ try: - return datetime.fromisoformat(timestamp_str).timestamp() - except (ValueError, AttributeError): + dt = datetime.fromisoformat(timestamp_str) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.timestamp() + except (ValueError, TypeError, AttributeError): return time.time() diff --git a/src/twinkle/server/state/config_signature.py b/src/twinkle/server/state/config_signature.py new file mode 100644 index 000000000..228c927ce --- /dev/null +++ b/src/twinkle/server/state/config_signature.py @@ -0,0 +1,99 @@ +"""Configuration signature validation for state persistence integrity.""" +from __future__ import annotations + +import hashlib +import json +import logging +from enum import Enum +from typing import Any + +from twinkle.server.state.backend.base import StateBackend + +logger = logging.getLogger(__name__) + +_SIGNATURE_KEY = "_meta::config_signature" + + +class SignatureMismatchPolicy(str, Enum): + """Policy for handling config signature mismatches.""" + WARN = "warn" # Log warning and continue + CLEAR = "clear" # Clear all backend data and continue + ABORT = "abort" # Raise error, refuse to start + + +def compute_signature(config: dict[str, Any]) -> str: + """Compute a SHA256 hash of the configuration dictionary. + + Args: + config: Configuration dictionary to hash. + + Returns: + Hex string of SHA256 hash. + """ + # Sort keys for deterministic serialization + serialized = json.dumps(config, sort_keys=True, default=str) + return hashlib.sha256(serialized.encode()).hexdigest() + + +async def validate_config_signature( + backend: StateBackend, + current_config: dict[str, Any], + policy: SignatureMismatchPolicy = SignatureMismatchPolicy.WARN, +) -> bool: + """Validate configuration signature against stored value. + + Compares the current config's hash with the previously stored hash. + On first run (no stored hash), stores the current hash and returns True. + + Args: + backend: State backend to read/write signature. + current_config: Current configuration dictionary. + policy: Action to take on mismatch. + + Returns: + True if signature matches or is new. False if mismatch with WARN/CLEAR policy. + + Raises: + ConfigMismatchError: If policy is ABORT and signature doesn't match. + """ + current_sig = compute_signature(current_config) + stored_sig = await backend.get(_SIGNATURE_KEY) + + if stored_sig is None: + # First run — store signature + logger.info("No previous config signature found. Storing current signature.") + await backend.set(_SIGNATURE_KEY, current_sig) + return True + + if stored_sig == current_sig: + logger.debug("Config signature matches stored value.") + return True + + # Mismatch detected + logger.warning(f"Config signature mismatch! " + f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " + f"Policy: {policy.value}") + + if policy == SignatureMismatchPolicy.WARN: + # Update to new signature and continue + await backend.set(_SIGNATURE_KEY, current_sig) + return False + + elif policy == SignatureMismatchPolicy.CLEAR: + # Clear all data except meta keys, store new signature + logger.warning("Clearing all backend data due to config signature mismatch.") + all_keys = await backend.keys("*") + for key in all_keys: + if not key.startswith("_meta::"): + await backend.delete(key) + await backend.set(_SIGNATURE_KEY, current_sig) + return False + + elif policy == SignatureMismatchPolicy.ABORT: + from twinkle.server.exceptions import ConfigMismatchError + raise ConfigMismatchError( + f"Configuration signature mismatch. " + f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " + f"Use policy='warn' or 'clear' to allow startup with changed config.") + + return False diff --git a/src/twinkle/server/state/future_manager.py b/src/twinkle/server/state/future_manager.py index 0af069a86..331fba5cd 100644 --- a/src/twinkle/server/state/future_manager.py +++ b/src/twinkle/server/state/future_manager.py @@ -4,20 +4,23 @@ from datetime import datetime from typing import Any +from .backend.base import StateBackend from .base import BaseManager from .models import FutureRecord class FutureManager(BaseManager[FutureRecord]): - """ - Manages async task futures / request statuses. + """Manages async task futures / request statuses. Expiry is based on `updated_at` (falls back to `created_at`). """ + def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: + super().__init__(backend, "future::", FutureRecord, expiration_timeout) + # ----- Future-specific operations ----- - def store_status( + async def store_status( self, request_id: str, status: str, @@ -45,7 +48,7 @@ def store_status( result = result.model_dump() now = datetime.now().isoformat() - existing = self._store.get(request_id) + existing = await self.get(request_id) if existing is not None: existing.status = status @@ -59,8 +62,9 @@ def store_status( existing.queue_state = queue_state if queue_state_reason is not None: existing.queue_state_reason = queue_state_reason + await self.add(request_id, existing) else: - self._store[request_id] = FutureRecord( + record = FutureRecord( status=status, model_id=model_id, reason=reason, @@ -70,10 +74,11 @@ def store_status( created_at=now, updated_at=now, ) + await self.add(request_id, record) # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float) -> int: + async def cleanup_expired(self, cutoff_time: float, **kwargs) -> int: """Remove futures whose last update is older than cutoff_time. Args: @@ -82,14 +87,15 @@ def cleanup_expired(self, cutoff_time: float) -> int: Returns: Number of futures removed. """ + all_records = await self.get_all() expired_ids = [] - for request_id, record in self._store.items(): + for request_id, record in all_records.items(): timestamp_str = record.updated_at or record.created_at timestamp = self._parse_timestamp(timestamp_str) if timestamp < cutoff_time: expired_ids.append(request_id) for request_id in expired_ids: - del self._store[request_id] + await self.remove(request_id) return len(expired_ids) diff --git a/src/twinkle/server/state/model_manager.py b/src/twinkle/server/state/model_manager.py index 2d5345f7a..e006d8e43 100644 --- a/src/twinkle/server/state/model_manager.py +++ b/src/twinkle/server/state/model_manager.py @@ -1,13 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +from .backend.base import StateBackend from .base import BaseManager from .models import ModelRecord class ModelManager(BaseManager[ModelRecord]): - """ - Manages registered models. + """Manages registered models. Expiry is based on `created_at`. A model is also considered expired if its owning session has already been removed (cascade expiry). @@ -16,10 +16,14 @@ class ModelManager(BaseManager[ModelRecord]): Also tracks replica registrations so the router can query which replicas still have capacity (i.e. their loaded-model count < max_loras). + + Uses a **hybrid mode**: primary records (ModelRecord) are persisted in the + StateBackend, while derived indexes are kept in memory for fast lookups. + On startup, `rebuild_indexes()` loads all records and rebuilds the indexes. """ - def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) -> None: - super().__init__(expiration_timeout) + def __init__(self, backend: StateBackend, expiration_timeout: float, per_token_model_limit: int = 30) -> None: + super().__init__(backend, "model::", ModelRecord, expiration_timeout) self._per_token_model_limit = per_token_model_limit # token -> set of model_ids owned by that token self._token_models: dict[str, set[str]] = {} @@ -28,6 +32,24 @@ def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) - # replica_id -> max_loras limit declared at registration time self._replica_max_loras: dict[str, int] = {} + # ----- Index Rebuild ----- + + async def rebuild_indexes(self) -> None: + """Rebuild in-memory indexes from all records in the backend. + + Should be called once after startup (e.g. in ServerState.start_cleanup_task). + """ + all_records = await self.get_all() + self._token_models.clear() + self._replica_models.clear() + for model_id, record in all_records.items(): + token = record.token + self._token_models.setdefault(token, set()).add(model_id) + if record.replica_id is not None: + self._replica_models.setdefault(record.replica_id, set()).add(model_id) + + # ----- Capacity Info ----- + def get_capacity_info(self) -> dict[str, int]: """Return global LoRA capacity across all registered replicas. @@ -54,14 +76,19 @@ def register_replica(self, replica_id: str, max_loras: int) -> None: self._replica_max_loras[replica_id] = max_loras self._replica_models.setdefault(replica_id, set()) - def unregister_replica(self, replica_id: str) -> None: + async def unregister_replica(self, replica_id: str) -> None: """Remove a replica from the registry. - Any model associations for this replica are also cleared. + Any model associations for this replica are also cleared from both + the backend and the in-memory indexes. Args: replica_id: Unique identifier for the replica to remove. """ + # Remove models associated with this replica + model_ids = list(self._replica_models.get(replica_id, set())) + for model_id in model_ids: + await self.remove(model_id) self._replica_max_loras.pop(replica_id, None) self._replica_models.pop(replica_id, None) @@ -92,7 +119,7 @@ def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: # ----- CRUD ----- - def add(self, model_id: str, record: ModelRecord) -> None: + async def add(self, model_id: str, record: ModelRecord) -> None: """Store a record under the given ID. Args: @@ -107,26 +134,32 @@ def add(self, model_id: str, record: ModelRecord) -> None: if len(current_ids) >= self._per_token_model_limit: raise RuntimeError(f'Model limit exceeded: ' f'{len(current_ids)}/{self._per_token_model_limit} models') + # Persist to backend + await super().add(model_id, record) + # Update in-memory indexes self._token_models.setdefault(token, set()).add(model_id) if record.replica_id is not None: self._replica_models.setdefault(record.replica_id, set()).add(model_id) - self._store[model_id] = record - def remove(self, model_id: str) -> bool: + async def remove(self, model_id: str) -> bool: """Remove a record by ID and clean up token and replica ownership. Returns: True if the record existed and was removed, False otherwise. """ - record = self._store.pop(model_id, None) + # Get the record first for index cleanup + record = await self.get(model_id) if record is None: return False + # Remove from backend + await super().remove(model_id) + # Clean up in-memory indexes self._cleanup_ownership(model_id, record) return True # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: + async def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None, **kwargs) -> int: """Remove models that are older than cutoff_time, or whose owning session has already been expired. @@ -140,9 +173,10 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N Number of models removed. """ session_set = set(expired_session_ids or []) + all_records = await self.get_all() expired_ids = [] - for model_id, record in self._store.items(): + for model_id, record in all_records.items(): # Cascade: owner session was expired if record.session_id and record.session_id in session_set: expired_ids.append(model_id) @@ -153,8 +187,7 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N expired_ids.append(model_id) for model_id in expired_ids: - record = self._store.pop(model_id) - self._cleanup_ownership(model_id, record) + await self.remove(model_id) return len(expired_ids) diff --git a/src/twinkle/server/state/sampling_manager.py b/src/twinkle/server/state/sampling_manager.py index ff3111a6f..3d3d57ca1 100644 --- a/src/twinkle/server/state/sampling_manager.py +++ b/src/twinkle/server/state/sampling_manager.py @@ -1,21 +1,24 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +from .backend.base import StateBackend from .base import BaseManager from .models import SamplingSessionRecord class SamplingSessionManager(BaseManager[SamplingSessionRecord]): - """ - Manages sampling sessions. + """Manages sampling sessions. Expiry is based on `created_at`. A sampling session is also considered expired if its owning session has already been removed (cascade expiry). """ + def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: + super().__init__(backend, "sampling::", SamplingSessionRecord, expiration_timeout) + # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: + async def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None, **kwargs) -> int: """Remove sampling sessions that are older than cutoff_time, or whose owning session has already been expired. @@ -29,9 +32,10 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N Number of sampling sessions removed. """ session_set = set(expired_session_ids or []) + all_records = await self.get_all() expired_ids = [] - for sampling_id, record in self._store.items(): + for sampling_id, record in all_records.items(): # Cascade: owner session was expired if record.session_id and record.session_id in session_set: expired_ids.append(sampling_id) @@ -42,6 +46,6 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N expired_ids.append(sampling_id) for sampling_id in expired_ids: - del self._store[sampling_id] + await self.remove(sampling_id) return len(expired_ids) diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 26b425410..9d96d4f3a 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -9,9 +9,11 @@ from datetime import datetime from typing import Any +from twinkle.server.exceptions import ConfigMismatchError from twinkle.server.telemetry import MetricsRegistry from twinkle.utils.logger import get_logger -from .backend import MemoryBackend, StateBackend +from .backend import StateBackend +from .backend.factory import PersistenceConfig, create_backend from .config_manager import ConfigManager from .future_manager import FutureManager from .model_manager import ModelManager @@ -39,17 +41,21 @@ class ServerState: def __init__( self, backend: StateBackend | None = None, + persistence_config: PersistenceConfig | None = None, expiration_timeout: float = 86400.0, # 24 hours in seconds cleanup_interval: float = 3600.0, # 1 hour in seconds per_token_model_limit: int = 30, + signature_config: dict[str, Any] | None = None, + signature_policy: str = 'warn', **kwargs) -> None: - # Backend is currently consumed only by ConfigManager; other managers - # will be migrated in subsequent tasks. - self._backend: StateBackend = backend if backend is not None else MemoryBackend() - self._session_mgr = SessionManager(expiration_timeout) - self._model_mgr = ModelManager(expiration_timeout, per_token_model_limit) - self._sampling_mgr = SamplingSessionManager(expiration_timeout) - self._future_mgr = FutureManager(expiration_timeout) + if backend is not None: + self._backend: StateBackend = backend + else: + self._backend = create_backend(persistence_config) + self._session_mgr = SessionManager(self._backend, expiration_timeout) + self._model_mgr = ModelManager(self._backend, expiration_timeout, per_token_model_limit) + self._sampling_mgr = SamplingSessionManager(self._backend, expiration_timeout) + self._future_mgr = FutureManager(self._backend, expiration_timeout) self._config_mgr = ConfigManager(self._backend) self.expiration_timeout = expiration_timeout @@ -57,6 +63,10 @@ def __init__( self._cleanup_task: asyncio.Task | None = None self._cleanup_running = False + # Config signature validation state + self._signature_config = signature_config + self._signature_policy = signature_policy + # Metrics loop state self._metrics_task: asyncio.Task | None = None self._metrics_running = False @@ -82,7 +92,7 @@ async def create_session(self, payload: dict[str, Any]) -> str: user_metadata=payload.get('user_metadata') or {}, sdk_version=payload.get('sdk_version'), ) - self._session_mgr.add(session_id, record) + await self._session_mgr.add(session_id, record) return session_id async def touch_session(self, session_id: str) -> bool: @@ -91,7 +101,7 @@ async def touch_session(self, session_id: str) -> bool: Returns: True if the session exists and was touched, False otherwise. """ - return self._session_mgr.touch(session_id) + return await self._session_mgr.touch(session_id) async def get_session_last_heartbeat(self, session_id: str) -> float | None: """Get the last heartbeat timestamp for a session. @@ -99,7 +109,7 @@ async def get_session_last_heartbeat(self, session_id: str) -> float | None: Returns: Last heartbeat timestamp, or None if the session does not exist. """ - return self._session_mgr.get_last_heartbeat(session_id) + return await self._session_mgr.get_last_heartbeat(session_id) # ----- Model Registration ----- @@ -136,7 +146,7 @@ async def register_model(self, token=token, replica_id=replica_id, ) - self._model_mgr.add(_model_id, record) + await self._model_mgr.add(_model_id, record) return _model_id async def unload_model(self, model_id: str) -> bool: @@ -145,11 +155,11 @@ async def unload_model(self, model_id: str) -> bool: Returns: True if the model was found and removed, False otherwise. """ - return self._model_mgr.remove(model_id) + return await self._model_mgr.remove(model_id) async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: """Get metadata for a registered model as a plain dict.""" - record = self._model_mgr.get(model_id) + record = await self._model_mgr.get(model_id) return record.model_dump() if record is not None else None # ----- Replica Management ----- @@ -169,7 +179,7 @@ async def unregister_replica(self, replica_id: str) -> None: Args: replica_id: Unique identifier for the replica to remove. """ - self._model_mgr.unregister_replica(replica_id) + await self._model_mgr.unregister_replica(replica_id) async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: """Return candidate replica IDs that have not reached their max_loras limit. @@ -202,19 +212,19 @@ async def create_sampling_session(self, payload: dict[str, Any], sampling_sessio base_model=payload.get('base_model'), model_path=payload.get('model_path'), ) - self._sampling_mgr.add(_sampling_session_id, record) + await self._sampling_mgr.add(_sampling_session_id, record) return _sampling_session_id async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: """Get a sampling session by ID as a plain dict.""" - record = self._sampling_mgr.get(sampling_session_id) + record = await self._sampling_mgr.get(sampling_session_id) return record.model_dump() if record is not None else None # ----- Future Management ----- async def get_future(self, request_id: str) -> dict[str, Any] | None: """Retrieve a stored future result as a plain dict.""" - record = self._future_mgr.get(request_id) + record = await self._future_mgr.get(request_id) return record.model_dump() if record is not None else None async def store_future_status( @@ -246,7 +256,7 @@ async def store_future_status( queue_state: Optional queue state for tinker client (active/paused_rate_limit/paused_capacity). queue_state_reason: Optional reason for the queue state. """ - self._future_mgr.store_status( + await self._future_mgr.store_status( request_id=request_id, status=status, model_id=model_id, @@ -298,13 +308,13 @@ async def cleanup_expired_resources(self) -> dict[str, int]: cutoff_time = current_time - self.expiration_timeout # Collect expired session IDs first for cascade logic - expired_session_ids = self._session_mgr.get_expired_ids(cutoff_time) + expired_session_ids = await self._session_mgr.get_expired_ids(cutoff_time) # Perform actual cleanup in dependency order - sessions_removed = self._session_mgr.cleanup_expired(cutoff_time) - models_removed = self._model_mgr.cleanup_expired(cutoff_time, expired_session_ids) - samplings_removed = self._sampling_mgr.cleanup_expired(cutoff_time, expired_session_ids) - futures_removed = self._future_mgr.cleanup_expired(cutoff_time) + sessions_removed = await self._session_mgr.cleanup_expired(cutoff_time) + models_removed = await self._model_mgr.cleanup_expired(cutoff_time, expired_session_ids=expired_session_ids) + samplings_removed = await self._sampling_mgr.cleanup_expired(cutoff_time, expired_session_ids=expired_session_ids) + futures_removed = await self._future_mgr.cleanup_expired(cutoff_time) return { 'sessions': sessions_removed, @@ -336,17 +346,17 @@ async def _metrics_loop(self) -> None: """ registry = MetricsRegistry.get() sources = ( - ('active_sessions', self._session_mgr.count), - ('active_models', self._model_mgr.count), - ('active_sampling_sessions', self._sampling_mgr.count), - ('active_futures', self._future_mgr.count), + ('active_sessions', self._session_mgr), + ('active_models', self._model_mgr), + ('active_sampling_sessions', self._sampling_mgr), + ('active_futures', self._future_mgr), ) last_values: dict[str, int] = {name: 0 for name, _ in sources} while self._metrics_running: try: await asyncio.sleep(self._metrics_update_interval) - for name, counter in sources: - current = counter() + for name, mgr in sources: + current = await mgr.count() delta = current - last_values[name] if delta != 0: getattr(registry, name).add(delta) @@ -365,6 +375,8 @@ async def start_cleanup_task(self) -> bool: """ if self._cleanup_running: return False + # Rebuild in-memory indexes from backend data + await self._rebuild_indexes() self._cleanup_running = True self._cleanup_task = asyncio.create_task(self._cleanup_loop()) if not self._metrics_running: @@ -372,6 +384,24 @@ async def start_cleanup_task(self) -> bool: self._metrics_task = asyncio.create_task(self._metrics_loop()) return True + async def _rebuild_indexes(self) -> None: + """Rebuild in-memory indexes from backend data after startup. + + Also validates config signature when ``signature_config`` was provided + at construction time. + """ + # Validate config signature if provided + if self._signature_config is not None: + from twinkle.server.state.config_signature import ( + SignatureMismatchPolicy, + validate_config_signature, + ) + policy = SignatureMismatchPolicy(self._signature_policy) + await validate_config_signature(self._backend, self._signature_config, policy) + + # Rebuild model indexes + await self._model_mgr.rebuild_indexes() + async def stop_cleanup_task(self) -> bool: """Stop the background cleanup task. @@ -401,10 +431,10 @@ async def get_cleanup_stats(self) -> dict[str, Any]: 'cleanup_interval': self.cleanup_interval, 'cleanup_running': self._cleanup_running, 'resource_counts': { - 'sessions': self._session_mgr.count(), - 'models': self._model_mgr.count(), - 'sampling_sessions': self._sampling_mgr.count(), - 'futures': self._future_mgr.count(), + 'sessions': await self._session_mgr.count(), + 'models': await self._model_mgr.count(), + 'sampling_sessions': await self._sampling_mgr.count(), + 'futures': await self._future_mgr.count(), }, } @@ -536,6 +566,9 @@ async def get_cleanup_stats(self) -> dict[str, Any]: def get_server_state( actor_name: str = 'twinkle_server_state', backend: StateBackend | None = None, + persistence_config: PersistenceConfig | None = None, + signature_config: dict[str, Any] | None = None, + signature_policy: str = 'warn', **kwargs) -> ServerStateProxy: """Get or create the ServerState Ray actor. @@ -545,8 +578,10 @@ def get_server_state( Args: actor_name: Name for the Ray actor (default: 'twinkle_server_state'). backend: Optional :class:`StateBackend` injected into the ServerState - actor. When ``None`` the actor falls back to an in-process - :class:`MemoryBackend`. + actor. When ``None`` the actor falls back to ``persistence_config`` + or an in-process :class:`MemoryBackend`. + persistence_config: Optional :class:`PersistenceConfig` used to build + a backend via :func:`create_backend` when ``backend`` is None. **kwargs: Additional keyword arguments passed to ServerState constructor (e.g., expiration_timeout, cleanup_interval). @@ -559,10 +594,18 @@ def get_server_state( try: _ServerState = ray.remote(ServerState) actor = _ServerState.options(name=actor_name, lifetime='detached').remote( - backend=backend, **kwargs) + backend=backend, + persistence_config=persistence_config, + signature_config=signature_config, + signature_policy=signature_policy, + **kwargs) try: ray.get(actor.start_cleanup_task.remote()) except Exception as e: + # Ray wraps remote exceptions - check cause + cause = e.__cause__ if hasattr(e, '__cause__') and e.__cause__ else e + if isinstance(cause, ConfigMismatchError) or 'ConfigMismatchError' in type(e).__name__: + raise logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') except ValueError: actor = ray.get_actor(actor_name) diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py index e7b154cbe..67efc2ddc 100644 --- a/src/twinkle/server/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -3,42 +3,46 @@ import time +from .backend.base import StateBackend from .base import BaseManager from .models import SessionRecord class SessionManager(BaseManager[SessionRecord]): - """ - Manages client sessions. + """Manages client sessions. Expiry is based on `last_heartbeat`; falls back to `created_at` if no heartbeat has been recorded yet. """ + def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: + super().__init__(backend, "session::", SessionRecord, expiration_timeout) + # ----- Session-specific operations ----- - def touch(self, session_id: str) -> bool: + async def touch(self, session_id: str) -> bool: """Update the heartbeat timestamp for a session. Returns: True if the session exists and was updated, False otherwise. """ - record = self._store.get(session_id) + record = await self.get(session_id) if record is None: return False record.last_heartbeat = time.time() + await self.add(session_id, record) return True - def get_last_heartbeat(self, session_id: str) -> float | None: + async def get_last_heartbeat(self, session_id: str) -> float | None: """Return the last heartbeat timestamp, or None if the session does not exist.""" - record = self._store.get(session_id) + record = await self.get(session_id) if record is None: return None return record.last_heartbeat # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float) -> int: + async def cleanup_expired(self, cutoff_time: float, **kwargs) -> int: """Remove sessions whose last activity is older than cutoff_time. Args: @@ -47,28 +51,29 @@ def cleanup_expired(self, cutoff_time: float) -> int: Returns: Number of sessions removed. """ + all_records = await self.get_all() expired_ids = [] - for session_id, record in self._store.items(): + for session_id, record in all_records.items(): last_activity = record.last_heartbeat if last_activity == 0.0: - # Fallback: parse created_at last_activity = self._parse_timestamp(record.created_at) if last_activity < cutoff_time: expired_ids.append(session_id) for session_id in expired_ids: - del self._store[session_id] + await self.remove(session_id) return len(expired_ids) - def get_expired_ids(self, cutoff_time: float) -> list[str]: + async def get_expired_ids(self, cutoff_time: float) -> list[str]: """Return IDs of sessions that would be removed at the given cutoff. Used by ServerState to cascade-expire dependent resources before actually deleting the sessions. """ + all_records = await self.get_all() expired_ids = [] - for session_id, record in self._store.items(): + for session_id, record in all_records.items(): last_activity = record.last_heartbeat if last_activity == 0.0: last_activity = self._parse_timestamp(record.created_at) diff --git a/tests/server/__init__.py b/tests/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/state/__init__.py b/tests/server/state/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/state/test_config_signature.py b/tests/server/state/test_config_signature.py new file mode 100644 index 000000000..946498cb5 --- /dev/null +++ b/tests/server/state/test_config_signature.py @@ -0,0 +1,153 @@ +"""Tests for config signature validation.""" +from __future__ import annotations + +import pytest + +from twinkle.server.state.backend.memory_backend import MemoryBackend +from twinkle.server.state.config_signature import ( + SignatureMismatchPolicy, + compute_signature, + validate_config_signature, +) + + +# ---- compute_signature ---- + +def test_compute_signature_deterministic(): + """Same input should produce same output.""" + config = {"model": "qwen", "batch_size": 8} + sig1 = compute_signature(config) + sig2 = compute_signature(config) + assert sig1 == sig2 + + +def test_compute_signature_different_inputs(): + """Different inputs should produce different outputs.""" + config_a = {"model": "qwen", "batch_size": 8} + config_b = {"model": "llama", "batch_size": 8} + assert compute_signature(config_a) != compute_signature(config_b) + + +def test_compute_signature_key_order_independent(): + """Key order should not affect the signature (sort_keys=True).""" + config_a = {"b": 2, "a": 1} + config_b = {"a": 1, "b": 2} + assert compute_signature(config_a) == compute_signature(config_b) + + +def test_compute_signature_is_hex_string(): + """Signature should be a valid hex SHA256 string.""" + sig = compute_signature({"key": "value"}) + assert len(sig) == 64 # SHA256 hex = 64 chars + assert all(c in "0123456789abcdef" for c in sig) + + +# ---- validate_config_signature ---- + +@pytest.mark.asyncio +async def test_first_run_stores_signature(): + """First run with no stored sig should store it and return True.""" + backend = MemoryBackend() + config = {"model": "test"} + result = await validate_config_signature(backend, config) + assert result is True + # Signature should be stored + stored = await backend.get("_meta::config_signature") + assert stored == compute_signature(config) + + +@pytest.mark.asyncio +async def test_same_config_passes(): + """Same config on second run should pass validation.""" + backend = MemoryBackend() + config = {"model": "test", "lr": 0.001} + # First run + await validate_config_signature(backend, config) + # Second run same config + result = await validate_config_signature(backend, config) + assert result is True + + +@pytest.mark.asyncio +async def test_different_config_warn_policy(): + """Different config with WARN policy should return False and update sig.""" + backend = MemoryBackend() + config_v1 = {"model": "v1"} + config_v2 = {"model": "v2"} + + await validate_config_signature(backend, config_v1) + result = await validate_config_signature( + backend, config_v2, policy=SignatureMismatchPolicy.WARN + ) + assert result is False + # Signature should be updated to v2 + stored = await backend.get("_meta::config_signature") + assert stored == compute_signature(config_v2) + + +@pytest.mark.asyncio +async def test_different_config_clear_policy(): + """CLEAR policy should clear non-meta data, preserve _meta, return False.""" + backend = MemoryBackend() + config_v1 = {"model": "v1"} + config_v2 = {"model": "v2"} + + # Store initial config + await validate_config_signature(backend, config_v1) + # Add some user data + await backend.set("session::abc", {"data": 123}) + await backend.set("model::xyz", {"data": 456}) + await backend.set("_meta::other", "keep_this") + + result = await validate_config_signature( + backend, config_v2, policy=SignatureMismatchPolicy.CLEAR + ) + assert result is False + + # User data should be cleared + assert await backend.get("session::abc") is None + assert await backend.get("model::xyz") is None + + # _meta keys should be preserved + assert await backend.get("_meta::other") == "keep_this" + # Signature should be updated + stored = await backend.get("_meta::config_signature") + assert stored == compute_signature(config_v2) + + +@pytest.mark.asyncio +async def test_different_config_abort_policy(): + """ABORT policy should raise ConfigMismatchError.""" + from twinkle.server.exceptions import ConfigMismatchError + + backend = MemoryBackend() + config_v1 = {"model": "v1"} + config_v2 = {"model": "v2"} + + await validate_config_signature(backend, config_v1) + + with pytest.raises(ConfigMismatchError): + await validate_config_signature( + backend, config_v2, policy=SignatureMismatchPolicy.ABORT + ) + + +@pytest.mark.asyncio +async def test_abort_policy_does_not_update_signature(): + """ABORT policy should NOT update the stored signature.""" + from twinkle.server.exceptions import ConfigMismatchError + + backend = MemoryBackend() + config_v1 = {"model": "v1"} + config_v2 = {"model": "v2"} + + await validate_config_signature(backend, config_v1) + + with pytest.raises(ConfigMismatchError): + await validate_config_signature( + backend, config_v2, policy=SignatureMismatchPolicy.ABORT + ) + + # Signature should still be v1 + stored = await backend.get("_meta::config_signature") + assert stored == compute_signature(config_v1) diff --git a/tests/server/state/test_factory.py b/tests/server/state/test_factory.py new file mode 100644 index 000000000..12568b775 --- /dev/null +++ b/tests/server/state/test_factory.py @@ -0,0 +1,85 @@ +"""Tests for backend factory - create_backend function.""" +from __future__ import annotations + +import os +import tempfile + +import pytest + +from twinkle.server.state.backend.factory import PersistenceConfig, create_backend +from twinkle.server.state.backend.file_backend import FileBackend +from twinkle.server.state.backend.memory_backend import MemoryBackend + + +# ---- Memory Mode ---- + +def test_create_backend_none_returns_memory(): + """Passing None should return MemoryBackend (default mode).""" + backend = create_backend(None) + assert isinstance(backend, MemoryBackend) + + +def test_create_backend_memory_mode(): + """Explicit memory mode should return MemoryBackend.""" + config = PersistenceConfig(mode="memory") + backend = create_backend(config) + assert isinstance(backend, MemoryBackend) + + +# ---- File Mode ---- + +def test_create_backend_file_mode(): + """File mode with file_path should return FileBackend.""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + path = f.name + try: + os.unlink(path) # Let FileBackend create it + config = PersistenceConfig(mode="file", file_path=path) + backend = create_backend(config) + assert isinstance(backend, FileBackend) + finally: + if os.path.exists(path): + os.unlink(path) + + +def test_create_backend_file_mode_missing_path(): + """File mode without file_path should raise ValueError.""" + config = PersistenceConfig(mode="file") + with pytest.raises(ValueError, match="file_path"): + create_backend(config) + + +# ---- Redis Mode ---- + +def test_create_backend_redis_mode(): + """Redis mode with redis_url should return RedisBackend (if redis available).""" + try: + import redis # noqa: F401 + except ImportError: + pytest.skip("redis package not available") + + from unittest.mock import patch, MagicMock + from twinkle.server.state.backend.redis_backend import RedisBackend + + with patch("redis.asyncio.from_url", return_value=MagicMock()): + config = PersistenceConfig(mode="redis", redis_url="redis://localhost:6379") + backend = create_backend(config) + assert isinstance(backend, RedisBackend) + + +def test_create_backend_redis_mode_missing_url(): + """Redis mode without redis_url should raise ValueError.""" + config = PersistenceConfig(mode="redis") + with pytest.raises(ValueError, match="redis_url"): + create_backend(config) + + +# ---- PersistenceConfig Defaults ---- + +def test_persistence_config_defaults(): + """PersistenceConfig should have sensible defaults.""" + config = PersistenceConfig() + assert config.mode == "memory" + assert config.file_path is None + assert config.redis_url is None + assert config.key_prefix == "" diff --git a/tests/server/state/test_file_backend.py b/tests/server/state/test_file_backend.py new file mode 100644 index 000000000..5ede0cf0e --- /dev/null +++ b/tests/server/state/test_file_backend.py @@ -0,0 +1,211 @@ +"""Tests for FileBackend - JSON file-based state backend.""" +from __future__ import annotations + +import asyncio +import json +import os +import tempfile +import time + +import pytest + +from twinkle.server.state.backend.file_backend import FileBackend + + +@pytest.fixture +def tmp_file(): + """Provide a temporary file path, deleted before use so FileBackend creates fresh.""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + path = f.name + os.unlink(path) + yield path + if os.path.exists(path): + os.unlink(path) + + +# ---- Basic CRUD ---- + +@pytest.mark.asyncio +async def test_set_and_get(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("key1", {"hello": "world"}) + result = await backend.get("key1") + assert result == {"hello": "world"} + + +@pytest.mark.asyncio +async def test_get_nonexistent_key(tmp_file): + backend = FileBackend(tmp_file) + result = await backend.get("nonexistent") + assert result is None + + +@pytest.mark.asyncio +async def test_delete(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("key1", "value1") + await backend.delete("key1") + result = await backend.get("key1") + assert result is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_key(tmp_file): + """Deleting a key that doesn't exist should not raise.""" + backend = FileBackend(tmp_file) + await backend.delete("nonexistent") # Should not raise + + +@pytest.mark.asyncio +async def test_exists(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("key1", "value1") + assert await backend.exists("key1") is True + assert await backend.exists("key2") is False + + +# ---- TTL Expiry ---- + +@pytest.mark.asyncio +async def test_ttl_expiry(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("ephemeral", "data", ttl=1) + # Immediately should exist + assert await backend.get("ephemeral") == "data" + assert await backend.exists("ephemeral") is True + # Wait for expiry + time.sleep(1.1) + assert await backend.get("ephemeral") is None + assert await backend.exists("ephemeral") is False + + +@pytest.mark.asyncio +async def test_ttl_none_means_no_expiry(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("permanent", "data", ttl=None) + time.sleep(0.1) + assert await backend.get("permanent") == "data" + + +# ---- Keys Pattern Matching ---- + +@pytest.mark.asyncio +async def test_keys_wildcard(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("session::abc", "s1") + await backend.set("session::def", "s2") + await backend.set("model::xyz", "m1") + + session_keys = await backend.keys("session::*") + assert sorted(session_keys) == ["session::abc", "session::def"] + + model_keys = await backend.keys("model::*") + assert model_keys == ["model::xyz"] + + all_keys = await backend.keys("*") + assert len(all_keys) == 3 + + +@pytest.mark.asyncio +async def test_keys_excludes_expired(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("alive", "yes") + await backend.set("dying", "soon", ttl=1) + time.sleep(1.1) + keys = await backend.keys("*") + assert keys == ["alive"] + + +# ---- Count ---- + +@pytest.mark.asyncio +async def test_count(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("a::1", "v") + await backend.set("a::2", "v") + await backend.set("b::1", "v") + assert await backend.count("a::*") == 2 + assert await backend.count("b::*") == 1 + assert await backend.count("*") == 3 + + +# ---- set_nx ---- + +@pytest.mark.asyncio +async def test_set_nx_new_key(tmp_file): + backend = FileBackend(tmp_file) + result = await backend.set_nx("new_key", "value") + assert result is True + assert await backend.get("new_key") == "value" + + +@pytest.mark.asyncio +async def test_set_nx_existing_key(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("existing", "original") + result = await backend.set_nx("existing", "new_value") + assert result is False + # Value should not change + assert await backend.get("existing") == "original" + + +@pytest.mark.asyncio +async def test_set_nx_expired_key(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("expired_key", "old", ttl=1) + time.sleep(1.1) + # Key is expired, set_nx should succeed + result = await backend.set_nx("expired_key", "new_value") + assert result is True + assert await backend.get("expired_key") == "new_value" + + +# ---- Health Check ---- + +@pytest.mark.asyncio +async def test_health_check(tmp_file): + backend = FileBackend(tmp_file) + assert await backend.health_check() is True + + +# ---- Auto-create File ---- + +@pytest.mark.asyncio +async def test_auto_create_file(): + """FileBackend should create the file if it doesn't exist.""" + with tempfile.TemporaryDirectory() as tmp_dir: + path = os.path.join(tmp_dir, "subdir", "state.json") + backend = FileBackend(path) + assert os.path.exists(path) + # File should be valid JSON + with open(path, 'r') as f: + data = json.load(f) + assert data == {} + + +# ---- Atomic Write Integrity ---- + +@pytest.mark.asyncio +async def test_atomic_write_integrity(tmp_file): + """After write, reading from file should give consistent data.""" + backend = FileBackend(tmp_file) + await backend.set("k1", {"nested": [1, 2, 3]}) + await backend.set("k2", "simple_string") + + # Read raw file to verify structure + with open(tmp_file, 'r', encoding='utf-8') as f: + raw = json.load(f) + assert "k1" in raw + assert raw["k1"]["value"] == {"nested": [1, 2, 3]} + assert "k2" in raw + assert raw["k2"]["value"] == "simple_string" + + +# ---- Overwrite Value ---- + +@pytest.mark.asyncio +async def test_overwrite_value(tmp_file): + backend = FileBackend(tmp_file) + await backend.set("key", "v1") + await backend.set("key", "v2") + assert await backend.get("key") == "v2" diff --git a/tests/server/state/test_managers.py b/tests/server/state/test_managers.py new file mode 100644 index 000000000..bea3f15f3 --- /dev/null +++ b/tests/server/state/test_managers.py @@ -0,0 +1,383 @@ +"""Tests for state managers using MemoryBackend as integration backend.""" +from __future__ import annotations + +import time +from datetime import datetime, timezone + +import pytest + +from twinkle.server.state.backend.memory_backend import MemoryBackend +from twinkle.server.state.future_manager import FutureManager +from twinkle.server.state.model_manager import ModelManager +from twinkle.server.state.models import ( + FutureRecord, + ModelRecord, + SamplingSessionRecord, + SessionRecord, +) +from twinkle.server.state.sampling_manager import SamplingSessionManager +from twinkle.server.state.session_manager import SessionManager + + +# ============================================================ +# SessionManager Tests +# ============================================================ + +class TestSessionManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return SessionManager(backend=backend, expiration_timeout=300.0) + + @pytest.mark.asyncio + async def test_add_and_get(self, manager): + record = SessionRecord(tags=["test"], sdk_version="1.0") + await manager.add("sess1", record) + result = await manager.get("sess1") + assert result is not None + assert result.tags == ["test"] + assert result.sdk_version == "1.0" + + @pytest.mark.asyncio + async def test_get_nonexistent(self, manager): + result = await manager.get("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_remove(self, manager): + record = SessionRecord() + await manager.add("sess1", record) + removed = await manager.remove("sess1") + assert removed is True + assert await manager.get("sess1") is None + + @pytest.mark.asyncio + async def test_remove_nonexistent(self, manager): + removed = await manager.remove("nonexistent") + assert removed is False + + @pytest.mark.asyncio + async def test_count(self, manager): + await manager.add("s1", SessionRecord()) + await manager.add("s2", SessionRecord()) + assert await manager.count() == 2 + + @pytest.mark.asyncio + async def test_touch_updates_heartbeat(self, manager): + record = SessionRecord(last_heartbeat=1000.0) + await manager.add("sess1", record) + before = time.time() + result = await manager.touch("sess1") + after = time.time() + assert result is True + updated = await manager.get("sess1") + assert before <= updated.last_heartbeat <= after + + @pytest.mark.asyncio + async def test_touch_nonexistent(self, manager): + result = await manager.touch("nonexistent") + assert result is False + + @pytest.mark.asyncio + async def test_get_last_heartbeat(self, manager): + record = SessionRecord(last_heartbeat=12345.0) + await manager.add("sess1", record) + hb = await manager.get_last_heartbeat("sess1") + assert hb == 12345.0 + + @pytest.mark.asyncio + async def test_cleanup_expired(self, manager): + now = time.time() + # Old session + old_record = SessionRecord(last_heartbeat=now - 1000) + await manager.add("old_sess", old_record) + # Recent session + new_record = SessionRecord(last_heartbeat=now) + await manager.add("new_sess", new_record) + + cutoff = now - 500 + removed_count = await manager.cleanup_expired(cutoff) + assert removed_count == 1 + assert await manager.get("old_sess") is None + assert await manager.get("new_sess") is not None + + @pytest.mark.asyncio + async def test_cleanup_expired_uses_created_at_fallback(self, manager): + """When last_heartbeat is 0, should use created_at for expiry check.""" + old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() + record = SessionRecord(last_heartbeat=0.0, created_at=old_time) + await manager.add("old_sess", record) + + cutoff = time.time() - 100 + removed_count = await manager.cleanup_expired(cutoff) + assert removed_count == 1 + + +# ============================================================ +# ModelManager Tests +# ============================================================ + +class TestModelManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return ModelManager(backend=backend, expiration_timeout=300.0, per_token_model_limit=3) + + @pytest.mark.asyncio + async def test_add_and_get(self, manager): + record = ModelRecord(token="tok1", session_id="sess1", base_model="qwen") + await manager.add("model1", record) + result = await manager.get("model1") + assert result is not None + assert result.token == "tok1" + assert result.base_model == "qwen" + + @pytest.mark.asyncio + async def test_remove(self, manager): + record = ModelRecord(token="tok1") + await manager.add("model1", record) + removed = await manager.remove("model1") + assert removed is True + assert await manager.get("model1") is None + + @pytest.mark.asyncio + async def test_token_limit_enforced(self, manager): + """Adding more models than per_token_model_limit should raise RuntimeError.""" + for i in range(3): + await manager.add(f"m{i}", ModelRecord(token="tok1")) + + with pytest.raises(RuntimeError, match="Model limit exceeded"): + await manager.add("m3", ModelRecord(token="tok1")) + + @pytest.mark.asyncio + async def test_token_limit_per_token(self, manager): + """Limit is per-token, different tokens have separate limits.""" + for i in range(3): + await manager.add(f"a{i}", ModelRecord(token="tokenA")) + # Different token should work + await manager.add("b0", ModelRecord(token="tokenB")) + assert await manager.get("b0") is not None + + @pytest.mark.asyncio + async def test_replica_registration(self, manager): + manager.register_replica("replica1", max_loras=5) + info = manager.get_capacity_info() + assert info["max_loras"] == 5 + assert info["used_loras"] == 0 + assert info["free_loras"] == 5 + + @pytest.mark.asyncio + async def test_capacity_info_after_add(self, manager): + manager.register_replica("r1", max_loras=3) + record = ModelRecord(token="tok1", replica_id="r1") + await manager.add("m1", record) + info = manager.get_capacity_info() + assert info["used_loras"] == 1 + assert info["free_loras"] == 2 + + @pytest.mark.asyncio + async def test_rebuild_indexes(self, manager): + """rebuild_indexes should reconstruct token and replica indexes from backend.""" + record1 = ModelRecord(token="tok1", replica_id="r1") + record2 = ModelRecord(token="tok1", replica_id="r2") + await manager.add("m1", record1) + await manager.add("m2", record2) + + # Clear in-memory indexes + manager._token_models.clear() + manager._replica_models.clear() + + await manager.rebuild_indexes() + assert "m1" in manager._token_models.get("tok1", set()) + assert "m2" in manager._token_models.get("tok1", set()) + assert "m1" in manager._replica_models.get("r1", set()) + assert "m2" in manager._replica_models.get("r2", set()) + + @pytest.mark.asyncio + async def test_cascade_cleanup_by_session(self, manager): + """Models owned by expired sessions should be cleaned up.""" + now = time.time() + record = ModelRecord( + token="tok1", + session_id="expired_sess", + created_at=datetime.now(timezone.utc).isoformat(), + ) + await manager.add("m1", record) + + # Cleanup with cascade + removed = await manager.cleanup_expired( + cutoff_time=now - 10000, # cutoff is old, so age-based wouldn't trigger + expired_session_ids=["expired_sess"], + ) + assert removed == 1 + assert await manager.get("m1") is None + + @pytest.mark.asyncio + async def test_get_available_replica_ids(self, manager): + manager.register_replica("r1", max_loras=2) + manager.register_replica("r2", max_loras=1) + # Fill r2 + await manager.add("m1", ModelRecord(token="t", replica_id="r2")) + + available = manager.get_available_replica_ids(["r1", "r2", "r3_unknown"]) + # r1 has capacity, r2 is full, r3 unknown (conservative include) + assert "r1" in available + assert "r2" not in available + assert "r3_unknown" in available + + +# ============================================================ +# SamplingSessionManager Tests +# ============================================================ + +class TestSamplingSessionManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return SamplingSessionManager(backend=backend, expiration_timeout=300.0) + + @pytest.mark.asyncio + async def test_add_and_get(self, manager): + record = SamplingSessionRecord(session_id="sess1", base_model="qwen") + await manager.add("samp1", record) + result = await manager.get("samp1") + assert result is not None + assert result.session_id == "sess1" + assert result.base_model == "qwen" + + @pytest.mark.asyncio + async def test_cleanup_expired_by_age(self, manager): + old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() + record = SamplingSessionRecord(session_id="sess1", created_at=old_time) + await manager.add("samp_old", record) + + # Recent + record2 = SamplingSessionRecord(session_id="sess2") + await manager.add("samp_new", record2) + + cutoff = time.time() - 100 + removed = await manager.cleanup_expired(cutoff) + assert removed == 1 + assert await manager.get("samp_old") is None + assert await manager.get("samp_new") is not None + + @pytest.mark.asyncio + async def test_cleanup_expired_cascade(self, manager): + """Sampling sessions should be cleaned when their parent session expires.""" + record = SamplingSessionRecord(session_id="expired_sess") + await manager.add("samp1", record) + + removed = await manager.cleanup_expired( + cutoff_time=0.0, # Won't catch by age + expired_session_ids=["expired_sess"], + ) + assert removed == 1 + assert await manager.get("samp1") is None + + +# ============================================================ +# FutureManager Tests +# ============================================================ + +class TestFutureManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return FutureManager(backend=backend, expiration_timeout=300.0) + + @pytest.mark.asyncio + async def test_store_status_creates_new(self, manager): + await manager.store_status( + request_id="req1", + status="pending", + model_id="model1", + ) + result = await manager.get("req1") + assert result is not None + assert result.status == "pending" + assert result.model_id == "model1" + + @pytest.mark.asyncio + async def test_store_status_updates_existing(self, manager): + await manager.store_status( + request_id="req1", status="pending", model_id="m1" + ) + await manager.store_status( + request_id="req1", status="completed", model_id="m1", result={"output": "done"} + ) + result = await manager.get("req1") + assert result.status == "completed" + assert result.result == {"output": "done"} + + @pytest.mark.asyncio + async def test_store_status_with_pydantic_result(self, manager): + """Pydantic models should be serialized via model_dump.""" + from pydantic import BaseModel + + class MockResult(BaseModel): + score: float = 0.95 + + await manager.store_status( + request_id="req1", status="completed", model_id="m1", + result=MockResult() + ) + result = await manager.get("req1") + assert result.result == {"score": 0.95} + + @pytest.mark.asyncio + async def test_store_status_preserves_reason(self, manager): + await manager.store_status( + request_id="req1", status="rate_limited", model_id=None, + reason="Too many requests" + ) + result = await manager.get("req1") + assert result.reason == "Too many requests" + + @pytest.mark.asyncio + async def test_cleanup_expired(self, manager): + old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() + old_record = FutureRecord( + status="completed", created_at=old_time, updated_at=old_time + ) + await manager.add("old_req", old_record) + + new_record = FutureRecord(status="pending") + await manager.add("new_req", new_record) + + cutoff = time.time() - 100 + removed = await manager.cleanup_expired(cutoff) + assert removed == 1 + assert await manager.get("old_req") is None + assert await manager.get("new_req") is not None + + @pytest.mark.asyncio + async def test_get_nonexistent(self, manager): + result = await manager.get("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_store_status_queue_state(self, manager): + await manager.store_status( + request_id="req1", status="queued", model_id="m1", + queue_state="paused_rate_limit", + queue_state_reason="Rate limit hit" + ) + result = await manager.get("req1") + assert result.queue_state == "paused_rate_limit" + assert result.queue_state_reason == "Rate limit hit" diff --git a/tests/server/state/test_redis_backend.py b/tests/server/state/test_redis_backend.py new file mode 100644 index 000000000..dde5c1af7 --- /dev/null +++ b/tests/server/state/test_redis_backend.py @@ -0,0 +1,198 @@ +"""Tests for RedisBackend - using mocks since no real Redis is available.""" +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Skip entire module if redis package not available +redis = pytest.importorskip("redis") + +from twinkle.server.state.backend.redis_backend import RedisBackend + + +@pytest.fixture +def mock_redis_client(): + """Create a mock redis.asyncio client.""" + client = AsyncMock() + client.set = AsyncMock(return_value=True) + client.get = AsyncMock(return_value=None) + client.delete = AsyncMock(return_value=1) + client.exists = AsyncMock(return_value=1) + client.keys = AsyncMock(return_value=[]) + client.ping = AsyncMock(return_value=True) + client.aclose = AsyncMock() + return client + + +@pytest.fixture +def backend_no_prefix(mock_redis_client): + """RedisBackend with no key prefix.""" + with patch("redis.asyncio.from_url", return_value=mock_redis_client): + backend = RedisBackend("redis://localhost:6379") + return backend + + +@pytest.fixture +def backend_with_prefix(mock_redis_client): + """RedisBackend with key prefix.""" + with patch("redis.asyncio.from_url", return_value=mock_redis_client): + backend = RedisBackend("redis://localhost:6379", key_prefix="twinkle:") + return backend + + +# ---- SET ---- + +@pytest.mark.asyncio +async def test_set_without_ttl(backend_no_prefix, mock_redis_client): + await backend_no_prefix.set("mykey", {"data": 123}) + mock_redis_client.set.assert_called_once_with("mykey", json.dumps({"data": 123})) + + +@pytest.mark.asyncio +async def test_set_with_ttl(backend_no_prefix, mock_redis_client): + await backend_no_prefix.set("mykey", "value", ttl=60) + mock_redis_client.set.assert_called_once_with("mykey", json.dumps("value"), ex=60) + + +@pytest.mark.asyncio +async def test_set_with_prefix(backend_with_prefix, mock_redis_client): + await backend_with_prefix.set("mykey", "val") + mock_redis_client.set.assert_called_once_with("twinkle:mykey", json.dumps("val")) + + +# ---- GET ---- + +@pytest.mark.asyncio +async def test_get_existing_key(backend_no_prefix, mock_redis_client): + mock_redis_client.get.return_value = json.dumps({"hello": "world"}) + result = await backend_no_prefix.get("mykey") + mock_redis_client.get.assert_called_once_with("mykey") + assert result == {"hello": "world"} + + +@pytest.mark.asyncio +async def test_get_nonexistent_key(backend_no_prefix, mock_redis_client): + mock_redis_client.get.return_value = None + result = await backend_no_prefix.get("missing") + assert result is None + + +@pytest.mark.asyncio +async def test_get_with_prefix(backend_with_prefix, mock_redis_client): + mock_redis_client.get.return_value = json.dumps("data") + result = await backend_with_prefix.get("mykey") + mock_redis_client.get.assert_called_once_with("twinkle:mykey") + assert result == "data" + + +# ---- DELETE ---- + +@pytest.mark.asyncio +async def test_delete(backend_no_prefix, mock_redis_client): + await backend_no_prefix.delete("mykey") + mock_redis_client.delete.assert_called_once_with("mykey") + + +@pytest.mark.asyncio +async def test_delete_with_prefix(backend_with_prefix, mock_redis_client): + await backend_with_prefix.delete("mykey") + mock_redis_client.delete.assert_called_once_with("twinkle:mykey") + + +# ---- EXISTS ---- + +@pytest.mark.asyncio +async def test_exists_true(backend_no_prefix, mock_redis_client): + mock_redis_client.exists.return_value = 1 + result = await backend_no_prefix.exists("mykey") + assert result is True + mock_redis_client.exists.assert_called_once_with("mykey") + + +@pytest.mark.asyncio +async def test_exists_false(backend_no_prefix, mock_redis_client): + mock_redis_client.exists.return_value = 0 + result = await backend_no_prefix.exists("mykey") + assert result is False + + +# ---- KEYS ---- + +@pytest.mark.asyncio +async def test_keys_pattern(backend_no_prefix, mock_redis_client): + mock_redis_client.keys.return_value = ["session::a", "session::b"] + result = await backend_no_prefix.keys("session::*") + mock_redis_client.keys.assert_called_once_with("session::*") + assert result == ["session::a", "session::b"] + + +@pytest.mark.asyncio +async def test_keys_with_prefix(backend_with_prefix, mock_redis_client): + mock_redis_client.keys.return_value = ["twinkle:session::a", "twinkle:session::b"] + result = await backend_with_prefix.keys("session::*") + mock_redis_client.keys.assert_called_once_with("twinkle:session::*") + # Result should have prefix stripped + assert result == ["session::a", "session::b"] + + +# ---- COUNT ---- + +@pytest.mark.asyncio +async def test_count(backend_no_prefix, mock_redis_client): + mock_redis_client.keys.return_value = ["a", "b", "c"] + result = await backend_no_prefix.count("*") + assert result == 3 + + +# ---- SET_NX ---- + +@pytest.mark.asyncio +async def test_set_nx_success(backend_no_prefix, mock_redis_client): + mock_redis_client.set.return_value = True # nx succeeded + result = await backend_no_prefix.set_nx("newkey", {"value": 1}) + mock_redis_client.set.assert_called_once_with("newkey", json.dumps({"value": 1}), nx=True) + assert result is True + + +@pytest.mark.asyncio +async def test_set_nx_failure(backend_no_prefix, mock_redis_client): + mock_redis_client.set.return_value = None # nx failed (key exists) + result = await backend_no_prefix.set_nx("existing", "val") + assert result is False + + +# ---- HEALTH CHECK ---- + +@pytest.mark.asyncio +async def test_health_check_healthy(backend_no_prefix, mock_redis_client): + mock_redis_client.ping.return_value = True + result = await backend_no_prefix.health_check() + assert result is True + + +@pytest.mark.asyncio +async def test_health_check_unhealthy(backend_no_prefix, mock_redis_client): + mock_redis_client.ping.side_effect = ConnectionError("offline") + result = await backend_no_prefix.health_check() + assert result is False + + +# ---- CLOSE ---- + +@pytest.mark.asyncio +async def test_close(backend_no_prefix, mock_redis_client): + await backend_no_prefix.close() + mock_redis_client.aclose.assert_called_once() + + +# ---- JSON Serialization ---- + +@pytest.mark.asyncio +async def test_json_serialization_complex_types(backend_no_prefix, mock_redis_client): + """Values should be JSON-serialized before storage.""" + complex_value = {"list": [1, 2, 3], "nested": {"key": "val"}, "null": None} + await backend_no_prefix.set("complex", complex_value) + expected_json = json.dumps(complex_value) + mock_redis_client.set.assert_called_once_with("complex", expected_json) From b8a27b3e4d25951ca03a081166032888615beeda Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 29 May 2026 17:09:31 +0800 Subject: [PATCH 05/34] feat(server): integrate telemetry into Ray Serve workers and fix persistence config parsing --- .../server/transformer/server_config.yaml | 9 ++++ .../twinkle/self_host/self_cognition.py | 14 ++++- src/twinkle/server/exceptions.py | 8 +-- src/twinkle/server/gateway/server.py | 12 ++++- src/twinkle/server/launcher.py | 52 +++++++++++++++++++ src/twinkle/server/model/app.py | 12 +++++ src/twinkle/server/processor/app.py | 12 +++++ src/twinkle/server/sampler/app.py | 12 +++++ src/twinkle/server/state/backend/base.py | 22 ++++---- .../server/state/backend/file_backend.py | 40 +++++++------- .../server/state/backend/memory_backend.py | 28 +++++----- .../server/state/backend/redis_backend.py | 30 +++++------ src/twinkle/server/state/server_state.py | 4 ++ src/twinkle/server/telemetry/__init__.py | 2 + src/twinkle/server/telemetry/metrics.py | 14 ++--- src/twinkle/server/telemetry/provider.py | 31 ++++++++++- src/twinkle/server/telemetry/tracing.py | 8 +-- src/twinkle/server/telemetry/worker_init.py | 49 +++++++++++++++++ 18 files changed, 280 insertions(+), 79 deletions(-) create mode 100644 src/twinkle/server/telemetry/worker_init.py diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 0a97a8896..1c12415c1 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -9,6 +9,12 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 8000 # Port number for the server +# Telemetry configuration for observability (OpenTelemetry-based) +telemetry_config: + enabled: true + debug: true + service_name: twinkle-server + # Applications: each entry defines a service component deployed on the server applications: @@ -20,6 +26,9 @@ applications: args: server_config: per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) + persistence_config: + mode: file + file_path: /tmp/twinkle_state.json supported_models: - Qwen/Qwen3.5-4B deployments: diff --git a/cookbook/client/twinkle/self_host/self_cognition.py b/cookbook/client/twinkle/self_host/self_cognition.py index c5b771aac..997a77322 100644 --- a/cookbook/client/twinkle/self_host/self_cognition.py +++ b/cookbook/client/twinkle/self_host/self_cognition.py @@ -24,7 +24,7 @@ base_model = 'Qwen/Qwen3.5-4B' base_url = 'http://localhost:8000' api_key = 'EMPTY_API_KEY' -save_dir = '/model' +save_dir = '/tmp/twinkle_sft_output' # Step 2: Initialize the Twinkle client to communicate with the remote server. @@ -108,8 +108,10 @@ def train(): start_step = progress['cur_step'] # Step 7: Run the training loop + max_steps = 10 # Limit to 10 steps for quick verification logger.info(model.get_train_configs().model_dump()) + global_step = 0 for epoch in range(3): logger.info(f'Starting epoch {epoch}') for cur_step, batch in enumerate(dataloader, start=start_step + 1): @@ -128,12 +130,22 @@ def train(): # # Advance the learning rate scheduler by one step # model.lr_step() + global_step += 1 + # Log the loss every 2 steps (aligned with gradient accumulation) if cur_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') + # Stop after max_steps + if global_step >= max_steps: + logger.info(f'Reached max_steps={max_steps}, stopping training.') + break + + if global_step >= max_steps: + break + # Step 8: Save the trained checkpoint twinkle_path = model.save( name=f'twinkle-epoch-{epoch}', diff --git a/src/twinkle/server/exceptions.py b/src/twinkle/server/exceptions.py index f96d42500..649c40b59 100644 --- a/src/twinkle/server/exceptions.py +++ b/src/twinkle/server/exceptions.py @@ -4,20 +4,20 @@ class TwinkleServerError(Exception): - """所有 Twinkle Server 异常的基类""" + """Base class for all Twinkle Server exceptions.""" pass class StateBackendError(TwinkleServerError): - """状态后端操作失败(连接断开、超时、数据序列化错误等)""" + """State backend operation failed (connection lost, timeout, data serialization error, etc.).""" pass class ConfigMismatchError(TwinkleServerError): - """配置签名不匹配 — 重启后检测到配置变更,持久化数据可能与当前配置不兼容""" + """Configuration signature mismatch — config changes detected after restart, persisted data may be incompatible with current configuration.""" pass class ResourceExhaustedError(TwinkleServerError): - """资源耗尽 — 队列满、内存不足、连接池耗尽等""" + """Resource exhausted — queue full, insufficient memory, connection pool exhausted, etc.""" pass diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 92897af6e..45eace9fa 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -90,7 +90,6 @@ def build_server_app(deploy_options: dict[str, Any], Returns: Configured Ray Serve deployment bound with options """ - def get_self() -> GatewayServer: return serve.get_replica_context().servable_object @@ -104,6 +103,17 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + @app.on_event('startup') + async def _init_telemetry_and_instrument(): + """Initialize telemetry and instrument app in worker process (after deserialization).""" + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + try: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) + except ImportError: + pass # OTEL instrumentation not installed + @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index e2a6179ac..ed4880ffb 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -69,6 +69,24 @@ def __init__( self._ray_initialized = False self._serve_started = False + # Telemetry env var keys that need to be propagated to Ray worker processes + _TELEMETRY_ENV_KEYS: tuple[str, ...] = ( + 'TWINKLE_TELEMETRY_ENABLED', + 'TWINKLE_TELEMETRY_DEBUG', + 'TWINKLE_TELEMETRY_SERVICE', + 'TWINKLE_TELEMETRY_ENDPOINT', + 'TWINKLE_TELEMETRY_INTERVAL', + ) + + def _build_telemetry_env_vars(self) -> dict[str, str]: + """Collect telemetry env vars from os.environ for propagation to Ray workers. + + These vars are read by ``ensure_telemetry_initialized()`` inside the + FastAPI startup hook running in each worker process. + """ + import os + return {k: os.environ[k] for k in self._TELEMETRY_ENV_KEYS if k in os.environ} + def _get_builders(self) -> dict[str, Callable]: """Get the builder functions for all app types.""" if self._builders: @@ -129,6 +147,12 @@ def _init_ray(self) -> None: # Use runtime_env to apply patches in worker processes # This is required because Ray Serve's ProxyActor runs in separate processes runtime_env = get_runtime_env_for_patches() + # Propagate telemetry env vars to all Ray workers spawned in this job + telemetry_env_vars = self._build_telemetry_env_vars() + if telemetry_env_vars: + merged_env_vars = dict(runtime_env.get('env_vars') or {}) + merged_env_vars.update(telemetry_env_vars) + runtime_env['env_vars'] = merged_env_vars # Connect to existing cluster if available, otherwise start local instance ray.init( address='auto', @@ -193,6 +217,20 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: if isinstance(deploy_config, dict): deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} + # Inject telemetry env vars into the deployment's runtime_env so that + # Ray Serve replicas (worker processes) can initialize telemetry. + # User-specified env_vars take precedence to avoid overriding existing config. + telemetry_env_vars = self._build_telemetry_env_vars() + if telemetry_env_vars: + ray_actor_options = dict(deploy_options.get('ray_actor_options') or {}) + runtime_env = dict(ray_actor_options.get('runtime_env') or {}) + env_vars = dict(runtime_env.get('env_vars') or {}) + for k, v in telemetry_env_vars.items(): + env_vars.setdefault(k, v) + runtime_env['env_vars'] = env_vars + ray_actor_options['runtime_env'] = runtime_env + deploy_options['ray_actor_options'] = ray_actor_options + # Pass http_options to server apps for internal proxy routing http_options = self.config.get('http_options', {}) if import_path == 'server' and http_options: @@ -213,6 +251,20 @@ def launch(self) -> None: # Apply Ray Serve patches before initializing Ray apply_ray_serve_patches() + # Initialize telemetry if configured + telemetry_config = self.config.get('telemetry_config', {}) + if telemetry_config: + from twinkle.server.telemetry import TelemetryConfig, init_telemetry + config = TelemetryConfig(**telemetry_config) + init_telemetry(config) + # Export config to env vars for Ray worker processes + import os + os.environ['TWINKLE_TELEMETRY_ENABLED'] = '1' + os.environ['TWINKLE_TELEMETRY_DEBUG'] = '1' if config.debug else '0' + os.environ['TWINKLE_TELEMETRY_SERVICE'] = config.service_name + os.environ['TWINKLE_TELEMETRY_ENDPOINT'] = config.otlp_endpoint + os.environ['TWINKLE_TELEMETRY_INTERVAL'] = str(config.export_interval_ms) + self._init_ray() self._start_serve() diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index e8da0078e..5c01b7f71 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -158,6 +158,7 @@ def build_model_app(model_id: str, # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + def get_self() -> ModelManagement: return serve.get_replica_context().servable_object @@ -175,6 +176,17 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + @app.on_event('startup') + async def _init_telemetry_and_instrument(): + """Initialize telemetry and instrument app in worker process (after deserialization).""" + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + try: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) + except ImportError: + pass # OTEL instrumentation not installed + @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 1db7c59dc..743eb7517 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -119,8 +119,20 @@ def build_processor_app(ncpu_proc_per_node: int, """ # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + app = FastAPI() + @app.on_event('startup') + async def _init_telemetry_and_instrument(): + """Initialize telemetry and instrument app in worker process (after deserialization).""" + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + try: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) + except ImportError: + pass # OTEL instrumentation not installed + @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index c1a283bd6..be77754ef 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -117,11 +117,23 @@ def build_sampler_app(model_id: str, """ # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + app = FastAPI( title='Unified Sampler', description='REST API for distributed text generation inference (Tinker + Twinkle)', version='1.0.0') + @app.on_event('startup') + async def _init_telemetry_and_instrument(): + """Initialize telemetry and instrument app in worker process (after deserialization).""" + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + try: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) + except ImportError: + pass # OTEL instrumentation not installed + @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) diff --git a/src/twinkle/server/state/backend/base.py b/src/twinkle/server/state/backend/base.py index 6a44d7944..2a247ed22 100644 --- a/src/twinkle/server/state/backend/base.py +++ b/src/twinkle/server/state/backend/base.py @@ -5,52 +5,52 @@ class StateBackend(ABC): - """状态存储后端的统一接口。 + """Unified interface for state storage backends. - 所有状态管理操作通过此接口进行,支持多种后端实现(内存、文件、Redis)。 + All state management operations go through this interface, supporting multiple backend implementations (memory, file, Redis). """ @abstractmethod async def set(self, key: str, value: Any, ttl: int | None = None) -> None: - """存储键值对,可选 TTL(秒)""" + """Store key-value pair with optional TTL in seconds.""" ... @abstractmethod async def get(self, key: str) -> Any | None: - """获取值,不存在或已过期返回 None""" + """Retrieve value, return None if not found or expired.""" ... @abstractmethod async def delete(self, key: str) -> None: - """删除键,不存在时静默忽略""" + """Delete key, silently ignore if not found.""" ... @abstractmethod async def exists(self, key: str) -> bool: - """检查键是否存在且未过期""" + """Check if key exists and is not expired.""" ... @abstractmethod async def keys(self, pattern: str) -> list[str]: - """按模式匹配返回所有键名。pattern 支持 * 通配符(如 'session::*')""" + """Return all key names matching the pattern. Supports * wildcard (e.g. 'session::*').""" ... @abstractmethod async def count(self, pattern: str) -> int: - """按模式匹配计数""" + """Count keys matching the pattern.""" ... @abstractmethod async def set_nx(self, key: str, value: Any) -> bool: - """Set if not exists. 返回 True 如果成功设置,False 如果键已存在""" + """Set if not exists. Return True if successfully set, False if key already exists.""" ... @abstractmethod async def close(self) -> None: - """关闭后端连接/释放资源""" + """Close backend connection / release resources.""" ... @abstractmethod async def health_check(self) -> bool: - """检查后端是否健康可用""" + """Check if backend is healthy and available.""" ... diff --git a/src/twinkle/server/state/backend/file_backend.py b/src/twinkle/server/state/backend/file_backend.py index 0ddb80324..5e306832e 100644 --- a/src/twinkle/server/state/backend/file_backend.py +++ b/src/twinkle/server/state/backend/file_backend.py @@ -13,11 +13,11 @@ class FileBackend(StateBackend): - """基于本地 JSON 文件的持久化状态后端实现。 + """Local JSON file-based persistent state backend implementation. - 存储格式为单个 JSON 文件:``{key: {"value": ..., "expire_at": float|null}}``。 - 文件读写通过 ``asyncio.to_thread`` 包装,避免阻塞事件循环。 - 写入使用临时文件 + ``os.replace`` 原子替换,并通过 ``fcntl.flock`` 防止多进程并发写入。 + Storage format is a single JSON file: ``{key: {"value": ..., "expire_at": float|null}}``. + File I/O is wrapped with ``asyncio.to_thread`` to avoid blocking the event loop. + Writes use temp file + ``os.replace`` for atomic replacement, protected by ``fcntl.flock`` against multi-process concurrent writes. """ def __init__(self, file_path: str) -> None: @@ -25,7 +25,7 @@ def __init__(self, file_path: str) -> None: self._init_file() def _init_file(self) -> None: - """如果文件或目录不存在,自动创建。""" + """Auto-create file or directory if not exists.""" dir_path = os.path.dirname(self._file_path) if dir_path and not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) @@ -34,7 +34,7 @@ def _init_file(self) -> None: json.dump({}, f) def _load_sync(self) -> dict[str, dict[str, Any]]: - """同步读取 JSON 文件,返回完整 data dict。""" + """Synchronously read JSON file, return complete data dict.""" try: with open(self._file_path, 'r', encoding='utf-8') as f: data = json.load(f) @@ -43,8 +43,8 @@ def _load_sync(self) -> dict[str, dict[str, Any]]: return data def _save_sync(self, data: dict[str, dict[str, Any]]) -> None: - """同步写入:清理过期键 → 写临时文件 → flock → os.replace。""" - # 写入前清理过期键 + """Synchronous write: clean expired keys -> write temp file -> flock -> os.replace.""" + # Clean expired keys before writing now = time.time() data = { k: v for k, v in data.items() @@ -65,13 +65,13 @@ def _save_sync(self, data: dict[str, dict[str, Any]]) -> None: os.fsync(fd.fileno()) fd.close() - # 对临时文件加排他锁后原子替换 + # Apply exclusive lock to temp file then atomic replace with open(fd.name, 'r') as lock_f: fcntl.flock(lock_f.fileno(), fcntl.LOCK_EX) os.replace(fd.name, self._file_path) fcntl.flock(lock_f.fileno(), fcntl.LOCK_UN) except BaseException: - # 清理临时文件 + # Clean up temp file if os.path.exists(fd.name): os.unlink(fd.name) raise @@ -83,19 +83,19 @@ async def _save(self, data: dict[str, dict[str, Any]]) -> None: await asyncio.to_thread(self._save_sync, data) def _is_expired(self, entry: dict[str, Any]) -> bool: - """检查条目是否已过期。""" + """Check if entry is expired.""" expire_at = entry.get('expire_at') return expire_at is not None and time.time() >= expire_at async def set(self, key: str, value: Any, ttl: int | None = None) -> None: - """存储键值对,可选 TTL(秒)""" + """Store key-value pair with optional TTL in seconds.""" expire_at = (time.time() + ttl) if ttl is not None else None data = await self._load() data[key] = {'value': value, 'expire_at': expire_at} await self._save(data) async def get(self, key: str) -> Any | None: - """获取值,不存在或已过期返回 None""" + """Retrieve value, return None if not found or expired.""" data = await self._load() entry = data.get(key) if entry is None: @@ -107,14 +107,14 @@ async def get(self, key: str) -> Any | None: return entry['value'] async def delete(self, key: str) -> None: - """删除键,不存在时静默忽略""" + """Delete key, silently ignore if not found.""" data = await self._load() if key in data: del data[key] await self._save(data) async def exists(self, key: str) -> bool: - """检查键是否存在且未过期""" + """Check if key exists and is not expired.""" data = await self._load() entry = data.get(key) if entry is None: @@ -126,7 +126,7 @@ async def exists(self, key: str) -> bool: return True async def keys(self, pattern: str) -> list[str]: - """按模式匹配返回所有键名。pattern 支持 * 通配符。""" + """Return all key names matching the pattern. Supports * wildcard.""" data = await self._load() result: list[str] = [] expired_keys: list[str] = [] @@ -143,11 +143,11 @@ async def keys(self, pattern: str) -> list[str]: return result async def count(self, pattern: str) -> int: - """按模式匹配计数""" + """Count keys matching the pattern.""" return len(await self.keys(pattern)) async def set_nx(self, key: str, value: Any) -> bool: - """Set if not exists. 返回 True 如果成功设置,False 如果键已存在且未过期。""" + """Set if not exists. Return True if successfully set, False if key already exists and is not expired.""" data = await self._load() entry = data.get(key) if entry is not None and not self._is_expired(entry): @@ -157,11 +157,11 @@ async def set_nx(self, key: str, value: Any) -> bool: return True async def close(self) -> None: - """关闭后端,文件后端无需持久连接,no-op。""" + """Close backend. File backend requires no persistent connection, no-op.""" pass async def health_check(self) -> bool: - """检查文件路径是否可写""" + """Check if file path is writable.""" try: return os.access(self._file_path, os.W_OK) except OSError: diff --git a/src/twinkle/server/state/backend/memory_backend.py b/src/twinkle/server/state/backend/memory_backend.py index a960f1e78..3bbde4d15 100644 --- a/src/twinkle/server/state/backend/memory_backend.py +++ b/src/twinkle/server/state/backend/memory_backend.py @@ -8,17 +8,17 @@ class MemoryBackend(StateBackend): - """基于内存字典的状态后端实现。 + """In-memory dictionary-based state backend implementation. - 使用 ``dict[str, tuple[Any, float | None]]`` 存储 (value, expire_at)。 - 过期检查在 get/exists 时进行(惰性过期),适用于 Ray Actor 单线程模型。 + Uses ``dict[str, tuple[Any, float | None]]`` to store (value, expire_at). + Expiration is checked lazily during get/exists, suitable for Ray Actor single-threaded model. """ def __init__(self) -> None: self._store: dict[str, tuple[Any, float | None]] = {} def _is_expired(self, key: str) -> bool: - """检查键是否已过期。如已过期则删除并返回 True。""" + """Check if key is expired. If expired, delete and return True.""" entry = self._store.get(key) if entry is None: return True @@ -29,29 +29,29 @@ def _is_expired(self, key: str) -> bool: return False async def set(self, key: str, value: Any, ttl: int | None = None) -> None: - """存储键值对,可选 TTL(秒)""" + """Store key-value pair with optional TTL in seconds.""" expire_at = (time.time() + ttl) if ttl is not None else None self._store[key] = (value, expire_at) async def get(self, key: str) -> Any | None: - """获取值,不存在或已过期返回 None""" + """Retrieve value, return None if not found or expired.""" if self._is_expired(key): return None value, _ = self._store[key] return value async def delete(self, key: str) -> None: - """删除键,不存在时静默忽略""" + """Delete key, silently ignore if not found.""" self._store.pop(key, None) async def exists(self, key: str) -> bool: - """检查键是否存在且未过期""" + """Check if key exists and is not expired.""" return not self._is_expired(key) async def keys(self, pattern: str) -> list[str]: - """按模式匹配返回所有键名。pattern 支持 * 通配符。""" + """Return all key names matching the pattern. Supports * wildcard.""" result: list[str] = [] - # 遍历时收集过期键,避免在迭代中修改字典 + # Collect expired keys during iteration to avoid modifying dict while iterating expired_keys: list[str] = [] for key, (_, expire_at) in self._store.items(): if expire_at is not None and time.time() >= expire_at: @@ -64,20 +64,20 @@ async def keys(self, pattern: str) -> list[str]: return result async def count(self, pattern: str) -> int: - """按模式匹配计数""" + """Count keys matching the pattern.""" return len(await self.keys(pattern)) async def set_nx(self, key: str, value: Any) -> bool: - """Set if not exists. 返回 True 如果成功设置,False 如果键已存在。""" + """Set if not exists. Return True if successfully set, False if key already exists.""" if not self._is_expired(key): return False self._store[key] = (value, None) return True async def close(self) -> None: - """关闭后端,清空存储""" + """Close backend, clear storage.""" self._store.clear() async def health_check(self) -> bool: - """检查后端是否健康可用,内存后端始终返回 True""" + """Check if backend is healthy and available. Memory backend always returns True.""" return True diff --git a/src/twinkle/server/state/backend/redis_backend.py b/src/twinkle/server/state/backend/redis_backend.py index bcdc1a136..a571f8421 100644 --- a/src/twinkle/server/state/backend/redis_backend.py +++ b/src/twinkle/server/state/backend/redis_backend.py @@ -14,10 +14,10 @@ class RedisBackend(StateBackend): - """基于 Redis 的持久化状态后端实现。 + """Redis-based persistent state backend implementation. - 使用 ``redis.asyncio`` 客户端,值通过 JSON 序列化存储为 Redis string。 - TTL 由 Redis 原生 EXPIRE 机制管理。 + Uses ``redis.asyncio`` client, values are stored as Redis strings via JSON serialization. + TTL is managed by Redis native EXPIRE mechanism. """ def __init__(self, redis_url: str, key_prefix: str = "") -> None: @@ -29,15 +29,15 @@ def __init__(self, redis_url: str, key_prefix: str = "") -> None: self._prefix = key_prefix def _make_key(self, key: str) -> str: - """为 key 添加命名空间前缀。""" + """Add namespace prefix to key.""" return f"{self._prefix}{key}" if self._prefix else key def _strip_prefix(self, key: str) -> str: - """从完整 key 中移除命名空间前缀。""" + """Remove namespace prefix from full key.""" return key[len(self._prefix):] if self._prefix else key async def set(self, key: str, value: Any, ttl: int | None = None) -> None: - """存储键值对,可选 TTL(秒)""" + """Store key-value pair with optional TTL in seconds.""" real_key = self._make_key(key) data = json.dumps(value) if ttl is not None: @@ -46,7 +46,7 @@ async def set(self, key: str, value: Any, ttl: int | None = None) -> None: await self._client.set(real_key, data) async def get(self, key: str) -> Any | None: - """获取值,不存在或已过期返回 None""" + """Retrieve value, return None if not found or expired.""" real_key = self._make_key(key) raw = await self._client.get(real_key) if raw is None: @@ -54,30 +54,30 @@ async def get(self, key: str) -> Any | None: return json.loads(raw) async def delete(self, key: str) -> None: - """删除键,不存在时静默忽略""" + """Delete key, silently ignore if not found.""" real_key = self._make_key(key) await self._client.delete(real_key) async def exists(self, key: str) -> bool: - """检查键是否存在且未过期""" + """Check if key exists and is not expired.""" real_key = self._make_key(key) return bool(await self._client.exists(real_key)) async def keys(self, pattern: str) -> list[str]: - """按模式匹配返回所有键名。pattern 支持 * 通配符。 + """Return all key names matching the pattern. Supports * wildcard. - 注意:生产环境高 key 数量时建议改用 SCAN 以避免阻塞。 + Note: For high key volumes in production, consider using SCAN to avoid blocking. """ real_pattern = self._make_key(pattern) raw_keys = await self._client.keys(real_pattern) return [self._strip_prefix(k) for k in raw_keys] async def count(self, pattern: str) -> int: - """按模式匹配计数""" + """Count keys matching the pattern.""" return len(await self.keys(pattern)) async def set_nx(self, key: str, value: Any, ttl: int | None = None) -> bool: - """Set if not exists. 返回 True 如果成功设置,False 如果键已存在。""" + """Set if not exists. Return True if successfully set, False if key already exists.""" real_key = self._make_key(key) data = json.dumps(value) if ttl is not None: @@ -87,11 +87,11 @@ async def set_nx(self, key: str, value: Any, ttl: int | None = None) -> bool: return result is not None async def close(self) -> None: - """关闭 Redis 连接""" + """Close Redis connection.""" await self._client.aclose() async def health_check(self) -> bool: - """检查 Redis 是否健康可用""" + """Check if Redis is healthy and available.""" try: return await self._client.ping() except Exception: diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 9d96d4f3a..1b1525376 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -588,6 +588,10 @@ def get_server_state( Returns: A ServerStateProxy for interacting with the actor. """ + # Support passing persistence_config as raw dict from YAML + if isinstance(persistence_config, dict): + persistence_config = PersistenceConfig(**persistence_config) + try: actor = ray.get_actor(actor_name) except ValueError: diff --git a/src/twinkle/server/telemetry/__init__.py b/src/twinkle/server/telemetry/__init__.py index dd2c3bcfc..8da61164f 100644 --- a/src/twinkle/server/telemetry/__init__.py +++ b/src/twinkle/server/telemetry/__init__.py @@ -11,6 +11,7 @@ extract_context, get_current_span, ) +from .worker_init import ensure_telemetry_initialized __all__ = [ "MetricsRegistry", @@ -22,4 +23,5 @@ "inject_context", "extract_context", "get_current_span", + "ensure_telemetry_initialized", ] diff --git a/src/twinkle/server/telemetry/metrics.py b/src/twinkle/server/telemetry/metrics.py index e95b8242d..d9979b8f1 100644 --- a/src/twinkle/server/telemetry/metrics.py +++ b/src/twinkle/server/telemetry/metrics.py @@ -6,9 +6,9 @@ class MetricsRegistry: - """集中声明所有指标。业务代码通过 MetricsRegistry.get() 获取单例使用。 + """Centrally declares all metrics. Business code retrieves singleton via MetricsRegistry.get(). - 当 telemetry 未初始化时,OTEL 返回 NoOp meter,所有记录操作自动静默。 + When telemetry is not initialized, OTEL returns a NoOp meter and all recording operations are silently no-op. """ _instance: MetricsRegistry | None = None @@ -16,7 +16,7 @@ class MetricsRegistry: def __init__(self) -> None: meter = get_meter("twinkle-server") - # === HTTP 请求 === + # === HTTP Requests === self.requests_total = meter.create_counter( "twinkle.http.requests.total", description="Total HTTP requests received", @@ -27,7 +27,7 @@ def __init__(self) -> None: unit="s", ) - # === 任务队列 === + # === Task Queue === self.queue_depth = meter.create_up_down_counter( "twinkle.queue.depth", description="Current task queue depth", @@ -55,7 +55,7 @@ def __init__(self) -> None: description="Number of tokens currently tracked by the rate limiter", ) - # === 资源 === + # === Resources === self.active_sessions = meter.create_up_down_counter( "twinkle.sessions.active", description="Number of active client sessions", @@ -75,12 +75,12 @@ def __init__(self) -> None: @classmethod def get(cls) -> MetricsRegistry: - """获取全局 MetricsRegistry 单例。首次调用时创建。""" + """Retrieve global MetricsRegistry singleton. Created on first call.""" if cls._instance is None: cls._instance = cls() return cls._instance @classmethod def reset(cls) -> None: - """重置单例(用于测试或 telemetry 重新初始化)""" + """Reset singleton (for testing or telemetry re-initialization).""" cls._instance = None diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index 71bcc3734..d029d85c1 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -84,6 +84,29 @@ _initialized: bool = False +class _LoggingWriter: + """IO adapter that routes writes to Python logging. + + Used to redirect ConsoleExporter output through the logging system + so that Ray Serve worker logs are properly captured. Without this, + ConsoleSpanExporter / ConsoleMetricExporter write directly to + ``sys.stdout`` which Ray reroutes into its internal log files, + making telemetry output invisible to the standard logging pipeline. + """ + + def __init__(self, logger_name: str = 'twinkle.server.telemetry.export'): + self._logger = logging.getLogger(logger_name) + + def write(self, text: str) -> int: + text = text.strip() + if text: + self._logger.info(text) + return len(text) + + def flush(self) -> None: + pass + + class TelemetryConfig(BaseModel): """Configuration for the OpenTelemetry pipeline.""" @@ -129,9 +152,13 @@ def init_telemetry(config: TelemetryConfig) -> None: "OTLP exporters not available; falling back to console exporters." ) + # When using console exporters, route their output through the Python + # logging system so that Ray Serve workers actually surface the data. + _console_writer = _LoggingWriter() if use_console else None + # ---- Traces --------------------------------------------------------- if use_console: - span_exporter = ConsoleSpanExporter() + span_exporter = ConsoleSpanExporter(out=_console_writer) else: span_exporter = OTLPSpanExporter(endpoint=config.otlp_endpoint) @@ -142,7 +169,7 @@ def init_telemetry(config: TelemetryConfig) -> None: # ---- Metrics -------------------------------------------------------- if use_console: - metric_exporter = ConsoleMetricExporter() + metric_exporter = ConsoleMetricExporter(out=_console_writer) else: metric_exporter = OTLPMetricExporter(endpoint=config.otlp_endpoint) diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py index 9982af3cb..8cc36f04b 100644 --- a/src/twinkle/server/telemetry/tracing.py +++ b/src/twinkle/server/telemetry/tracing.py @@ -12,28 +12,28 @@ def get_tracer(name: str = "twinkle-server"): - """获取 tracer 实例。OTEL 未安装时返回 NoOp tracer。""" + """Retrieve tracer instance. Returns NoOp tracer when OTEL is not installed.""" if not _OTEL_AVAILABLE: return _NoopTracer() return trace.get_tracer(name) def inject_context(carrier: dict) -> None: - """将当前 trace context 注入到 carrier。OTEL 未安装时为 noop。""" + """Inject current trace context into carrier. Noop when OTEL is not installed.""" if not _OTEL_AVAILABLE: return inject(carrier) def extract_context(carrier: dict): - """从 carrier 中提取 trace context。OTEL 未安装时返回空 context。""" + """Extract trace context from carrier. Returns empty context when OTEL is not installed.""" if not _OTEL_AVAILABLE: return None return extract(carrier) def get_current_span(): - """获取当前活跃的 span。OTEL 未安装时返回 noop span。""" + """Get current active span. Returns noop span when OTEL is not installed.""" if not _OTEL_AVAILABLE: return _NoopSpan() return trace.get_current_span() diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py new file mode 100644 index 000000000..e07c7562b --- /dev/null +++ b/src/twinkle/server/telemetry/worker_init.py @@ -0,0 +1,49 @@ +"""Telemetry initialization for Ray worker processes. + +Ray Serve deployments run in separate worker processes that do not inherit +the driver process's OTEL global state. This module provides a function +to re-initialize telemetry in each worker process using environment variables +set by the launcher. +""" +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + +_worker_initialized = False + + +def ensure_telemetry_initialized() -> None: + """Initialize telemetry in the current worker process if not already done. + + Reads configuration from environment variables set by the launcher process. + Safe to call multiple times - only initializes once per process. + """ + global _worker_initialized + if _worker_initialized: + return + + _worker_initialized = True + + if os.environ.get('TWINKLE_TELEMETRY_ENABLED') != '1': + return + + try: + from twinkle.server.telemetry import TelemetryConfig, init_telemetry + from twinkle.server.telemetry.metrics import MetricsRegistry + + config = TelemetryConfig( + enabled=True, + debug=os.environ.get('TWINKLE_TELEMETRY_DEBUG', '0') == '1', + service_name=os.environ.get('TWINKLE_TELEMETRY_SERVICE', 'twinkle-server'), + otlp_endpoint=os.environ.get('TWINKLE_TELEMETRY_ENDPOINT', 'http://localhost:4317'), + export_interval_ms=int(os.environ.get('TWINKLE_TELEMETRY_INTERVAL', '30000')), + ) + init_telemetry(config) + # Reset MetricsRegistry singleton so it picks up the real MeterProvider + MetricsRegistry.reset() + logger.info(f'Worker telemetry initialized (service={config.service_name}, debug={config.debug})') + except Exception as e: + logger.warning(f'Failed to initialize worker telemetry: {e}') From 0d021c46f101f8ade9eb3f78124d71fb18719810 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 29 May 2026 18:29:16 +0800 Subject: [PATCH 06/34] fix(server): replace FastAPIInstrumentor with custom tracing middleware for Ray Serve compatibility --- src/twinkle/server/gateway/server.py | 16 +++---- src/twinkle/server/model/app.py | 16 +++---- src/twinkle/server/processor/app.py | 19 ++++---- src/twinkle/server/sampler/app.py | 24 +++++----- src/twinkle/server/telemetry/tracing.py | 59 +++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 44 deletions(-) diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 45eace9fa..fb6a2a48f 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -14,6 +14,7 @@ from typing import Any import twinkle_client.types as types +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.state import get_server_state from twinkle.server.utils.validation import verify_request_token @@ -95,6 +96,9 @@ def get_self() -> GatewayServer: @asynccontextmanager async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() yield try: await get_self().proxy.close() @@ -103,22 +107,12 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) - @app.on_event('startup') - async def _init_telemetry_and_instrument(): - """Initialize telemetry and instrument app in worker process (after deserialization).""" - from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized - ensure_telemetry_initialized() - try: - from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor - FastAPIInstrumentor.instrument_app(app) - except ImportError: - pass # OTEL instrumentation not installed - @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) app.middleware('http')(create_metrics_middleware('Gateway')) + app.middleware('http')(create_tracing_middleware('Gateway')) _register_tinker_routes(app, get_self) _register_twinkle_routes(app, get_self) diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 5c01b7f71..349538678 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -16,6 +16,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.utils.lifecycle import AdapterManagerMixin +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin @@ -164,6 +165,9 @@ def get_self() -> ModelManagement: @asynccontextmanager async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() try: await get_self()._ensure_replica_registered() except Exception as e: @@ -176,22 +180,12 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) - @app.on_event('startup') - async def _init_telemetry_and_instrument(): - """Initialize telemetry and instrument app in worker process (after deserialization).""" - from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized - ensure_telemetry_initialized() - try: - from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor - FastAPIInstrumentor.instrument_app(app) - except ImportError: - pass # OTEL instrumentation not installed - @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) app.middleware('http')(create_metrics_middleware('Model')) + app.middleware('http')(create_tracing_middleware('Model')) _register_tinker_routes(app, get_self) _register_twinkle_routes(app, get_self) diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 743eb7517..1f1e827bf 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -14,6 +14,7 @@ from __future__ import annotations import os +from contextlib import asynccontextmanager from fastapi import FastAPI, Request from ray import serve from typing import Any, Dict, Optional @@ -21,6 +22,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger from twinkle.server.utils.lifecycle import ProcessorManagerMixin +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token @@ -120,24 +122,21 @@ def build_processor_app(ncpu_proc_per_node: int, # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). - app = FastAPI() - - @app.on_event('startup') - async def _init_telemetry_and_instrument(): - """Initialize telemetry and instrument app in worker process (after deserialization).""" + @asynccontextmanager + async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() - try: - from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor - FastAPIInstrumentor.instrument_app(app) - except ImportError: - pass # OTEL instrumentation not installed + yield + + app = FastAPI(lifespan=lifespan) @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) app.middleware('http')(create_metrics_middleware('Processor')) + app.middleware('http')(create_tracing_middleware('Processor')) def get_self() -> ProcessorManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index be77754ef..d8368399b 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -7,12 +7,14 @@ """ from __future__ import annotations +from contextlib import asynccontextmanager from fastapi import FastAPI, Request from ray import serve from typing import Any, Dict, Optional import twinkle from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin @@ -118,27 +120,25 @@ def build_sampler_app(model_id: str, # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + @asynccontextmanager + async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + yield + app = FastAPI( title='Unified Sampler', description='REST API for distributed text generation inference (Tinker + Twinkle)', - version='1.0.0') - - @app.on_event('startup') - async def _init_telemetry_and_instrument(): - """Initialize telemetry and instrument app in worker process (after deserialization).""" - from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized - ensure_telemetry_initialized() - try: - from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor - FastAPIInstrumentor.instrument_app(app) - except ImportError: - pass # OTEL instrumentation not installed + version='1.0.0', + lifespan=lifespan) @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) app.middleware('http')(create_metrics_middleware('Sampler')) + app.middleware('http')(create_tracing_middleware('Sampler')) def get_self() -> SamplerManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py index 8cc36f04b..f10d64730 100644 --- a/src/twinkle/server/telemetry/tracing.py +++ b/src/twinkle/server/telemetry/tracing.py @@ -2,6 +2,8 @@ from __future__ import annotations +from fastapi import Request + try: from opentelemetry import trace from opentelemetry.propagate import inject, extract @@ -55,3 +57,60 @@ def start_as_current_span(self, name, **kwargs): return _NoopSpan() def start_span(self, name, **kwargs): return _NoopSpan() + + +def create_tracing_middleware(service_component: str): + """Create an HTTP tracing middleware that lazily acquires tracer at request time. + + This approach is compatible with Ray Serve's pickle serialization, unlike + ``FastAPIInstrumentor.instrument_app`` which attaches unpicklable references + (e.g. ``_thread.lock``) to the FastAPI app and breaks deployment pickling. + + The returned middleware is a plain async function with no captured state + other than the ``service_component`` string, so it can be safely pickled + along with the app when registered via ``app.middleware('http')`` inside a + Ray Serve build function. + + Args: + service_component: Logical service name used as the tracer name suffix + and recorded as a span attribute (e.g. ``Gateway``, ``Model``, + ``Processor``, ``Sampler``). + + Returns: + An async FastAPI HTTP middleware function. + """ + + async def tracing_middleware(request: Request, call_next): + # Lazy import to avoid holding any unpicklable module-level references + # at the time Ray Serve serializes the build function / app. + from opentelemetry import trace + + tracer = trace.get_tracer(f'twinkle.server.{service_component}') + + method = request.method + path = request.url.path + span_name = f'{method} {path}' + + with tracer.start_as_current_span( + span_name, + kind=trace.SpanKind.SERVER, + attributes={ + 'http.method': method, + 'http.url': str(request.url), + 'http.route': path, + 'http.scheme': request.url.scheme, + 'service.component': service_component, + }, + ) as span: + try: + response = await call_next(request) + span.set_attribute('http.status_code', response.status_code) + if response.status_code >= 400: + span.set_status(trace.Status(trace.StatusCode.ERROR)) + return response + except Exception as exc: + span.set_status(trace.Status(trace.StatusCode.ERROR, str(exc))) + span.record_exception(exc) + raise + + return tracing_middleware From 2d783f45cd139d4275f5a6df88856ca3b9c412cc Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Mon, 1 Jun 2026 11:38:36 +0800 Subject: [PATCH 07/34] update dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 8107ebcb3..668b339e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ RUN pip install flash-linear-attention -U --no-cache-dir RUN pip install numpy==2.2 --no-cache-dir # Install tinker, ray, and other deps -RUN pip install --no-cache-dir tinker==0.16.1 "ray[serve]" transformers peft<=0.18 accelerate -U +RUN pip install --no-cache-dir tinker==0.16.1 "ray[serve]" transformers peft<=0.18 accelerate redis opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp -U # Clone and install twinkle, checkout to latest v-tag RUN git clone https://github.com/modelscope/twinkle.git From e75deeb7401526d768228817caefdb030e32769a Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Mon, 1 Jun 2026 14:27:00 +0800 Subject: [PATCH 08/34] fix(server): harden telemetry/persistence wiring and middleware order - tracing middleware: return passthrough when OpenTelemetry SDK is absent instead of crashing every request on lazy import inside the handler - persistence_config: propagate via TWINKLE_PERSISTENCE_* env vars from the launcher to all Ray workers, so the configured backend is used regardless of which deployment initializes ServerState first; lift the example to top-level YAML - middleware order: register metrics last in all four apps so it wraps the outermost layer and captures full end-to-end latency including tracing - example yaml: telemetry default to enabled=false (optional dependency), document how to opt in --- .../client/server/megatron/server_config.yaml | 24 +++++++++ .../server/transformer/server_config.yaml | 27 +++++++--- src/twinkle/server/gateway/server.py | 6 ++- src/twinkle/server/launcher.py | 52 +++++++++++++++---- src/twinkle/server/model/app.py | 5 +- src/twinkle/server/processor/app.py | 5 +- src/twinkle/server/sampler/app.py | 5 +- src/twinkle/server/state/backend/factory.py | 40 ++++++++++++++ src/twinkle/server/state/server_state.py | 6 +++ src/twinkle/server/telemetry/tracing.py | 23 ++++---- src/twinkle/server/utils/metrics.py | 14 +++-- 11 files changed, 170 insertions(+), 37 deletions(-) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 696200200..6c9337402 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -9,6 +9,30 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 9000 # Port number for the server +# Telemetry configuration for observability (OpenTelemetry-based). +# Disabled by default — opentelemetry-* packages are optional dependencies. +# To enable: +# 1. pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp +# 2. set enabled: true (and debug: true to dump exporters to console for local dev, +# or leave debug: false and point otlp_endpoint at an OTLP collector — see +# cookbook/observability/ for a docker-compose example). +# telemetry_config: +# enabled: false +# debug: false +# service_name: twinkle-server +# otlp_endpoint: http://localhost:4317 + +# Persistence configuration for ServerState (sessions, models, futures, ...). +# Top-level placement makes the launcher propagate this to every Ray worker +# via env vars, so the configured backend is used regardless of which +# deployment initializes the ServerState actor first. +# mode: memory | file | redis +# file_path: required for `file` mode +# redis_url / key_prefix: required for `redis` mode +# persistence_config: +# mode: file +# file_path: /tmp/twinkle_state.json + # Applications: each entry defines a service component deployed on the server applications: diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 1c12415c1..cc7f37780 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -9,11 +9,29 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 8000 # Port number for the server -# Telemetry configuration for observability (OpenTelemetry-based) +# Telemetry configuration for observability (OpenTelemetry-based). +# Disabled by default — opentelemetry-* packages are optional dependencies. +# To enable: +# 1. pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp +# 2. set enabled: true (and debug: true to dump exporters to console for local dev, +# or leave debug: false and point otlp_endpoint at an OTLP collector — see +# cookbook/observability/ for a docker-compose example). telemetry_config: - enabled: true - debug: true + enabled: false + debug: false service_name: twinkle-server + otlp_endpoint: http://localhost:4317 + +# Persistence configuration for ServerState (sessions, models, futures, ...). +# Top-level placement makes the launcher propagate this to every Ray worker +# via env vars, so the configured backend is used regardless of which +# deployment initializes the ServerState actor first. +# mode: memory | file | redis +# file_path: required for `file` mode +# redis_url / key_prefix: required for `redis` mode +persistence_config: + mode: file + file_path: /tmp/twinkle_state.json # Applications: each entry defines a service component deployed on the server applications: @@ -26,9 +44,6 @@ applications: args: server_config: per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) - persistence_config: - mode: file - file_path: /tmp/twinkle_state.json supported_models: - Qwen/Qwen3.5-4B deployments: diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index fb6a2a48f..a8478512f 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -111,8 +111,12 @@ async def lifespan(app: FastAPI): async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - app.middleware('http')(create_metrics_middleware('Gateway')) + # Registration order matters: FastAPI runs middleware in LIFO order, so the + # last-registered wraps the outermost layer. Tracing first → metrics last + # makes metrics the outermost wrapper and capture the full end-to-end + # latency including tracing overhead and auth. app.middleware('http')(create_tracing_middleware('Gateway')) + app.middleware('http')(create_metrics_middleware('Gateway')) _register_tinker_routes(app, get_self) _register_twinkle_routes(app, get_self) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index ed4880ffb..2a739b32b 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -87,6 +87,24 @@ def _build_telemetry_env_vars(self) -> dict[str, str]: import os return {k: os.environ[k] for k in self._TELEMETRY_ENV_KEYS if k in os.environ} + def _build_persistence_env_vars(self) -> dict[str, str]: + """Collect persistence env vars from os.environ for propagation to Ray workers. + + These vars are read by ``PersistenceConfig.from_env()`` inside any + worker that calls ``get_server_state()`` without an explicit config, + which makes the chosen backend independent of deployment startup order. + """ + import os + from twinkle.server.state.backend.factory import PERSISTENCE_ENV_KEYS + return {k: os.environ[k] for k in PERSISTENCE_ENV_KEYS if k in os.environ} + + def _build_propagated_env_vars(self) -> dict[str, str]: + """Aggregate all env vars that must reach Ray worker processes.""" + merged: dict[str, str] = {} + merged.update(self._build_telemetry_env_vars()) + merged.update(self._build_persistence_env_vars()) + return merged + def _get_builders(self) -> dict[str, Callable]: """Get the builder functions for all app types.""" if self._builders: @@ -147,11 +165,11 @@ def _init_ray(self) -> None: # Use runtime_env to apply patches in worker processes # This is required because Ray Serve's ProxyActor runs in separate processes runtime_env = get_runtime_env_for_patches() - # Propagate telemetry env vars to all Ray workers spawned in this job - telemetry_env_vars = self._build_telemetry_env_vars() - if telemetry_env_vars: + # Propagate telemetry + persistence env vars to all Ray workers + propagated_env_vars = self._build_propagated_env_vars() + if propagated_env_vars: merged_env_vars = dict(runtime_env.get('env_vars') or {}) - merged_env_vars.update(telemetry_env_vars) + merged_env_vars.update(propagated_env_vars) runtime_env['env_vars'] = merged_env_vars # Connect to existing cluster if available, otherwise start local instance ray.init( @@ -217,15 +235,17 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: if isinstance(deploy_config, dict): deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} - # Inject telemetry env vars into the deployment's runtime_env so that - # Ray Serve replicas (worker processes) can initialize telemetry. - # User-specified env_vars take precedence to avoid overriding existing config. - telemetry_env_vars = self._build_telemetry_env_vars() - if telemetry_env_vars: + # Inject telemetry + persistence env vars into the deployment's + # runtime_env so that Ray Serve replicas (worker processes) can + # initialize telemetry and resolve the configured persistence backend + # regardless of deployment startup order. + # User-specified env_vars take precedence over our defaults. + propagated_env_vars = self._build_propagated_env_vars() + if propagated_env_vars: ray_actor_options = dict(deploy_options.get('ray_actor_options') or {}) runtime_env = dict(ray_actor_options.get('runtime_env') or {}) env_vars = dict(runtime_env.get('env_vars') or {}) - for k, v in telemetry_env_vars.items(): + for k, v in propagated_env_vars.items(): env_vars.setdefault(k, v) runtime_env['env_vars'] = env_vars ray_actor_options['runtime_env'] = runtime_env @@ -265,6 +285,18 @@ def launch(self) -> None: os.environ['TWINKLE_TELEMETRY_ENDPOINT'] = config.otlp_endpoint os.environ['TWINKLE_TELEMETRY_INTERVAL'] = str(config.export_interval_ms) + # Export top-level persistence_config to env vars so any worker + # (not just Gateway) can build the same backend on first call to + # get_server_state(). + persistence_config_dict = self.config.get('persistence_config') + if persistence_config_dict: + import os + from twinkle.server.state.backend.factory import PersistenceConfig + pconfig = PersistenceConfig(**persistence_config_dict) + for k, v in pconfig.to_env_vars().items(): + os.environ[k] = v + logger.info(f'Persistence backend configured: mode={pconfig.mode}') + self._init_ray() self._start_serve() diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 349538678..59f6f00e2 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -184,8 +184,11 @@ async def lifespan(app: FastAPI): async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - app.middleware('http')(create_metrics_middleware('Model')) + # Registration order: FastAPI runs middleware LIFO. Tracing first → metrics + # last makes metrics the outermost wrapper, so its latency observation + # covers the full request path including tracing overhead. app.middleware('http')(create_tracing_middleware('Model')) + app.middleware('http')(create_metrics_middleware('Model')) _register_tinker_routes(app, get_self) _register_twinkle_routes(app, get_self) diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 1f1e827bf..20976da84 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -135,8 +135,11 @@ async def lifespan(app: FastAPI): async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - app.middleware('http')(create_metrics_middleware('Processor')) + # Registration order: FastAPI runs middleware LIFO. Tracing first → metrics + # last makes metrics the outermost wrapper, so its latency observation + # covers the full request path including tracing overhead. app.middleware('http')(create_tracing_middleware('Processor')) + app.middleware('http')(create_metrics_middleware('Processor')) def get_self() -> ProcessorManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index d8368399b..1262ae158 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -137,8 +137,11 @@ async def lifespan(app: FastAPI): async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - app.middleware('http')(create_metrics_middleware('Sampler')) + # Registration order: FastAPI runs middleware LIFO. Tracing first → metrics + # last makes metrics the outermost wrapper, so its latency observation + # covers the full request path including tracing overhead. app.middleware('http')(create_tracing_middleware('Sampler')) + app.middleware('http')(create_metrics_middleware('Sampler')) def get_self() -> SamplerManagement: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 2e03f72d7..4a37ecedc 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import os from typing import Literal from pydantic import BaseModel @@ -12,6 +13,17 @@ logger = logging.getLogger(__name__) +# Env var keys propagated by the launcher so that any Ray worker can rebuild +# the same PersistenceConfig regardless of which deployment initializes the +# ServerState actor first. +PERSISTENCE_ENV_KEYS: tuple[str, ...] = ( + 'TWINKLE_PERSISTENCE_MODE', + 'TWINKLE_PERSISTENCE_FILE_PATH', + 'TWINKLE_PERSISTENCE_REDIS_URL', + 'TWINKLE_PERSISTENCE_KEY_PREFIX', +) + + class PersistenceConfig(BaseModel): """Configuration for state persistence backend.""" mode: Literal['memory', 'file', 'redis'] = 'memory' @@ -19,6 +31,34 @@ class PersistenceConfig(BaseModel): redis_url: str | None = None # required for redis mode key_prefix: str = '' # optional global key prefix + def to_env_vars(self) -> dict[str, str]: + """Serialize this config to env var key/value pairs for worker propagation.""" + env: dict[str, str] = {'TWINKLE_PERSISTENCE_MODE': self.mode} + if self.file_path: + env['TWINKLE_PERSISTENCE_FILE_PATH'] = self.file_path + if self.redis_url: + env['TWINKLE_PERSISTENCE_REDIS_URL'] = self.redis_url + if self.key_prefix: + env['TWINKLE_PERSISTENCE_KEY_PREFIX'] = self.key_prefix + return env + + @classmethod + def from_env(cls) -> PersistenceConfig | None: + """Reconstruct a PersistenceConfig from env vars set by the launcher. + + Returns ``None`` when ``TWINKLE_PERSISTENCE_MODE`` is unset, so callers + can distinguish "no env-configured persistence" from "memory mode". + """ + mode = os.environ.get('TWINKLE_PERSISTENCE_MODE') + if not mode: + return None + return cls( + mode=mode, + file_path=os.environ.get('TWINKLE_PERSISTENCE_FILE_PATH'), + redis_url=os.environ.get('TWINKLE_PERSISTENCE_REDIS_URL'), + key_prefix=os.environ.get('TWINKLE_PERSISTENCE_KEY_PREFIX', ''), + ) + def create_backend(config: PersistenceConfig | None = None) -> StateBackend: """Create a StateBackend instance based on persistence configuration. diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 1b1525376..eea830dd6 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -592,6 +592,12 @@ def get_server_state( if isinstance(persistence_config, dict): persistence_config = PersistenceConfig(**persistence_config) + # Fall back to env-var-propagated config so any worker (not just Gateway) + # can create the actor with the right backend regardless of deployment + # startup order. Explicit args still take precedence. + if backend is None and persistence_config is None: + persistence_config = PersistenceConfig.from_env() + try: actor = ray.get_actor(actor_name) except ValueError: diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py index f10d64730..9a12a5b9d 100644 --- a/src/twinkle/server/telemetry/tracing.py +++ b/src/twinkle/server/telemetry/tracing.py @@ -60,16 +60,15 @@ def start_span(self, name, **kwargs): def create_tracing_middleware(service_component: str): - """Create an HTTP tracing middleware that lazily acquires tracer at request time. + """Create an HTTP tracing middleware compatible with Ray Serve pickling. - This approach is compatible with Ray Serve's pickle serialization, unlike - ``FastAPIInstrumentor.instrument_app`` which attaches unpicklable references - (e.g. ``_thread.lock``) to the FastAPI app and breaks deployment pickling. + Unlike ``FastAPIInstrumentor.instrument_app`` which attaches unpicklable + references (e.g. ``_thread.lock``) to the FastAPI app and breaks Ray Serve + deployment pickling, the returned middleware is a plain async function + closing only over the ``service_component`` string. - The returned middleware is a plain async function with no captured state - other than the ``service_component`` string, so it can be safely pickled - along with the app when registered via ``app.middleware('http')`` inside a - Ray Serve build function. + When OpenTelemetry is not installed, returns a passthrough middleware so + the server still works without the optional dependency. Args: service_component: Logical service name used as the tracer name suffix @@ -79,12 +78,12 @@ def create_tracing_middleware(service_component: str): Returns: An async FastAPI HTTP middleware function. """ + if not _OTEL_AVAILABLE: + async def passthrough_middleware(request: Request, call_next): + return await call_next(request) + return passthrough_middleware async def tracing_middleware(request: Request, call_next): - # Lazy import to avoid holding any unpicklable module-level references - # at the time Ray Serve serializes the build function / app. - from opentelemetry import trace - tracer = trace.get_tracer(f'twinkle.server.{service_component}') method = request.method diff --git a/src/twinkle/server/utils/metrics.py b/src/twinkle/server/utils/metrics.py index c02da78e5..aaf1076ed 100644 --- a/src/twinkle/server/utils/metrics.py +++ b/src/twinkle/server/utils/metrics.py @@ -142,12 +142,16 @@ def create_metrics_middleware(deployment: str) -> Callable: Usage inside a ``build_*_app()`` function:: from twinkle.server.utils.metrics import create_metrics_middleware - metrics_mw = create_metrics_middleware("Model") - app.middleware('http')(metrics_mw) + from twinkle.server.telemetry.tracing import create_tracing_middleware - Because FastAPI executes middleware in LIFO order, registering this - **after** ``verify_token`` means it wraps the outermost layer and - captures full end-to-end latency including auth. + app.middleware('http')(verify_token) + app.middleware('http')(create_tracing_middleware("Model")) + app.middleware('http')(create_metrics_middleware("Model")) # outermost + + FastAPI executes middleware in LIFO order, so the **last** middleware + registered is the outermost wrapper. Register metrics last so its + latency observation covers the full request path including tracing + overhead and authentication. """ async def metrics_middleware(request: Any, call_next: Callable) -> Any: From 0819affb4aa8c9765dd7211381dd212ef1a200dc Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Mon, 1 Jun 2026 14:27:08 +0800 Subject: [PATCH 09/34] docs(observability): add LGTM-based docker-compose stack for telemetry cookbook/observability/ provides a one-container OTLP receiver + dashboard for Twinkle, built on grafana/otel-lgtm (bundled OTel Collector + Mimir + Tempo + Loki + Grafana). Users docker compose up, point telemetry_config at localhost:4317, and get a pre-provisioned overview dashboard with HTTP rate / latency, queue depth, task latencies, rate-limit rejections, and active-resource gauges. --- cookbook/observability/README.md | 102 +++++++++++++ cookbook/observability/docker-compose.yaml | 35 +++++ .../grafana/dashboards/twinkle-overview.json | 137 ++++++++++++++++++ 3 files changed, 274 insertions(+) create mode 100644 cookbook/observability/README.md create mode 100644 cookbook/observability/docker-compose.yaml create mode 100644 cookbook/observability/grafana/dashboards/twinkle-overview.json diff --git a/cookbook/observability/README.md b/cookbook/observability/README.md new file mode 100644 index 000000000..a17fd009c --- /dev/null +++ b/cookbook/observability/README.md @@ -0,0 +1,102 @@ +# Twinkle Observability Stack + +A one-container OTLP receiver + dashboard for Twinkle, built on the +[`grafana/otel-lgtm`](https://github.com/grafana/docker-otel-lgtm) image. +That image bundles OTel Collector, Mimir (Prometheus-compatible), Tempo, +Loki, and Grafana with everything pre-wired — no extra config files needed. + +## What you get + +| Surface | URL | Purpose | +|---|---|---| +| Grafana | `http://localhost:3000` | Dashboards + Explore (metrics / traces / logs) | +| OTLP gRPC | `localhost:4317` | Point Twinkle's `otlp_endpoint` here | +| OTLP HTTP | `localhost:4318` | Same, HTTP alternative | + +## Quick start + +```bash +# 1. Start the stack +cd cookbook/observability +docker compose up -d + +# 2. Make sure Twinkle has the OTLP exporter +pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp + +# 3. In your server_config.yaml: +# +# telemetry_config: +# enabled: true +# debug: false # debug=true dumps to console instead of OTLP +# service_name: twinkle-server +# otlp_endpoint: http://localhost:4317 + +# 4. Launch Twinkle as usual +python -m twinkle.server --config server_config.yaml + +# 5. Open Grafana +open http://localhost:3000 +``` + +Anonymous viewer access is on by default; full access is `admin` / `admin`. + +The provisioned **Twinkle / Twinkle Server Overview** dashboard shows: + +- HTTP request rate and P95 latency per deployment (Gateway / Model / Sampler / Processor) +- Active resources (sessions, models, sampling sessions, futures) +- Task queue depth, execution P95, wait-time P95 +- Rate-limit rejections and task completions by status + +For traces, switch the datasource picker in **Explore** to Tempo and search by +service or span name. Twinkle spans are namespaced under +`twinkle.server.` (Gateway / Model / Sampler / Processor). + +## Metric naming reference + +Twinkle emits OpenTelemetry metric names with dot notation. Prometheus's OTLP +ingestion converts dots to underscores and appends `_total` to monotonic +counters where missing: + +| OpenTelemetry name | Prometheus name | +|---|---| +| `twinkle.http.requests.total` | `twinkle_http_requests_total` | +| `twinkle.http.request.duration_seconds` | `twinkle_http_request_duration_seconds_bucket` (and `_sum`, `_count`) | +| `twinkle.queue.depth` | `twinkle_queue_depth` | +| `twinkle.task.execution_seconds` | `twinkle_task_execution_seconds_bucket` | +| `twinkle.task.wait_seconds` | `twinkle_task_wait_seconds_bucket` | +| `twinkle.rate_limit.rejections.total` | `twinkle_rate_limit_rejections_total` | +| `twinkle.tasks.total` | `twinkle_tasks_total` | +| `twinkle.rate_limiter.active_tokens` | `twinkle_rate_limiter_active_tokens` | +| `twinkle.sessions.active` | `twinkle_sessions_active` | +| `twinkle.models.active` | `twinkle_models_active` | +| `twinkle.sampling_sessions.active` | `twinkle_sampling_sessions_active` | +| `twinkle.futures.active` | `twinkle_futures_active` | + +## Tear down + +```bash +docker compose down -v # -v also removes the named volume +``` + +## Production note + +The LGTM all-in-one image is **for local development and demos**. Each backend +runs single-instance and shares one volume. For production, deploy each +component (Mimir / Tempo / Loki / Grafana) separately with proper persistent +storage, replicas, and an OTel Collector tier in front. The OTLP endpoint and +metric names stay the same, so your `server_config.yaml` and dashboards +transfer without changes. + +## Troubleshooting + +- **Grafana shows "No data"** — confirm `telemetry_config.enabled: true` in + your server config and that Twinkle's worker logs show + `Worker telemetry initialized`. With `debug: true` Twinkle dumps spans / + metrics to logs instead of OTLP, so set `debug: false` once verified. +- **Twinkle can't reach the collector** — `otlp_endpoint` must be reachable + from the Twinkle process. If Twinkle runs in another container on the same + Docker network, use `http://twinkle-lgtm:4317` instead of `localhost`. +- **Dashboard panel shows "Datasource not found"** — open the panel, switch + the datasource dropdown to the LGTM-provisioned Prometheus / Tempo and save. + This happens when LGTM versions change the default datasource UID; the + dashboard JSON pins `uid: prometheus`. diff --git a/cookbook/observability/docker-compose.yaml b/cookbook/observability/docker-compose.yaml new file mode 100644 index 000000000..de775c6e0 --- /dev/null +++ b/cookbook/observability/docker-compose.yaml @@ -0,0 +1,35 @@ +# Twinkle observability stack — local-dev edition. +# +# Single container: grafana/otel-lgtm bundles OTel Collector + Mimir +# (Prometheus-compatible) + Tempo + Loki + Grafana with all datasources +# pre-wired. Twinkle pushes OTLP to :4317; you read it back at :3000. +# +# Quick start: +# docker compose up -d +# open http://localhost:3000 # admin / admin (anonymous viewer also enabled) +# +# In your server_config.yaml: +# telemetry_config: +# enabled: true +# debug: false +# service_name: twinkle-server +# otlp_endpoint: http://localhost:4317 + +services: + lgtm: + image: grafana/otel-lgtm:latest + container_name: twinkle-lgtm + ports: + - "3000:3000" # Grafana UI + - "4317:4317" # OTLP gRPC — point telemetry_config.otlp_endpoint here + - "4318:4318" # OTLP HTTP (alternative) + volumes: + # Drop our pre-built Twinkle overview dashboard into the image's + # existing dashboard provisioning folder. Grafana inside the container + # auto-scans this directory on startup. + - ./grafana/dashboards/twinkle-overview.json:/otel-lgtm/grafana/conf/provisioning/dashboards/twinkle-overview.json:ro + # Persist dashboards/data across container restarts (optional) + - lgtm-data:/data + +volumes: + lgtm-data: diff --git a/cookbook/observability/grafana/dashboards/twinkle-overview.json b/cookbook/observability/grafana/dashboards/twinkle-overview.json new file mode 100644 index 000000000..ab609e813 --- /dev/null +++ b/cookbook/observability/grafana/dashboards/twinkle-overview.json @@ -0,0 +1,137 @@ +{ + "annotations": {"list": []}, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 1, + "id": null, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "fieldConfig": {"defaults": {"color": {"mode": "palette-classic"}, "unit": "reqps"}}, + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 0}, + "id": 1, + "title": "HTTP request rate (per deployment)", + "type": "timeseries", + "targets": [ + { + "expr": "sum by (deployment, status) (rate(twinkle_http_requests_total[1m]))", + "legendFormat": "{{deployment}} {{status}}", + "refId": "A" + } + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "fieldConfig": {"defaults": {"unit": "s"}}, + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 0}, + "id": 2, + "title": "HTTP latency P95 (per deployment)", + "type": "timeseries", + "targets": [ + { + "expr": "histogram_quantile(0.95, sum by (le, deployment) (rate(twinkle_http_request_duration_seconds_bucket[5m])))", + "legendFormat": "{{deployment}}", + "refId": "A" + } + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 8}, + "id": 3, + "title": "Active resources", + "type": "timeseries", + "targets": [ + {"expr": "twinkle_sessions_active", "legendFormat": "sessions", "refId": "A"}, + {"expr": "twinkle_models_active", "legendFormat": "models", "refId": "B"}, + {"expr": "twinkle_sampling_sessions_active", "legendFormat": "sampling sessions", "refId": "C"}, + {"expr": "twinkle_futures_active", "legendFormat": "futures", "refId": "D"} + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 8}, + "id": 4, + "title": "Task queue depth (per deployment)", + "type": "timeseries", + "targets": [ + { + "expr": "sum by (deployment) (twinkle_queue_depth)", + "legendFormat": "{{deployment}}", + "refId": "A" + } + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "fieldConfig": {"defaults": {"unit": "s"}}, + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 16}, + "id": 5, + "title": "Task execution P95", + "type": "timeseries", + "targets": [ + { + "expr": "histogram_quantile(0.95, sum by (le, deployment) (rate(twinkle_task_execution_seconds_bucket[5m])))", + "legendFormat": "{{deployment}}", + "refId": "A" + } + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "fieldConfig": {"defaults": {"unit": "s"}}, + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 16}, + "id": 6, + "title": "Task wait time P95", + "type": "timeseries", + "targets": [ + { + "expr": "histogram_quantile(0.95, sum by (le, deployment) (rate(twinkle_task_wait_seconds_bucket[5m])))", + "legendFormat": "{{deployment}}", + "refId": "A" + } + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "gridPos": {"h": 8, "w": 12, "x": 0, "y": 24}, + "id": 7, + "title": "Rate-limit rejections", + "type": "timeseries", + "targets": [ + { + "expr": "sum by (deployment) (rate(twinkle_rate_limit_rejections_total[1m]))", + "legendFormat": "{{deployment}}", + "refId": "A" + } + ] + }, + { + "datasource": {"type": "prometheus", "uid": "prometheus"}, + "gridPos": {"h": 8, "w": 12, "x": 12, "y": 24}, + "id": 8, + "title": "Task completions by status", + "type": "timeseries", + "targets": [ + { + "expr": "sum by (deployment, status) (rate(twinkle_tasks_total[1m]))", + "legendFormat": "{{deployment}} {{status}}", + "refId": "A" + } + ] + } + ], + "refresh": "10s", + "schemaVersion": 39, + "tags": ["twinkle"], + "templating": {"list": []}, + "time": {"from": "now-1h", "to": "now"}, + "timepicker": {}, + "timezone": "", + "title": "Twinkle Server Overview", + "uid": "twinkle-overview", + "version": 1, + "weekStart": "" +} From c8439b40f7d6a4178e8f45ffe3426368bcebf1a1 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Mon, 1 Jun 2026 23:51:54 +0800 Subject: [PATCH 10/34] test(server): add client-API contract harness and baseline snapshot (Phase 0a) Establishes the cross-cutting freeze guard for the Tinker/Twinkle HTTP contract (R20, R18.1) ahead of the server-config + observability refactor. The harness builds each FastAPI app (gateway, model, sampler, processor) by registering its route helpers against a fresh app, then extracts the OpenAPI paths and component schemas. The committed baseline at tests/contract/client_api_baseline.json is what every later phase asserts equality against to catch any drift in route paths, HTTP methods, or request/response schemas. Adds hypothesis to the test extras for the property-based tests later phases will need. --- pyproject.toml | 1 + tests/contract/__init__.py | 0 tests/contract/client_api_baseline.json | 7340 ++++++++++++++++++++ tests/contract/client_api_harness.py | 152 + tests/contract/test_client_api_contract.py | 97 + tests/contract/update_baseline.py | 22 + 6 files changed, 7612 insertions(+) create mode 100644 tests/contract/__init__.py create mode 100644 tests/contract/client_api_baseline.json create mode 100644 tests/contract/client_api_harness.py create mode 100644 tests/contract/test_client_api_contract.py create mode 100644 tests/contract/update_baseline.py diff --git a/pyproject.toml b/pyproject.toml index 964a7548c..b54feaceb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridg vllm = ["vllm>=0.11"] ray = ["ray[serve]"] tinker = ["tinker==0.14.0"] +test = ["hypothesis>=6.0"] docs = [ "sphinx>=5.3.0,<6.0.0", "docutils>=0.16.0,<0.17.0", diff --git a/tests/contract/__init__.py b/tests/contract/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contract/client_api_baseline.json b/tests/contract/client_api_baseline.json new file mode 100644 index 000000000..54f0ec1bf --- /dev/null +++ b/tests/contract/client_api_baseline.json @@ -0,0 +1,7340 @@ +{ + "gateway": { + "paths": { + "/asample": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SampleRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Asample Asample Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/create_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Create Model Create Model Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/create_sampling_session": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateSamplingSessionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateSamplingSessionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/create_session": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__create_session_request__CreateSessionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__create_session_response__CreateSessionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/forward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Forward Forward Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/forward_backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardBackwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Forward Backward Forward Backward Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/get_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Get Info Get Info Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/get_server_capabilities": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__get_server_capabilities_response__GetServerCapabilitiesResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/healthz": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__health_response__HealthResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/load_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Load Weights Load Weights Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/optim_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptimStepRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Optim Step Optim Step Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/retrieve_future": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FutureRetrieveRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Retrieve Future Retrieve Future Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/save_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Save Weights Save Weights Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/save_weights_for_sampler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsForSamplerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Save Weights For Sampler Save Weights For Sampler Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/session_heartbeat": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__session_heartbeat_request__SessionHeartbeatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__session_heartbeat_response__SessionHeartbeatResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/telemetry": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TelemetrySendRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TelemetryResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs": { + "GET": { + "parameters": [ + { + "in": "query", + "name": "limit", + "required": false, + "schema": { + "default": 20, + "title": "Limit", + "type": "integer" + } + }, + { + "in": "query", + "name": "offset", + "required": false, + "schema": { + "default": 0, + "title": "Offset", + "type": "integer" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__training_runs_response__TrainingRunsResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__training_run__TrainingRun" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}/checkpoints": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__checkpoints_list_response__CheckpointsListResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}/checkpoints/{checkpoint_id}": { + "DELETE": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Delete Run Checkpoint Training Runs Run Id Checkpoints Checkpoint Id Delete" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}/checkpoints/{checkpoint_id}/publish": { + "POST": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/capacity_info": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CapacityInfoResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/checkpoint_path/{run_id}/{checkpoint_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CheckpointPathResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create_session": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__CreateSessionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__CreateSessionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/get_server_capabilities": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__server__GetServerCapabilitiesResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/healthz": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__server__HealthResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/session_heartbeat": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__SessionHeartbeatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__SessionHeartbeatResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/status": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "title": "Response Status Twinkle Status Get", + "type": "object" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/training_runs": { + "GET": { + "parameters": [ + { + "in": "query", + "name": "limit", + "required": false, + "schema": { + "default": 20, + "title": "Limit", + "type": "integer" + } + }, + { + "in": "query", + "name": "offset", + "required": false, + "schema": { + "default": 0, + "title": "Offset", + "type": "integer" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__TrainingRunsResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/training_runs/{run_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__TrainingRun" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/training_runs/{run_id}/checkpoints": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__CheckpointsListResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id}": { + "DELETE": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteCheckpointResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/weights_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WeightsInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__WeightsInfoResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/unload_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnloadModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Unload Model Unload Model Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/weights_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "title": "Body", + "type": "object" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__weights_info_response__WeightsInfoResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "AdamParams": { + "additionalProperties": false, + "properties": { + "beta1": { + "default": 0.9, + "title": "Beta1", + "type": "number" + }, + "beta2": { + "default": 0.95, + "title": "Beta2", + "type": "number" + }, + "eps": { + "default": 1e-12, + "title": "Eps", + "type": "number" + }, + "grad_clip_norm": { + "default": 0.0, + "title": "Grad Clip Norm", + "type": "number" + }, + "learning_rate": { + "default": 0.0001, + "title": "Learning Rate", + "type": "number" + }, + "weight_decay": { + "default": 0.0, + "title": "Weight Decay", + "type": "number" + } + }, + "title": "AdamParams", + "type": "object" + }, + "CapacityInfoResponse": { + "description": "Response body for the /capacity_info endpoint.", + "properties": { + "free_loras": { + "title": "Free Loras", + "type": "integer" + }, + "max_loras": { + "title": "Max Loras", + "type": "integer" + }, + "used_loras": { + "title": "Used Loras", + "type": "integer" + } + }, + "required": [ + "max_loras", + "used_loras", + "free_loras" + ], + "title": "CapacityInfoResponse", + "type": "object" + }, + "CheckpointPathResponse": { + "description": "Response body for the /checkpoint_path endpoint.", + "properties": { + "path": { + "title": "Path", + "type": "string" + }, + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + } + }, + "required": [ + "path", + "twinkle_path" + ], + "title": "CheckpointPathResponse", + "type": "object" + }, + "CreateModelRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "lora_config": { + "anyOf": [ + { + "$ref": "#/components/schemas/LoraConfig" + }, + { + "type": "null" + } + ] + }, + "model_seq_id": { + "title": "Model Seq Id", + "type": "integer" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_model", + "default": "create_model", + "title": "Type", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "session_id", + "model_seq_id", + "base_model" + ], + "title": "CreateModelRequest", + "type": "object" + }, + "CreateSamplingSessionRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "model_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Path" + }, + "sampling_session_seq_id": { + "title": "Sampling Session Seq Id", + "type": "integer" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_sampling_session", + "default": "create_sampling_session", + "title": "Type", + "type": "string" + } + }, + "required": [ + "session_id", + "sampling_session_seq_id" + ], + "title": "CreateSamplingSessionRequest", + "type": "object" + }, + "CreateSamplingSessionResponse": { + "properties": { + "sampling_session_id": { + "title": "Sampling Session Id", + "type": "string" + }, + "type": { + "const": "create_sampling_session", + "default": "create_sampling_session", + "title": "Type", + "type": "string" + } + }, + "required": [ + "sampling_session_id" + ], + "title": "CreateSamplingSessionResponse", + "type": "object" + }, + "Datum": { + "additionalProperties": false, + "properties": { + "loss_fn_inputs": { + "additionalProperties": { + "$ref": "#/components/schemas/TensorData" + }, + "title": "Loss Fn Inputs", + "type": "object" + }, + "model_input": { + "$ref": "#/components/schemas/ModelInput" + } + }, + "required": [ + "loss_fn_inputs", + "model_input" + ], + "title": "Datum", + "type": "object" + }, + "DeleteCheckpointResponse": { + "properties": { + "message": { + "title": "Message", + "type": "string" + }, + "success": { + "title": "Success", + "type": "boolean" + } + }, + "required": [ + "success", + "message" + ], + "title": "DeleteCheckpointResponse", + "type": "object" + }, + "EncodedTextChunk": { + "additionalProperties": false, + "properties": { + "tokens": { + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + }, + "type": { + "const": "encoded_text", + "default": "encoded_text", + "title": "Type", + "type": "string" + } + }, + "required": [ + "tokens" + ], + "title": "EncodedTextChunk", + "type": "object" + }, + "ForwardBackwardInput": { + "additionalProperties": false, + "properties": { + "data": { + "items": { + "$ref": "#/components/schemas/Datum" + }, + "title": "Data", + "type": "array" + }, + "loss_fn": { + "enum": [ + "cross_entropy", + "importance_sampling", + "ppo", + "cispo", + "dro" + ], + "title": "Loss Fn", + "type": "string" + }, + "loss_fn_config": { + "anyOf": [ + { + "additionalProperties": { + "type": "number" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Loss Fn Config" + } + }, + "required": [ + "data", + "loss_fn" + ], + "title": "ForwardBackwardInput", + "type": "object" + }, + "ForwardBackwardRequest": { + "additionalProperties": false, + "properties": { + "forward_backward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_backward_input", + "model_id" + ], + "title": "ForwardBackwardRequest", + "type": "object" + }, + "ForwardRequest": { + "additionalProperties": false, + "properties": { + "forward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_input", + "model_id" + ], + "title": "ForwardRequest", + "type": "object" + }, + "FutureRetrieveRequest": { + "additionalProperties": false, + "properties": { + "allow_metadata_only": { + "default": false, + "title": "Allow Metadata Only", + "type": "boolean" + }, + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "FutureRetrieveRequest", + "type": "object" + }, + "GenericEvent": { + "properties": { + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_data": { + "additionalProperties": true, + "default": {}, + "title": "Event Data", + "type": "object" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_name": { + "title": "Event Name", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + } + }, + "required": [ + "event", + "event_id", + "event_name", + "event_session_index", + "severity", + "timestamp" + ], + "title": "GenericEvent", + "type": "object" + }, + "GetInfoRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "get_info", + "default": "get_info", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "GetInfoRequest", + "type": "object" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ImageAssetPointerChunk": { + "additionalProperties": false, + "properties": { + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "location": { + "title": "Location", + "type": "string" + }, + "type": { + "const": "image_asset_pointer", + "default": "image_asset_pointer", + "title": "Type", + "type": "string" + } + }, + "required": [ + "format", + "location" + ], + "title": "ImageAssetPointerChunk", + "type": "object" + }, + "ImageChunk": { + "additionalProperties": false, + "properties": { + "data": { + "contentMediaType": "application/octet-stream", + "title": "Data", + "type": "string" + }, + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "type": { + "const": "image", + "default": "image", + "title": "Type", + "type": "string" + } + }, + "required": [ + "data", + "format" + ], + "title": "ImageChunk", + "type": "object" + }, + "LoadWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "optimizer": { + "title": "Optimizer", + "type": "boolean" + }, + "path": { + "title": "Path", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "load_weights", + "default": "load_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id", + "path", + "optimizer" + ], + "title": "LoadWeightsRequest", + "type": "object" + }, + "LoraConfig": { + "additionalProperties": false, + "properties": { + "rank": { + "title": "Rank", + "type": "integer" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "train_attn": { + "default": true, + "title": "Train Attn", + "type": "boolean" + }, + "train_mlp": { + "default": true, + "title": "Train Mlp", + "type": "boolean" + }, + "train_unembed": { + "default": true, + "title": "Train Unembed", + "type": "boolean" + } + }, + "required": [ + "rank" + ], + "title": "LoraConfig", + "type": "object" + }, + "ModelInput": { + "additionalProperties": false, + "properties": { + "chunks": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/EncodedTextChunk" + }, + { + "$ref": "#/components/schemas/ImageAssetPointerChunk" + }, + { + "$ref": "#/components/schemas/ImageChunk" + } + ] + }, + "title": "Chunks", + "type": "array" + } + }, + "required": [ + "chunks" + ], + "title": "ModelInput", + "type": "object" + }, + "OptimStepRequest": { + "additionalProperties": false, + "properties": { + "adam_params": { + "$ref": "#/components/schemas/AdamParams" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "optim_step", + "default": "optim_step", + "title": "Type", + "type": "string" + } + }, + "required": [ + "adam_params", + "model_id" + ], + "title": "OptimStepRequest", + "type": "object" + }, + "SampleRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "model_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Path" + }, + "num_samples": { + "default": 1, + "title": "Num Samples", + "type": "integer" + }, + "prompt": { + "$ref": "#/components/schemas/ModelInput" + }, + "prompt_logprobs": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Prompt Logprobs" + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "sampling_session_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "topk_prompt_logprobs": { + "default": 0, + "title": "Topk Prompt Logprobs", + "type": "integer" + }, + "type": { + "const": "sample", + "default": "sample", + "title": "Type", + "type": "string" + } + }, + "required": [ + "prompt", + "sampling_params" + ], + "title": "SampleRequest", + "type": "object" + }, + "SamplingParams": { + "properties": { + "max_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Max Tokens" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "stop": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stop" + }, + "temperature": { + "default": 1, + "title": "Temperature", + "type": "number" + }, + "top_k": { + "default": -1, + "title": "Top K", + "type": "integer" + }, + "top_p": { + "default": 1, + "title": "Top P", + "type": "number" + } + }, + "title": "SamplingParams", + "type": "object" + }, + "SaveWeightsForSamplerRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "sampling_session_seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Seq Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights_for_sampler", + "default": "save_weights_for_sampler", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsForSamplerRequest", + "type": "object" + }, + "SaveWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights", + "default": "save_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsRequest", + "type": "object" + }, + "SessionEndEvent": { + "properties": { + "duration": { + "title": "Duration", + "type": "string" + }, + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + } + }, + "required": [ + "duration", + "event", + "event_id", + "event_session_index", + "severity", + "timestamp" + ], + "title": "SessionEndEvent", + "type": "object" + }, + "SessionStartEvent": { + "properties": { + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + } + }, + "required": [ + "event", + "event_id", + "event_session_index", + "severity", + "timestamp" + ], + "title": "SessionStartEvent", + "type": "object" + }, + "TelemetryResponse": { + "properties": { + "status": { + "const": "accepted", + "title": "Status", + "type": "string" + } + }, + "required": [ + "status" + ], + "title": "TelemetryResponse", + "type": "object" + }, + "TelemetrySendRequest": { + "additionalProperties": false, + "properties": { + "events": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/SessionStartEvent" + }, + { + "$ref": "#/components/schemas/SessionEndEvent" + }, + { + "$ref": "#/components/schemas/UnhandledExceptionEvent" + }, + { + "$ref": "#/components/schemas/GenericEvent" + } + ] + }, + "title": "Events", + "type": "array" + }, + "platform": { + "title": "Platform", + "type": "string" + }, + "sdk_version": { + "title": "Sdk Version", + "type": "string" + }, + "session_id": { + "title": "Session Id", + "type": "string" + } + }, + "required": [ + "events", + "platform", + "sdk_version", + "session_id" + ], + "title": "TelemetrySendRequest", + "type": "object" + }, + "TensorData": { + "additionalProperties": false, + "properties": { + "data": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "items": { + "type": "number" + }, + "type": "array" + } + ], + "title": "Data" + }, + "dtype": { + "enum": [ + "int64", + "float32" + ], + "title": "Dtype", + "type": "string" + }, + "shape": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Shape" + } + }, + "required": [ + "data", + "dtype" + ], + "title": "TensorData", + "type": "object" + }, + "UnhandledExceptionEvent": { + "properties": { + "error_message": { + "title": "Error Message", + "type": "string" + }, + "error_type": { + "title": "Error Type", + "type": "string" + }, + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + }, + "traceback": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Traceback" + } + }, + "required": [ + "error_message", + "error_type", + "event", + "event_id", + "event_session_index", + "severity", + "timestamp" + ], + "title": "UnhandledExceptionEvent", + "type": "object" + }, + "UnloadModelRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "unload_model", + "default": "unload_model", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "UnloadModelRequest", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + }, + "WeightsInfoRequest": { + "properties": { + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + } + }, + "required": [ + "twinkle_path" + ], + "title": "WeightsInfoRequest", + "type": "object" + }, + "tinker__types__checkpoint__Checkpoint": { + "properties": { + "checkpoint_id": { + "title": "Checkpoint Id", + "type": "string" + }, + "checkpoint_type": { + "enum": [ + "training", + "sampler" + ], + "title": "Checkpoint Type", + "type": "string" + }, + "expires_at": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Expires At" + }, + "public": { + "default": false, + "title": "Public", + "type": "boolean" + }, + "size_bytes": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Size Bytes" + }, + "time": { + "format": "date-time", + "title": "Time", + "type": "string" + }, + "tinker_path": { + "title": "Tinker Path", + "type": "string" + } + }, + "required": [ + "checkpoint_id", + "checkpoint_type", + "time", + "tinker_path" + ], + "title": "Checkpoint", + "type": "object" + }, + "tinker__types__checkpoints_list_response__CheckpointsListResponse": { + "properties": { + "checkpoints": { + "items": { + "$ref": "#/components/schemas/tinker__types__checkpoint__Checkpoint" + }, + "title": "Checkpoints", + "type": "array" + }, + "cursor": { + "anyOf": [ + { + "$ref": "#/components/schemas/tinker__types__cursor__Cursor" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "checkpoints" + ], + "title": "CheckpointsListResponse", + "type": "object" + }, + "tinker__types__create_session_request__CreateSessionRequest": { + "additionalProperties": false, + "properties": { + "sdk_version": { + "title": "Sdk Version", + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "title": "Tags", + "type": "array" + }, + "type": { + "const": "create_session", + "default": "create_session", + "title": "Type", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "tags", + "user_metadata", + "sdk_version" + ], + "title": "CreateSessionRequest", + "type": "object" + }, + "tinker__types__create_session_response__CreateSessionResponse": { + "properties": { + "error_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Message" + }, + "info_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Info Message" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_session", + "default": "create_session", + "title": "Type", + "type": "string" + }, + "warning_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Warning Message" + } + }, + "required": [ + "session_id" + ], + "title": "CreateSessionResponse", + "type": "object" + }, + "tinker__types__cursor__Cursor": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "offset": { + "title": "Offset", + "type": "integer" + }, + "total_count": { + "title": "Total Count", + "type": "integer" + } + }, + "required": [ + "offset", + "limit", + "total_count" + ], + "title": "Cursor", + "type": "object" + }, + "tinker__types__get_server_capabilities_response__GetServerCapabilitiesResponse": { + "description": "Response containing the server's supported models and capabilities.", + "properties": { + "supported_models": { + "items": { + "$ref": "#/components/schemas/tinker__types__get_server_capabilities_response__SupportedModel" + }, + "title": "Supported Models", + "type": "array" + } + }, + "required": [ + "supported_models" + ], + "title": "GetServerCapabilitiesResponse", + "type": "object" + }, + "tinker__types__get_server_capabilities_response__SupportedModel": { + "description": "Information about a model supported by the server.", + "properties": { + "model_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Name" + } + }, + "title": "SupportedModel", + "type": "object" + }, + "tinker__types__health_response__HealthResponse": { + "properties": { + "status": { + "const": "ok", + "title": "Status", + "type": "string" + } + }, + "required": [ + "status" + ], + "title": "HealthResponse", + "type": "object" + }, + "tinker__types__session_heartbeat_request__SessionHeartbeatRequest": { + "additionalProperties": false, + "properties": { + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "session_heartbeat", + "default": "session_heartbeat", + "title": "Type", + "type": "string" + } + }, + "required": [ + "session_id" + ], + "title": "SessionHeartbeatRequest", + "type": "object" + }, + "tinker__types__session_heartbeat_response__SessionHeartbeatResponse": { + "properties": { + "type": { + "const": "session_heartbeat", + "default": "session_heartbeat", + "title": "Type", + "type": "string" + } + }, + "title": "SessionHeartbeatResponse", + "type": "object" + }, + "tinker__types__training_run__TrainingRun": { + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "corrupted": { + "default": false, + "title": "Corrupted", + "type": "boolean" + }, + "is_lora": { + "title": "Is Lora", + "type": "boolean" + }, + "last_checkpoint": { + "anyOf": [ + { + "$ref": "#/components/schemas/tinker__types__checkpoint__Checkpoint" + }, + { + "type": "null" + } + ] + }, + "last_request_time": { + "format": "date-time", + "title": "Last Request Time", + "type": "string" + }, + "last_sampler_checkpoint": { + "anyOf": [ + { + "$ref": "#/components/schemas/tinker__types__checkpoint__Checkpoint" + }, + { + "type": "null" + } + ] + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_owner": { + "title": "Model Owner", + "type": "string" + }, + "training_run_id": { + "title": "Training Run Id", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "training_run_id", + "base_model", + "model_owner", + "is_lora", + "last_request_time" + ], + "title": "TrainingRun", + "type": "object" + }, + "tinker__types__training_runs_response__TrainingRunsResponse": { + "properties": { + "cursor": { + "$ref": "#/components/schemas/tinker__types__cursor__Cursor" + }, + "training_runs": { + "items": { + "$ref": "#/components/schemas/tinker__types__training_run__TrainingRun" + }, + "title": "Training Runs", + "type": "array" + } + }, + "required": [ + "training_runs", + "cursor" + ], + "title": "TrainingRunsResponse", + "type": "object" + }, + "tinker__types__weights_info_response__WeightsInfoResponse": { + "description": "Minimal information for loading public checkpoints.", + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "is_lora": { + "title": "Is Lora", + "type": "boolean" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "train_attn": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Attn" + }, + "train_mlp": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Mlp" + }, + "train_unembed": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Unembed" + } + }, + "required": [ + "base_model", + "is_lora" + ], + "title": "WeightsInfoResponse", + "type": "object" + }, + "twinkle_client__types__server__GetServerCapabilitiesResponse": { + "description": "Response body for the /get_server_capabilities endpoint.", + "properties": { + "supported_models": { + "items": { + "$ref": "#/components/schemas/twinkle_client__types__server__SupportedModel" + }, + "title": "Supported Models", + "type": "array" + } + }, + "required": [ + "supported_models" + ], + "title": "GetServerCapabilitiesResponse", + "type": "object" + }, + "twinkle_client__types__server__HealthResponse": { + "properties": { + "status": { + "title": "Status", + "type": "string" + } + }, + "required": [ + "status" + ], + "title": "HealthResponse", + "type": "object" + }, + "twinkle_client__types__server__SupportedModel": { + "description": "Information about a supported model.", + "properties": { + "model_name": { + "title": "Model Name", + "type": "string" + } + }, + "required": [ + "model_name" + ], + "title": "SupportedModel", + "type": "object" + }, + "twinkle_client__types__session__CreateSessionRequest": { + "description": "Request body for POST /twinkle/create_session.", + "properties": { + "metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Metadata" + } + }, + "title": "CreateSessionRequest", + "type": "object" + }, + "twinkle_client__types__session__CreateSessionResponse": { + "description": "Response body for POST /twinkle/create_session.", + "properties": { + "session_id": { + "title": "Session Id", + "type": "string" + } + }, + "required": [ + "session_id" + ], + "title": "CreateSessionResponse", + "type": "object" + }, + "twinkle_client__types__session__SessionHeartbeatRequest": { + "description": "Request body for POST /twinkle/session_heartbeat.", + "properties": { + "session_id": { + "title": "Session Id", + "type": "string" + } + }, + "required": [ + "session_id" + ], + "title": "SessionHeartbeatRequest", + "type": "object" + }, + "twinkle_client__types__session__SessionHeartbeatResponse": { + "description": "Response body for POST /twinkle/session_heartbeat.", + "properties": {}, + "title": "SessionHeartbeatResponse", + "type": "object" + }, + "twinkle_client__types__training__Checkpoint": { + "description": "Twinkle checkpoint model.", + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "checkpoint_id": { + "title": "Checkpoint Id", + "type": "string" + }, + "checkpoint_type": { + "title": "Checkpoint Type", + "type": "string" + }, + "is_lora": { + "default": false, + "title": "Is Lora", + "type": "boolean" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "public": { + "default": false, + "title": "Public", + "type": "boolean" + }, + "size_bytes": { + "title": "Size Bytes", + "type": "integer" + }, + "time": { + "format": "date-time", + "title": "Time", + "type": "string" + }, + "train_attn": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Attn" + }, + "train_mlp": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Mlp" + }, + "train_unembed": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Unembed" + }, + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "checkpoint_id", + "checkpoint_type", + "time", + "size_bytes", + "twinkle_path" + ], + "title": "Checkpoint", + "type": "object" + }, + "twinkle_client__types__training__CheckpointsListResponse": { + "properties": { + "checkpoints": { + "items": { + "$ref": "#/components/schemas/twinkle_client__types__training__Checkpoint" + }, + "title": "Checkpoints", + "type": "array" + }, + "cursor": { + "anyOf": [ + { + "$ref": "#/components/schemas/twinkle_client__types__training__Cursor" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "checkpoints" + ], + "title": "CheckpointsListResponse", + "type": "object" + }, + "twinkle_client__types__training__Cursor": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "offset": { + "title": "Offset", + "type": "integer" + }, + "total_count": { + "title": "Total Count", + "type": "integer" + } + }, + "required": [ + "limit", + "offset", + "total_count" + ], + "title": "Cursor", + "type": "object" + }, + "twinkle_client__types__training__TrainingRun": { + "description": "Twinkle training run model.", + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "corrupted": { + "default": false, + "title": "Corrupted", + "type": "boolean" + }, + "is_lora": { + "default": false, + "title": "Is Lora", + "type": "boolean" + }, + "last_checkpoint": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Last Checkpoint" + }, + "last_request_time": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Last Request Time" + }, + "last_sampler_checkpoint": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Last Sampler Checkpoint" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_owner": { + "title": "Model Owner", + "type": "string" + }, + "save_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Save Dir" + }, + "training_run_id": { + "title": "Training Run Id", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "training_run_id", + "base_model", + "model_owner" + ], + "title": "TrainingRun", + "type": "object" + }, + "twinkle_client__types__training__TrainingRunsResponse": { + "properties": { + "cursor": { + "$ref": "#/components/schemas/twinkle_client__types__training__Cursor" + }, + "training_runs": { + "items": { + "$ref": "#/components/schemas/twinkle_client__types__training__TrainingRun" + }, + "title": "Training Runs", + "type": "array" + } + }, + "required": [ + "training_runs", + "cursor" + ], + "title": "TrainingRunsResponse", + "type": "object" + }, + "twinkle_client__types__training__WeightsInfoResponse": { + "description": "Twinkle weights info response.", + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "is_lora": { + "default": false, + "title": "Is Lora", + "type": "boolean" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_owner": { + "title": "Model Owner", + "type": "string" + }, + "training_run_id": { + "title": "Training Run Id", + "type": "string" + } + }, + "required": [ + "training_run_id", + "base_model", + "model_owner" + ], + "title": "WeightsInfoResponse", + "type": "object" + } + } + }, + "model": { + "paths": { + "/tinker/create_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/forward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__forward_request__ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/forward_backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardBackwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/get_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetInfoResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/load_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/optim_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptimStepRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/save_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/save_weights_for_sampler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsForSamplerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/unload_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnloadModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/add_adapter_to_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/add_metric": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddMetricRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/apply_patch": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ApplyPatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/calculate_loss": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CalculateLossResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/calculate_metric": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CalculateMetricRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CalculateMetricResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/clip_grad_and_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClipGradAndStepRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/clip_grad_norm": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClipGradNormResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/forward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__model__ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/forward_backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__model__ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardBackwardResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/forward_only": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardOnlyRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/get_state_dict": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetStateDictRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetStateDictResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/get_train_configs": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetTrainConfigsResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/load": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/lr_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/resume_from_checkpoint": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ResumeFromCheckpointRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TrainingProgressResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/save": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_loss": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetLossRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_lr_scheduler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetLrSchedulerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_optimizer": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetOptimizerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_processor": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetProcessorRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_template": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetTemplateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/upload_status/{request_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "request_id", + "required": true, + "schema": { + "title": "Request Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UploadStatusResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/upload_to_hub": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UploadToHubRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UploadToHubResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/zero_grad": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "AdamParams": { + "additionalProperties": false, + "properties": { + "beta1": { + "default": 0.9, + "title": "Beta1", + "type": "number" + }, + "beta2": { + "default": 0.95, + "title": "Beta2", + "type": "number" + }, + "eps": { + "default": 1e-12, + "title": "Eps", + "type": "number" + }, + "grad_clip_norm": { + "default": 0.0, + "title": "Grad Clip Norm", + "type": "number" + }, + "learning_rate": { + "default": 0.0001, + "title": "Learning Rate", + "type": "number" + }, + "weight_decay": { + "default": 0.0, + "title": "Weight Decay", + "type": "number" + } + }, + "title": "AdamParams", + "type": "object" + }, + "AdapterRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "AdapterRequest", + "type": "object" + }, + "AddAdapterRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "config": { + "title": "Config", + "type": "string" + }, + "save_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Save Dir" + } + }, + "required": [ + "adapter_name", + "config" + ], + "title": "AddAdapterRequest", + "type": "object" + }, + "AddAdapterResponse": { + "description": "Response body for the /add_adapter_to_sampler endpoint.", + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "AddAdapterResponse", + "type": "object" + }, + "AddMetricRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "is_training": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Training" + }, + "metric_cls": { + "title": "Metric Cls", + "type": "string" + } + }, + "required": [ + "metric_cls", + "adapter_name" + ], + "title": "AddMetricRequest", + "type": "object" + }, + "ApplyPatchRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "patch_cls": { + "title": "Patch Cls", + "type": "string" + } + }, + "required": [ + "patch_cls", + "adapter_name" + ], + "title": "ApplyPatchRequest", + "type": "object" + }, + "CalculateLossResponse": { + "description": "Response for /calculate_loss endpoint (returns float).", + "properties": { + "result": { + "title": "Result", + "type": "number" + } + }, + "required": [ + "result" + ], + "title": "CalculateLossResponse", + "type": "object" + }, + "CalculateMetricRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "is_training": { + "default": true, + "title": "Is Training", + "type": "boolean" + } + }, + "required": [ + "adapter_name" + ], + "title": "CalculateMetricRequest", + "type": "object" + }, + "CalculateMetricResponse": { + "description": "Response for /calculate_metric endpoint (returns Dict).", + "properties": { + "result": { + "additionalProperties": true, + "title": "Result", + "type": "object" + } + }, + "required": [ + "result" + ], + "title": "CalculateMetricResponse", + "type": "object" + }, + "ClipGradAndStepRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "max_grad_norm": { + "default": 1.0, + "title": "Max Grad Norm", + "type": "number" + }, + "norm_type": { + "default": 2, + "title": "Norm Type", + "type": "integer" + } + }, + "required": [ + "adapter_name" + ], + "title": "ClipGradAndStepRequest", + "type": "object" + }, + "ClipGradNormResponse": { + "description": "Response for /clip_grad_norm endpoint (returns float as str).", + "properties": { + "result": { + "title": "Result", + "type": "string" + } + }, + "required": [ + "result" + ], + "title": "ClipGradNormResponse", + "type": "object" + }, + "CreateModelRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "lora_config": { + "anyOf": [ + { + "$ref": "#/components/schemas/LoraConfig" + }, + { + "type": "null" + } + ] + }, + "model_seq_id": { + "title": "Model Seq Id", + "type": "integer" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_model", + "default": "create_model", + "title": "Type", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "session_id", + "model_seq_id", + "base_model" + ], + "title": "CreateModelRequest", + "type": "object" + }, + "CreateRequest": { + "additionalProperties": true, + "properties": {}, + "title": "CreateRequest", + "type": "object" + }, + "CreateResponse": { + "description": "Response for /create endpoint.", + "properties": { + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "title": "CreateResponse", + "type": "object" + }, + "Datum": { + "additionalProperties": false, + "properties": { + "loss_fn_inputs": { + "additionalProperties": { + "$ref": "#/components/schemas/TensorData" + }, + "title": "Loss Fn Inputs", + "type": "object" + }, + "model_input": { + "$ref": "#/components/schemas/ModelInput" + } + }, + "required": [ + "loss_fn_inputs", + "model_input" + ], + "title": "Datum", + "type": "object" + }, + "EncodedTextChunk": { + "additionalProperties": false, + "properties": { + "tokens": { + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + }, + "type": { + "const": "encoded_text", + "default": "encoded_text", + "title": "Type", + "type": "string" + } + }, + "required": [ + "tokens" + ], + "title": "EncodedTextChunk", + "type": "object" + }, + "ForwardBackwardInput": { + "additionalProperties": false, + "properties": { + "data": { + "items": { + "$ref": "#/components/schemas/Datum" + }, + "title": "Data", + "type": "array" + }, + "loss_fn": { + "enum": [ + "cross_entropy", + "importance_sampling", + "ppo", + "cispo", + "dro" + ], + "title": "Loss Fn", + "type": "string" + }, + "loss_fn_config": { + "anyOf": [ + { + "additionalProperties": { + "type": "number" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Loss Fn Config" + } + }, + "required": [ + "data", + "loss_fn" + ], + "title": "ForwardBackwardInput", + "type": "object" + }, + "ForwardBackwardRequest": { + "additionalProperties": false, + "properties": { + "forward_backward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_backward_input", + "model_id" + ], + "title": "ForwardBackwardRequest", + "type": "object" + }, + "ForwardBackwardResponse": { + "description": "Response for /forward_backward endpoint (returns ModelOutput).", + "properties": { + "result": { + "title": "Result" + } + }, + "required": [ + "result" + ], + "title": "ForwardBackwardResponse", + "type": "object" + }, + "ForwardOnlyRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Adapter Name" + }, + "inputs": { + "title": "Inputs" + } + }, + "required": [ + "inputs" + ], + "title": "ForwardOnlyRequest", + "type": "object" + }, + "ForwardResponse": { + "description": "Response for /forward and /forward_only endpoints (returns ModelOutput).", + "properties": { + "result": { + "title": "Result" + } + }, + "required": [ + "result" + ], + "title": "ForwardResponse", + "type": "object" + }, + "GetInfoRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "get_info", + "default": "get_info", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "GetInfoRequest", + "type": "object" + }, + "GetInfoResponse": { + "description": "Response containing information about a training client's model.", + "properties": { + "is_lora": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Lora" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_data": { + "$ref": "#/components/schemas/ModelData" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "model_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Name" + }, + "type": { + "anyOf": [ + { + "const": "get_info", + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Type" + } + }, + "required": [ + "model_data", + "model_id" + ], + "title": "GetInfoResponse", + "type": "object" + }, + "GetStateDictRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "GetStateDictRequest", + "type": "object" + }, + "GetStateDictResponse": { + "description": "Response for /get_state_dict endpoint (returns Dict).", + "properties": { + "result": { + "additionalProperties": true, + "title": "Result", + "type": "object" + } + }, + "required": [ + "result" + ], + "title": "GetStateDictResponse", + "type": "object" + }, + "GetTrainConfigsResponse": { + "description": "Response for /get_train_configs endpoint (returns str).", + "properties": { + "result": { + "title": "Result", + "type": "string" + } + }, + "required": [ + "result" + ], + "title": "GetTrainConfigsResponse", + "type": "object" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ImageAssetPointerChunk": { + "additionalProperties": false, + "properties": { + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "location": { + "title": "Location", + "type": "string" + }, + "type": { + "const": "image_asset_pointer", + "default": "image_asset_pointer", + "title": "Type", + "type": "string" + } + }, + "required": [ + "format", + "location" + ], + "title": "ImageAssetPointerChunk", + "type": "object" + }, + "ImageChunk": { + "additionalProperties": false, + "properties": { + "data": { + "contentMediaType": "application/octet-stream", + "title": "Data", + "type": "string" + }, + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "type": { + "const": "image", + "default": "image", + "title": "Type", + "type": "string" + } + }, + "required": [ + "data", + "format" + ], + "title": "ImageChunk", + "type": "object" + }, + "LoadRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "load_optimizer": { + "default": false, + "title": "Load Optimizer", + "type": "boolean" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "adapter_name", + "name" + ], + "title": "LoadRequest", + "type": "object" + }, + "LoadWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "optimizer": { + "title": "Optimizer", + "type": "boolean" + }, + "path": { + "title": "Path", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "load_weights", + "default": "load_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id", + "path", + "optimizer" + ], + "title": "LoadWeightsRequest", + "type": "object" + }, + "LoraConfig": { + "additionalProperties": false, + "properties": { + "rank": { + "title": "Rank", + "type": "integer" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "train_attn": { + "default": true, + "title": "Train Attn", + "type": "boolean" + }, + "train_mlp": { + "default": true, + "title": "Train Mlp", + "type": "boolean" + }, + "train_unembed": { + "default": true, + "title": "Train Unembed", + "type": "boolean" + } + }, + "required": [ + "rank" + ], + "title": "LoraConfig", + "type": "object" + }, + "ModelData": { + "description": "Metadata about a model's architecture and configuration.", + "properties": { + "arch": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Arch" + }, + "model_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Name" + }, + "tokenizer_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Tokenizer Id" + } + }, + "title": "ModelData", + "type": "object" + }, + "ModelInput": { + "additionalProperties": false, + "properties": { + "chunks": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/EncodedTextChunk" + }, + { + "$ref": "#/components/schemas/ImageAssetPointerChunk" + }, + { + "$ref": "#/components/schemas/ImageChunk" + } + ] + }, + "title": "Chunks", + "type": "array" + } + }, + "required": [ + "chunks" + ], + "title": "ModelInput", + "type": "object" + }, + "OptimStepRequest": { + "additionalProperties": false, + "properties": { + "adam_params": { + "$ref": "#/components/schemas/AdamParams" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "optim_step", + "default": "optim_step", + "title": "Type", + "type": "string" + } + }, + "required": [ + "adam_params", + "model_id" + ], + "title": "OptimStepRequest", + "type": "object" + }, + "ResumeFromCheckpointRequest": { + "additionalProperties": true, + "description": "Request for /resume_from_checkpoint endpoint.", + "properties": { + "adapter_name": { + "default": "", + "title": "Adapter Name", + "type": "string" + }, + "name": { + "title": "Name", + "type": "string" + }, + "resume_only_model": { + "default": false, + "title": "Resume Only Model", + "type": "boolean" + } + }, + "required": [ + "name" + ], + "title": "ResumeFromCheckpointRequest", + "type": "object" + }, + "SaveRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "is_sampler": { + "default": false, + "title": "Is Sampler", + "type": "boolean" + }, + "name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Name" + }, + "save_optimizer": { + "default": false, + "title": "Save Optimizer", + "type": "boolean" + } + }, + "required": [ + "adapter_name" + ], + "title": "SaveRequest", + "type": "object" + }, + "SaveResponse": { + "description": "Response for /save endpoint (returns twinkle path + checkpoint dir).", + "properties": { + "checkpoint_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Checkpoint Dir" + }, + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + } + }, + "required": [ + "twinkle_path" + ], + "title": "SaveResponse", + "type": "object" + }, + "SaveWeightsForSamplerRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "sampling_session_seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Seq Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights_for_sampler", + "default": "save_weights_for_sampler", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsForSamplerRequest", + "type": "object" + }, + "SaveWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights", + "default": "save_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsRequest", + "type": "object" + }, + "SetLossRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "loss_cls": { + "title": "Loss Cls", + "type": "string" + } + }, + "required": [ + "loss_cls", + "adapter_name" + ], + "title": "SetLossRequest", + "type": "object" + }, + "SetLrSchedulerRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "scheduler_cls": { + "title": "Scheduler Cls", + "type": "string" + } + }, + "required": [ + "scheduler_cls", + "adapter_name" + ], + "title": "SetLrSchedulerRequest", + "type": "object" + }, + "SetOptimizerRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "optimizer_cls": { + "title": "Optimizer Cls", + "type": "string" + } + }, + "required": [ + "optimizer_cls", + "adapter_name" + ], + "title": "SetOptimizerRequest", + "type": "object" + }, + "SetProcessorRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "processor_cls": { + "title": "Processor Cls", + "type": "string" + } + }, + "required": [ + "processor_cls", + "adapter_name" + ], + "title": "SetProcessorRequest", + "type": "object" + }, + "SetTemplateRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "template_cls": { + "title": "Template Cls", + "type": "string" + } + }, + "required": [ + "template_cls", + "adapter_name" + ], + "title": "SetTemplateRequest", + "type": "object" + }, + "TensorData": { + "additionalProperties": false, + "properties": { + "data": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "items": { + "type": "number" + }, + "type": "array" + } + ], + "title": "Data" + }, + "dtype": { + "enum": [ + "int64", + "float32" + ], + "title": "Dtype", + "type": "string" + }, + "shape": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Shape" + } + }, + "required": [ + "data", + "dtype" + ], + "title": "TensorData", + "type": "object" + }, + "TrainingProgressResponse": { + "description": "Response for /resume_from_checkpoint endpoint.", + "properties": { + "result": { + "additionalProperties": true, + "title": "Result", + "type": "object" + } + }, + "required": [ + "result" + ], + "title": "TrainingProgressResponse", + "type": "object" + }, + "UnloadModelRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "unload_model", + "default": "unload_model", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "UnloadModelRequest", + "type": "object" + }, + "UntypedAPIFuture": { + "properties": { + "model_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Id" + }, + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "UntypedAPIFuture", + "type": "object" + }, + "UploadStatusResponse": { + "description": "Response for /upload_status/{request_id} endpoint.", + "properties": { + "error": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error" + }, + "request_id": { + "title": "Request Id", + "type": "string" + }, + "status": { + "title": "Status", + "type": "string" + } + }, + "required": [ + "request_id", + "status" + ], + "title": "UploadStatusResponse", + "type": "object" + }, + "UploadToHubRequest": { + "additionalProperties": true, + "properties": { + "async_upload": { + "default": false, + "title": "Async Upload", + "type": "boolean" + }, + "checkpoint_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "additionalProperties": true, + "type": "object" + } + ], + "title": "Checkpoint Dir" + }, + "hub_model_id": { + "title": "Hub Model Id", + "type": "string" + }, + "hub_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Hub Token" + } + }, + "required": [ + "checkpoint_dir", + "hub_model_id" + ], + "title": "UploadToHubRequest", + "type": "object" + }, + "UploadToHubResponse": { + "description": "Response for /upload_to_hub endpoint.", + "properties": { + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "UploadToHubResponse", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + }, + "tinker__types__forward_request__ForwardRequest": { + "additionalProperties": false, + "properties": { + "forward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_input", + "model_id" + ], + "title": "ForwardRequest", + "type": "object" + }, + "twinkle_client__types__model__ForwardRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "inputs": { + "title": "Inputs" + } + }, + "required": [ + "inputs", + "adapter_name" + ], + "title": "ForwardRequest", + "type": "object" + } + } + }, + "processor": { + "paths": { + "/twinkle/call": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCallRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCallResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCreateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCreateResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ProcessorCallRequest": { + "additionalProperties": true, + "properties": { + "function": { + "title": "Function", + "type": "string" + }, + "processor_id": { + "title": "Processor Id", + "type": "string" + } + }, + "required": [ + "processor_id", + "function" + ], + "title": "ProcessorCallRequest", + "type": "object" + }, + "ProcessorCallResponse": { + "description": "Response body for the /call endpoint.", + "properties": { + "result": { + "title": "Result" + } + }, + "required": [ + "result" + ], + "title": "ProcessorCallResponse", + "type": "object" + }, + "ProcessorCreateRequest": { + "additionalProperties": true, + "properties": { + "class_type": { + "title": "Class Type", + "type": "string" + }, + "processor_type": { + "title": "Processor Type", + "type": "string" + } + }, + "required": [ + "processor_type", + "class_type" + ], + "title": "ProcessorCreateRequest", + "type": "object" + }, + "ProcessorCreateResponse": { + "description": "Response body for the /create endpoint.", + "properties": { + "processor_id": { + "title": "Processor Id", + "type": "string" + } + }, + "required": [ + "processor_id" + ], + "title": "ProcessorCreateResponse", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + } + } + }, + "sampler": { + "paths": { + "/tinker/asample": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__sample_request__SampleRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/add_adapter_to_sampler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/apply_patch": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ApplyPatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create": { + "POST": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/sample": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__sampler__SampleRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SampleResponseModelList" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_template": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetTemplateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetTemplateResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "AddAdapterRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "config": { + "title": "Config", + "type": "string" + }, + "save_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Save Dir" + } + }, + "required": [ + "adapter_name", + "config" + ], + "title": "AddAdapterRequest", + "type": "object" + }, + "AddAdapterResponse": { + "description": "Response body for the /add_adapter_to_sampler endpoint.", + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "AddAdapterResponse", + "type": "object" + }, + "ApplyPatchRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "patch_cls": { + "title": "Patch Cls", + "type": "string" + } + }, + "required": [ + "patch_cls", + "adapter_name" + ], + "title": "ApplyPatchRequest", + "type": "object" + }, + "CreateResponse": { + "description": "Response for /create endpoint.", + "properties": { + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "title": "CreateResponse", + "type": "object" + }, + "EncodedTextChunk": { + "additionalProperties": false, + "properties": { + "tokens": { + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + }, + "type": { + "const": "encoded_text", + "default": "encoded_text", + "title": "Type", + "type": "string" + } + }, + "required": [ + "tokens" + ], + "title": "EncodedTextChunk", + "type": "object" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ImageAssetPointerChunk": { + "additionalProperties": false, + "properties": { + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "location": { + "title": "Location", + "type": "string" + }, + "type": { + "const": "image_asset_pointer", + "default": "image_asset_pointer", + "title": "Type", + "type": "string" + } + }, + "required": [ + "format", + "location" + ], + "title": "ImageAssetPointerChunk", + "type": "object" + }, + "ImageChunk": { + "additionalProperties": false, + "properties": { + "data": { + "contentMediaType": "application/octet-stream", + "title": "Data", + "type": "string" + }, + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "type": { + "const": "image", + "default": "image", + "title": "Type", + "type": "string" + } + }, + "required": [ + "data", + "format" + ], + "title": "ImageChunk", + "type": "object" + }, + "ModelInput": { + "additionalProperties": false, + "properties": { + "chunks": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/EncodedTextChunk" + }, + { + "$ref": "#/components/schemas/ImageAssetPointerChunk" + }, + { + "$ref": "#/components/schemas/ImageChunk" + } + ] + }, + "title": "Chunks", + "type": "array" + } + }, + "required": [ + "chunks" + ], + "title": "ModelInput", + "type": "object" + }, + "SampleResponseModel": { + "description": "Mirroring twinkle.data_format.SampleResponse.", + "properties": { + "prompt_logprobs": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Prompt Logprobs" + }, + "sequences": { + "description": "List of sampled sequences", + "items": { + "$ref": "#/components/schemas/SampledSequenceModel" + }, + "title": "Sequences", + "type": "array" + }, + "topk_prompt_logprobs": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "items": { + "maxItems": 2, + "minItems": 2, + "prefixItems": [ + { + "type": "integer" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Topk Prompt Logprobs" + } + }, + "required": [ + "sequences" + ], + "title": "SampleResponseModel", + "type": "object" + }, + "SampleResponseModelList": { + "description": "Response body for the /sample endpoint", + "properties": { + "samples": { + "description": "List of sample responses", + "items": { + "$ref": "#/components/schemas/SampleResponseModel" + }, + "title": "Samples", + "type": "array" + } + }, + "required": [ + "samples" + ], + "title": "SampleResponseModelList", + "type": "object" + }, + "SampledSequenceModel": { + "description": "A single sampled sequence, mirroring twinkle.data_format.SampledSequence.", + "properties": { + "decoded": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Decoded text of the sampled sequence", + "title": "Decoded" + }, + "logprobs": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "items": { + "maxItems": 2, + "minItems": 2, + "prefixItems": [ + { + "type": "integer" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "type": "null" + } + ], + "description": "Per-token log-probabilities", + "title": "Logprobs" + }, + "new_input_feature": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "description": "Updated InputFeature after sampling (input_ids, labels, etc.)", + "title": "New Input Feature" + }, + "stop_reason": { + "description": "Stop reason: 'length' or 'stop'", + "enum": [ + "length", + "stop" + ], + "title": "Stop Reason", + "type": "string" + }, + "tokens": { + "description": "Token IDs of the sampled sequence", + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + } + }, + "required": [ + "stop_reason", + "tokens" + ], + "title": "SampledSequenceModel", + "type": "object" + }, + "SamplingParams": { + "properties": { + "max_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Max Tokens" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "stop": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stop" + }, + "temperature": { + "default": 1, + "title": "Temperature", + "type": "number" + }, + "top_k": { + "default": -1, + "title": "Top K", + "type": "integer" + }, + "top_p": { + "default": 1, + "title": "Top P", + "type": "number" + } + }, + "title": "SamplingParams", + "type": "object" + }, + "SetTemplateRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "template_cls": { + "title": "Template Cls", + "type": "string" + } + }, + "required": [ + "template_cls", + "adapter_name" + ], + "title": "SetTemplateRequest", + "type": "object" + }, + "SetTemplateResponse": { + "description": "Response for /set_template endpoint.", + "properties": { + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "title": "SetTemplateResponse", + "type": "object" + }, + "UntypedAPIFuture": { + "properties": { + "model_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Id" + }, + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "UntypedAPIFuture", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + }, + "tinker__types__sample_request__SampleRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "model_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Path" + }, + "num_samples": { + "default": 1, + "title": "Num Samples", + "type": "integer" + }, + "prompt": { + "$ref": "#/components/schemas/ModelInput" + }, + "prompt_logprobs": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Prompt Logprobs" + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "sampling_session_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "topk_prompt_logprobs": { + "default": 0, + "title": "Topk Prompt Logprobs", + "type": "integer" + }, + "type": { + "const": "sample", + "default": "sample", + "title": "Type", + "type": "string" + } + }, + "required": [ + "prompt", + "sampling_params" + ], + "title": "SampleRequest", + "type": "object" + }, + "twinkle_client__types__sampler__SampleRequest": { + "description": "Request body for the /sample endpoint.", + "properties": { + "adapter_name": { + "default": "", + "description": "Adapter name for LoRA inference", + "title": "Adapter Name", + "type": "string" + }, + "adapter_uri": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Adapter URI (twinkle:// path or local path) for LoRA inference", + "title": "Adapter Uri" + }, + "inputs": { + "description": "List of Trajectory or InputFeature dicts", + "title": "Inputs" + }, + "sampling_params": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "description": "Sampling parameters (max_tokens, temperature, num_samples, etc.)", + "title": "Sampling Params" + } + }, + "required": [ + "inputs" + ], + "title": "SampleRequest", + "type": "object" + } + } + } +} diff --git a/tests/contract/client_api_harness.py b/tests/contract/client_api_harness.py new file mode 100644 index 000000000..c8d30a385 --- /dev/null +++ b/tests/contract/client_api_harness.py @@ -0,0 +1,152 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Client-API contract harness. + +Builds the four FastAPI apps used by the Ray Serve deployments (Gateway, Model, +Sampler, Processor) by registering their route-registration helpers against a +fresh FastAPI instance, then extracts the client-facing surface (route paths, +HTTP methods, and request/response schemas) as a stable JSON dict. + +Used to: +- snapshot the current surface into ``client_api_baseline.json`` before the + refactor begins, and +- assert post-refactor equality after each phase (cross-cutting freeze guard + for R20 / R18.1). + +Notes: +- The handler factories accept ``(app, self_fn)``; we pass a no-op ``self_fn`` + because route registration only inspects the app object — the closures are + never invoked here. +- We restrict the surface to the client-facing endpoints. The Tinker-public + surface on the Gateway is at ``/*`` (flat, by design), and the Twinkle + surface is at ``/twinkle/*`` everywhere. Internal ``/tinker/*`` routes + registered on Model and Sampler are also captured because the Gateway proxy + forwards Tinker compute requests to them — their request/response schemas + are part of the externally observed Tinker contract. +""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Callable + +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi + + +# ----- App build helpers --------------------------------------------------- # + + +def _noop_self() -> None: + return None + + +def build_gateway_app() -> FastAPI: + from twinkle.server.gateway.tinker_gateway_handlers import _register_tinker_routes + from twinkle.server.gateway.twinkle_gateway_handlers import _register_twinkle_routes + + app = FastAPI() + _register_tinker_routes(app, _noop_self) + _register_twinkle_routes(app, _noop_self) + return app + + +def build_model_app() -> FastAPI: + from twinkle.server.model.tinker_handlers import _register_tinker_routes + from twinkle.server.model.twinkle_handlers import _register_twinkle_routes + + app = FastAPI() + _register_tinker_routes(app, _noop_self) + _register_twinkle_routes(app, _noop_self) + return app + + +def build_sampler_app() -> FastAPI: + from twinkle.server.sampler.tinker_handlers import _register_tinker_sampler_routes + from twinkle.server.sampler.twinkle_handlers import _register_twinkle_sampler_routes + + app = FastAPI() + _register_tinker_sampler_routes(app, _noop_self) + _register_twinkle_sampler_routes(app, _noop_self) + return app + + +def build_processor_app() -> FastAPI: + from twinkle.server.processor.twinkle_handlers import _register_processor_routes + + app = FastAPI() + _register_processor_routes(app, _noop_self) + return app + + +APP_BUILDERS: dict[str, Callable[[], FastAPI]] = { + 'gateway': build_gateway_app, + 'model': build_model_app, + 'sampler': build_sampler_app, + 'processor': build_processor_app, +} + + +# ----- Surface extraction -------------------------------------------------- # + + +_HTTP_METHODS = {'GET', 'POST', 'PUT', 'PATCH', 'DELETE'} + + +def _extract_app_surface(app: FastAPI) -> dict[str, Any]: + """Return the OpenAPI ``paths`` and ``components.schemas`` of ``app``. + + The output is a stable, JSON-serializable view restricted to standard HTTP + methods. Per-operation metadata is reduced to fields that affect the + client contract: ``requestBody``, ``responses``, and ``parameters``. + """ + spec = get_openapi( + title='contract', + version='0.0.0', + routes=app.routes, + ) + + paths: dict[str, dict[str, Any]] = {} + for path, ops in (spec.get('paths') or {}).items(): + clean_ops: dict[str, Any] = {} + for method, op in ops.items(): + if method.upper() not in _HTTP_METHODS: + continue + clean_ops[method.upper()] = { + 'parameters': op.get('parameters', []), + 'requestBody': op.get('requestBody'), + 'responses': op.get('responses', {}), + } + if clean_ops: + paths[path] = clean_ops + + components = (spec.get('components') or {}).get('schemas', {}) + return {'paths': paths, 'schemas': components} + + +def extract_full_surface() -> dict[str, Any]: + """Build all four apps and return a per-app contract surface dict.""" + surface: dict[str, Any] = {} + for name, builder in APP_BUILDERS.items(): + app = builder() + surface[name] = _extract_app_surface(app) + return surface + + +# ----- Baseline I/O -------------------------------------------------------- # + + +BASELINE_PATH = Path(__file__).parent / 'client_api_baseline.json' + + +def write_baseline(path: Path | None = None) -> Path: + """Snapshot the current client-API surface to ``client_api_baseline.json``.""" + p = Path(path) if path is not None else BASELINE_PATH + surface = extract_full_surface() + p.write_text(json.dumps(surface, indent=2, sort_keys=True) + '\n') + return p + + +def load_baseline(path: Path | None = None) -> dict[str, Any]: + p = Path(path) if path is not None else BASELINE_PATH + return json.loads(p.read_text()) diff --git a/tests/contract/test_client_api_contract.py b/tests/contract/test_client_api_contract.py new file mode 100644 index 000000000..ebeca298d --- /dev/null +++ b/tests/contract/test_client_api_contract.py @@ -0,0 +1,97 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Client-facing API contract regression test (R20.3, R20.4, R18.1). + +# Feature: server-config-observability-refactor, Property 28: Client-facing API contract invariance + +The refactor freezes the client-facing HTTP surface: the route paths, HTTP +methods, and request/response schemas of the Tinker (`/*`, `/tinker/*`) and +Twinkle (`/twinkle/*`) endpoints MUST be identical before and after each +phase. This test rebuilds the FastAPI apps from the current sources, extracts +their OpenAPI surface, and asserts equality with the committed baseline at +``tests/contract/client_api_baseline.json``. + +Updating the baseline is intentionally a manual step: run +``python -m tests.contract.update_baseline`` (or call +``client_api_harness.write_baseline()``) only when an API change has been +explicitly approved. +""" +from __future__ import annotations + +import json + +import pytest + +from tests.contract.client_api_harness import ( + APP_BUILDERS, + BASELINE_PATH, + extract_full_surface, + load_baseline, +) + + +def test_baseline_file_exists() -> None: + assert BASELINE_PATH.exists(), ( + f'Baseline {BASELINE_PATH} missing. Generate it with ' + '`python -m tests.contract.update_baseline` after confirming the ' + 'current client-facing surface is correct.' + ) + + +@pytest.mark.parametrize('app_name', sorted(APP_BUILDERS.keys())) +def test_app_surface_matches_baseline(app_name: str) -> None: + """Per-app surface equals the snapshot — narrows failure scope per app.""" + baseline = load_baseline() + current = extract_full_surface() + + expected = baseline.get(app_name) + actual = current.get(app_name) + assert expected is not None, f'baseline is missing app {app_name!r}' + assert actual is not None, f'current surface is missing app {app_name!r}' + + if actual != expected: + diff = _surface_diff(expected, actual) + pytest.fail( + f'Client-API surface for {app_name!r} drifted from the baseline.\n' + f'{diff}\n' + f'If the change is intentional, regenerate the baseline with ' + f'`python -m tests.contract.update_baseline`.' + ) + + +def test_full_surface_matches_baseline() -> None: + """Whole-surface equality — the cross-cutting freeze guard.""" + baseline = load_baseline() + current = extract_full_surface() + assert current == baseline, ( + 'Full client-API surface drifted from the baseline. ' + 'See per-app failures for details.' + ) + + +def _surface_diff(expected: dict, actual: dict) -> str: + exp_paths = set((expected.get('paths') or {}).keys()) + act_paths = set((actual.get('paths') or {}).keys()) + added = sorted(act_paths - exp_paths) + removed = sorted(exp_paths - act_paths) + changed = [] + for p in sorted(exp_paths & act_paths): + if expected['paths'][p] != actual['paths'][p]: + exp_methods = set(expected['paths'][p].keys()) + act_methods = set(actual['paths'][p].keys()) + method_added = sorted(act_methods - exp_methods) + method_removed = sorted(exp_methods - act_methods) + method_diff = '' + if method_added or method_removed: + method_diff = f' methods +{method_added} -{method_removed}' + changed.append(f' {p}{method_diff}') + parts = [] + if added: + parts.append(f' added paths: {added}') + if removed: + parts.append(f' removed paths: {removed}') + if changed: + parts.append(' changed paths:\n' + '\n'.join(changed)) + return '\n'.join(parts) if parts else json.dumps( + {'expected_keys': sorted(expected.keys()), 'actual_keys': sorted(actual.keys())}, + indent=2, + ) diff --git a/tests/contract/update_baseline.py b/tests/contract/update_baseline.py new file mode 100644 index 000000000..685938a7b --- /dev/null +++ b/tests/contract/update_baseline.py @@ -0,0 +1,22 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Regenerate the client-API contract baseline. + +Run with:: + + python -m tests.contract.update_baseline + +Only invoke after confirming that the current client-facing surface has been +intentionally changed and approved as part of this refactor. +""" +from __future__ import annotations + +from tests.contract.client_api_harness import write_baseline + + +def main() -> None: + path = write_baseline() + print(f'Wrote baseline: {path}') + + +if __name__ == '__main__': + main() From 5c33080778aa0dddcc8c6788ef668e8304e914a8 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Mon, 1 Jun 2026 23:58:41 +0800 Subject: [PATCH 11/34] refactor(server): convert TaskQueueConfig to Pydantic with field constraints (Phase 0b) Replaces the dataclass with a Pydantic BaseModel so invalid rate-limit and timeout values are rejected at construction instead of leaking into the running deployment. Constrains rps_limit/tps_limit/queue_timeout/ token_cleanup_interval to >= 0, window_seconds to > 0, and max_input_tokens to int >= 1, matching R9.2-9.5/9.7. Sets extra='forbid' so unknown YAML keys surface immediately. The from_dict(config_dict=None) factory is preserved for the existing call sites in model/sampler/processor apps and now delegates to model_validate({}) when no input is given. Adds property tests (Hypothesis, max_examples=100) for constraint enforcement, from_dict equivalence with model_validate, and the documented defaulting behaviour. The Phase 0a client-API contract baseline is re-run green as the cross-cutting freeze guard. --- src/twinkle/server/utils/task_queue/config.py | 84 +++----- tests/server/utils/__init__.py | 0 tests/server/utils/task_queue/__init__.py | 0 tests/server/utils/task_queue/test_config.py | 179 ++++++++++++++++++ 4 files changed, 207 insertions(+), 56 deletions(-) create mode 100644 tests/server/utils/__init__.py create mode 100644 tests/server/utils/task_queue/__init__.py create mode 100644 tests/server/utils/task_queue/test_config.py diff --git a/src/twinkle/server/utils/task_queue/config.py b/src/twinkle/server/utils/task_queue/config.py index a8b6437b6..1bd26971a 100644 --- a/src/twinkle/server/utils/task_queue/config.py +++ b/src/twinkle/server/utils/task_queue/config.py @@ -2,78 +2,50 @@ """ Task queue configuration. -Provides TaskQueueConfig for controlling rate limits, timeouts, -and queue behavior. +Provides TaskQueueConfig (Pydantic) for controlling rate limits, timeouts, +and queue behavior. Constraints are validated at construction time so an +invalid YAML/dict value is rejected before the deployment reaches a ready +state. """ from __future__ import annotations -from dataclasses import dataclass from typing import Any +from pydantic import BaseModel, ConfigDict, Field -@dataclass -class TaskQueueConfig: + +class TaskQueueConfig(BaseModel): """Configuration for task queue and rate limiting. Attributes: - rps_limit: Maximum requests per second per user token. - tps_limit: Maximum input tokens per second per user token. - window_seconds: Time window for rate limiting calculations. + rps_limit: Maximum requests per second per user token. ``0`` disables. + tps_limit: Maximum input tokens per second per user token. ``0`` disables. + window_seconds: Sliding window for rate-limit calculations. Must be > 0. queue_timeout: Maximum time a task can wait in queue (seconds). execution_timeout: Maximum time a task can execute (seconds). 0 means no limit. enabled: Whether rate limiting is enabled. token_cleanup_multiplier: Multiplier for token cleanup threshold. token_cleanup_interval: How often to run cleanup task (seconds). - max_input_tokens: Maximum allowed input tokens per request (default 16000). + max_input_tokens: Maximum allowed input tokens per request. """ - rps_limit: float = 100.0 # 100 requests per second - tps_limit: float = 16000.0 # 16000 input tokens per second - window_seconds: float = 1.0 # 1 second sliding window - queue_timeout: float = 300.0 # 5 minutes queue timeout - execution_timeout: float = 120.0 # 120 seconds execution timeout (0 to disable) - enabled: bool = True # Rate limiting enabled by default - # Remove tokens after 10x window inactivity - token_cleanup_multiplier: float = 10.0 - token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds - max_input_tokens: int = 16000 # Maximum input tokens per request - @classmethod - def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: - """Create TaskQueueConfig from a dictionary. + model_config = ConfigDict(extra='forbid') - Args: - config_dict: Dictionary with configuration values. Supports keys: - - rps_limit: requests per second limit - - tps_limit: input tokens per second limit - - window_seconds: sliding window duration - - queue_timeout: queue timeout in seconds - - execution_timeout: task execution timeout in seconds (0 to disable) - - enabled: whether rate limiting is enabled - - token_cleanup_multiplier: multiplier for token cleanup threshold - - token_cleanup_interval: cleanup task interval in seconds - - max_input_tokens: maximum input tokens per request + rps_limit: float = Field(default=100.0, ge=0) + tps_limit: float = Field(default=16000.0, ge=0) + window_seconds: float = Field(default=1.0, gt=0) + queue_timeout: float = Field(default=300.0, ge=0) + execution_timeout: float = Field(default=120.0, ge=0) + enabled: bool = True + token_cleanup_multiplier: float = Field(default=10.0, ge=0) + token_cleanup_interval: float = Field(default=60.0, ge=0) + max_input_tokens: int = Field(default=16000, ge=1) + + @classmethod + def from_dict(cls, config_dict: dict[str, Any] | None = None) -> 'TaskQueueConfig': + """Validate ``config_dict`` (or ``{}``) into a ``TaskQueueConfig``. - Returns: - TaskQueueConfig instance with values from dict merged with defaults. + Equivalent to ``cls.model_validate(config_dict or {})`` — kept for + call-site compatibility and to make the entry point explicit. """ - config = cls() - if config_dict: - if 'rps_limit' in config_dict: - config.rps_limit = float(config_dict['rps_limit']) - if 'tps_limit' in config_dict: - config.tps_limit = float(config_dict['tps_limit']) - if 'window_seconds' in config_dict: - config.window_seconds = float(config_dict['window_seconds']) - if 'queue_timeout' in config_dict: - config.queue_timeout = float(config_dict['queue_timeout']) - if 'execution_timeout' in config_dict: - config.execution_timeout = float(config_dict['execution_timeout']) - if 'enabled' in config_dict: - config.enabled = bool(config_dict['enabled']) - if 'token_cleanup_multiplier' in config_dict: - config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier']) - if 'token_cleanup_interval' in config_dict: - config.token_cleanup_interval = float(config_dict['token_cleanup_interval']) - if 'max_input_tokens' in config_dict: - config.max_input_tokens = int(config_dict['max_input_tokens']) - return config + return cls.model_validate(config_dict or {}) diff --git a/tests/server/utils/__init__.py b/tests/server/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/utils/task_queue/__init__.py b/tests/server/utils/task_queue/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/utils/task_queue/test_config.py b/tests/server/utils/task_queue/test_config.py new file mode 100644 index 000000000..c76096a1a --- /dev/null +++ b/tests/server/utils/task_queue/test_config.py @@ -0,0 +1,179 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the Pydantic ``TaskQueueConfig`` (R9). + +Covers: +- # Feature: server-config-observability-refactor, Property 16: TaskQueueConfig constraint enforcement +- # Feature: server-config-observability-refactor, Property 17: from_dict equivalence +- # Feature: server-config-observability-refactor, Property 18: TaskQueueConfig defaulting +- Unit checks that the call sites in model/sampler/processor apps still construct the config + through ``TaskQueueConfig.from_dict``. +""" +from __future__ import annotations + +import pytest +from hypothesis import given, settings, strategies as st +from pydantic import ValidationError + +from twinkle.server.utils.task_queue.config import TaskQueueConfig + + +# ---------- defaults snapshot used by Property 18 (R9.8) ------------------- # + +DEFAULTS = { + 'rps_limit': 100.0, + 'tps_limit': 16000.0, + 'window_seconds': 1.0, + 'queue_timeout': 300.0, + 'token_cleanup_interval': 60.0, + 'max_input_tokens': 16000, +} + + +# ---------- Property 16: constraint enforcement (R9.2-9.5, 9.7) ------------ # + + +_CONSTRAINED_GE0_FLOATS = ['rps_limit', 'tps_limit', 'queue_timeout', 'token_cleanup_interval'] + + +@settings(max_examples=100) +@given( + field=st.sampled_from(_CONSTRAINED_GE0_FLOATS), + bad_value=st.floats(max_value=-1e-6, min_value=-1e6, allow_nan=False, allow_infinity=False), +) +def test_property_16_ge0_floats_reject_negative(field: str, bad_value: float) -> None: + """Non-negative float fields reject any negative input.""" + with pytest.raises(ValidationError) as exc: + TaskQueueConfig(**{field: bad_value}) + assert any(field in err['loc'] for err in exc.value.errors()) + + +@settings(max_examples=100) +@given(bad_value=st.floats(max_value=0.0, min_value=-1e6, allow_nan=False, allow_infinity=False)) +def test_property_16_window_seconds_rejects_zero_and_negative(bad_value: float) -> None: + """``window_seconds`` must be strictly > 0.""" + with pytest.raises(ValidationError) as exc: + TaskQueueConfig(window_seconds=bad_value) + assert any('window_seconds' in err['loc'] for err in exc.value.errors()) + + +@settings(max_examples=100) +@given(bad_value=st.integers(max_value=0, min_value=-1_000_000)) +def test_property_16_max_input_tokens_rejects_lt_1(bad_value: int) -> None: + """``max_input_tokens`` must be an integer ≥ 1.""" + with pytest.raises(ValidationError) as exc: + TaskQueueConfig(max_input_tokens=bad_value) + assert any('max_input_tokens' in err['loc'] for err in exc.value.errors()) + + +@settings(max_examples=100) +@given( + rps=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + tps=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + win=st.floats(min_value=1e-6, max_value=1e6, allow_nan=False, allow_infinity=False), + qt=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + cleanup=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + mit=st.integers(min_value=1, max_value=10_000_000), +) +def test_property_16_valid_values_accepted( + rps: float, tps: float, win: float, qt: float, cleanup: float, mit: int +) -> None: + """Any value satisfying the constraints constructs successfully.""" + cfg = TaskQueueConfig( + rps_limit=rps, + tps_limit=tps, + window_seconds=win, + queue_timeout=qt, + token_cleanup_interval=cleanup, + max_input_tokens=mit, + ) + assert cfg.rps_limit == rps + assert cfg.window_seconds == win + assert cfg.max_input_tokens == mit + + +# ---------- Property 17: from_dict equivalence (R9.6) ---------------------- # + + +_INPUT_DICT_STRATEGY = st.fixed_dictionaries( + {}, + optional={ + 'rps_limit': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'tps_limit': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'window_seconds': st.floats(min_value=1e-6, max_value=1e6, allow_nan=False, allow_infinity=False), + 'queue_timeout': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'token_cleanup_interval': st.floats( + min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'max_input_tokens': st.integers(min_value=1, max_value=10_000_000), + 'enabled': st.booleans(), + 'execution_timeout': st.floats( + min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'token_cleanup_multiplier': st.floats( + min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + }, +) + + +@settings(max_examples=100) +@given(payload=_INPUT_DICT_STRATEGY) +def test_property_17_from_dict_equivalence(payload: dict) -> None: + """``from_dict`` returns the same instance as ``model_validate``.""" + via_factory = TaskQueueConfig.from_dict(payload) + via_validate = TaskQueueConfig.model_validate(payload) + assert via_factory.model_dump() == via_validate.model_dump() + + +# ---------- Property 18: defaulting (R9.8) --------------------------------- # + + +def test_property_18_from_dict_with_no_argument() -> None: + """``from_dict()`` with no argument returns the documented defaults.""" + cfg = TaskQueueConfig.from_dict() + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +def test_property_18_from_dict_with_none() -> None: + cfg = TaskQueueConfig.from_dict(None) + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +def test_property_18_from_dict_with_empty_dict() -> None: + cfg = TaskQueueConfig.from_dict({}) + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +@settings(max_examples=100) +@given(present=st.sets(st.sampled_from(sorted(DEFAULTS.keys())), max_size=len(DEFAULTS))) +def test_property_18_omitted_fields_take_defaults(present: set) -> None: + """Fields absent from the dict adopt their documented defaults.""" + payload = {f: DEFAULTS[f] for f in present} + cfg = TaskQueueConfig.from_dict(payload) + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +# ---------- Unit: extra=forbid + call-site usage --------------------------- # + + +def test_extra_field_rejected() -> None: + """``extra='forbid'`` rejects unknown keys (defends R8.2 scoped to this model).""" + with pytest.raises(ValidationError): + TaskQueueConfig.from_dict({'unknown_field': 1}) + + +def test_call_site_imports_resolve() -> None: + """``TaskQueueConfig.from_dict`` is what the apps call — keep that import alive. + + We don't instantiate the deployments (those need Ray Serve runtime); we just + confirm the call sites import the same name and the factory still produces + valid configs for the dicts they pass. + """ + from twinkle.server.utils.task_queue import TaskQueueConfig as Exported + + assert Exported is TaskQueueConfig + # Mimic the queue_config dicts shipped in cookbook/client/server YAMLs. + cfg = TaskQueueConfig.from_dict({'rps_limit': 100, 'tps_limit': 100000}) + assert cfg.rps_limit == 100.0 + assert cfg.tps_limit == 100000.0 From 2a074bf06be622d159bb88bcd9776bd65c84b250 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:08:00 +0800 Subject: [PATCH 12/34] feat(server): introduce typed ServerConfig aggregate root (Phase 0c) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a single Pydantic aggregate root that drives the launcher: ServerConfig nests TelemetryConfig, PersistenceConfig, TaskQueueConfig, and a list of typed ApplicationSpec entries. Each per-deployment args block has its own schema (ModelArgs/SamplerArgs/ServerArgs/ProcessorArgs) with extra='forbid', so unknown keys and out-of-range values surface at load time with the offending field path. backend (model) and sampler_type (sampler) are introduced as Literal-validated selectors, replacing the legacy use_megatron boolean — Phase 1 will wire the actual dispatch on these values. ServerConfig.from_yaml is the single load entry point: FileNotFoundError on a missing path, ConfigParseError on malformed YAML, ValidationError on field or cross-field violations. The cross-field validator rejects redis mode without redis_url and file mode without file_path. ServerLauncher now requires a typed ServerConfig and rejects raw dicts; from_yaml became a thin wrapper. Legacy field names telemetry_config/persistence_config are rejected per the breaking-change clause in R8. Migrates the cookbook example configs (transformer + megatron) to the new field names and adds property tests covering valid/invalid loads, round-trip fidelity, and legacy-name rejection. The Phase 0a client-API contract baseline is re-run green as the cross-cutting freeze guard. Adds ConfigError (field/value/allowed) and ConfigParseError to server.exceptions for callers that want a single non-pydantic exception type to catch. --- .../client/server/megatron/server_config.yaml | 6 +- .../server/megatron/server_config_4b.yaml | 2 +- .../server/transformer/server_config.yaml | 6 +- src/twinkle/server/config/__init__.py | 22 ++ src/twinkle/server/config/application_spec.py | 147 +++++++++++ src/twinkle/server/config/server_config.py | 92 +++++++ src/twinkle/server/exceptions.py | 35 +++ src/twinkle/server/launcher.py | 156 +++++------- src/twinkle/server/model/app.py | 11 +- tests/server/config/__init__.py | 0 tests/server/config/test_server_config.py | 232 ++++++++++++++++++ 11 files changed, 611 insertions(+), 98 deletions(-) create mode 100644 src/twinkle/server/config/__init__.py create mode 100644 src/twinkle/server/config/application_spec.py create mode 100644 src/twinkle/server/config/server_config.py create mode 100644 tests/server/config/__init__.py create mode 100644 tests/server/config/test_server_config.py diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 6c9337402..ae6e9236c 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -16,7 +16,7 @@ http_options: # 2. set enabled: true (and debug: true to dump exporters to console for local dev, # or leave debug: false and point otlp_endpoint at an OTLP collector — see # cookbook/observability/ for a docker-compose example). -# telemetry_config: +# telemetry: # enabled: false # debug: false # service_name: twinkle-server @@ -29,7 +29,7 @@ http_options: # mode: memory | file | redis # file_path: required for `file` mode # redis_url / key_prefix: required for `redis` mode -# persistence_config: +# persistence: # mode: file # file_path: /tmp/twinkle_state.json @@ -108,7 +108,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.6-27B import_path: model args: - use_megatron: true # Use Megatron-LM backend + backend: megatron # Use Megatron-LM backend model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier max_length: 65536 # model max length max_loras: 3 # model max loras diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index 36adb3328..7eed4699d 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -38,7 +38,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - use_megatron: true + backend: megatron model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 2 # Number of GPU processes per node diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index cc7f37780..390871165 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -16,7 +16,7 @@ http_options: # 2. set enabled: true (and debug: true to dump exporters to console for local dev, # or leave debug: false and point otlp_endpoint at an OTLP collector — see # cookbook/observability/ for a docker-compose example). -telemetry_config: +telemetry: enabled: false debug: false service_name: twinkle-server @@ -29,7 +29,7 @@ telemetry_config: # mode: memory | file | redis # file_path: required for `file` mode # redis_url / key_prefix: required for `redis` mode -persistence_config: +persistence: mode: file file_path: /tmp/twinkle_state.json @@ -62,7 +62,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - use_megatron: false # Use HuggingFace Transformers backend + backend: transformers # Model backend: mock | transformers | megatron model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 1 # Number of GPU processes per node diff --git a/src/twinkle/server/config/__init__.py b/src/twinkle/server/config/__init__.py new file mode 100644 index 000000000..c5502e85e --- /dev/null +++ b/src/twinkle/server/config/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Server configuration package — aggregate root and per-deployment specs.""" + +from .application_spec import ( + ApplicationSpec, + HttpOptions, + ModelArgs, + ProcessorArgs, + SamplerArgs, + ServerArgs, +) +from .server_config import ServerConfig + +__all__ = [ + 'ApplicationSpec', + 'HttpOptions', + 'ModelArgs', + 'ProcessorArgs', + 'SamplerArgs', + 'ServerArgs', + 'ServerConfig', +] diff --git a/src/twinkle/server/config/application_spec.py b/src/twinkle/server/config/application_spec.py new file mode 100644 index 000000000..93d6f1e57 --- /dev/null +++ b/src/twinkle/server/config/application_spec.py @@ -0,0 +1,147 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Per-deployment ``ApplicationSpec`` and typed argument schemas (R6, R3). + +Each deployment kind (``server | model | sampler | processor``) carries its +own ``args`` block with strict field validation. ``ApplicationSpec`` holds +the routing metadata plus the deployment kind and validates ``args`` against +the matching ``*Args`` schema in a model validator. + +The schemas use ``extra='forbid'`` so unknown args (typos, copy-paste from +other deployments) fail at load time instead of being silently dropped. +""" +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from twinkle.server.utils.task_queue.config import TaskQueueConfig + + +# ---------- shared helpers ------------------------------------------------- # + + +class _ArgsBase(BaseModel): + """Base class for every per-deployment args schema (extra='forbid').""" + + model_config = ConfigDict(extra='forbid') + + +class HttpOptions(BaseModel): + """HTTP listener settings (host/port). + + Re-exported from ``server_config`` for convenience and so that + ``ServerArgs`` can carry it without importing the aggregate root. + """ + + model_config = ConfigDict(extra='forbid') + + host: str = 'localhost' + port: int = 8000 + + +# ---------- per-deployment args schemas (R3.x) ----------------------------- # + + +_BACKEND_VALUES: tuple[str, ...] = ('mock', 'transformers', 'megatron') +_SAMPLER_TYPE_VALUES: tuple[str, ...] = ('mock', 'vllm', 'torch') + + +class ModelArgs(_ArgsBase): + """Args for the ``model`` deployment. + + The ``backend`` field selects the model implementation and replaces the + legacy ``use_megatron: bool`` flag. Phase 0c introduces this field; the + actual dispatch on its value is wired up in Phase 1 (R3.1-3.3, R3.9). + """ + + model_id: str + nproc_per_node: int = 1 + device_group: dict[str, Any] + device_mesh: dict[str, Any] + backend: Literal['mock', 'transformers', 'megatron'] + adapter_config: dict[str, Any] | None = None + queue_config: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + max_loras: int = 5 + max_length: int | None = None + + +class SamplerArgs(_ArgsBase): + """Args for the ``sampler`` deployment. + + ``sampler_type`` selects the sampler implementation (R3.4-3.6, R3.10). + """ + + model_id: str + nproc_per_node: int = 1 + device_group: dict[str, Any] + device_mesh: dict[str, Any] + sampler_type: Literal['mock', 'vllm', 'torch'] + engine_args: dict[str, Any] | None = None + queue_config: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + + +class ServerArgs(_ArgsBase): + """Args for the gateway ``server`` deployment.""" + + server_config: dict[str, Any] | None = None + supported_models: list[Any] | None = None + http_options: HttpOptions | None = None + route_prefix: str | None = None + + +class ProcessorArgs(_ArgsBase): + """Args for the ``processor`` deployment.""" + + ncpu_proc_per_node: int | None = None + device_group: dict[str, Any] | None = None + device_mesh: dict[str, Any] | None = None + queue_config: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + + +_ARGS_SCHEMA: dict[str, type[_ArgsBase]] = { + 'server': ServerArgs, + 'model': ModelArgs, + 'sampler': SamplerArgs, + 'processor': ProcessorArgs, +} + + +# ---------- ApplicationSpec ------------------------------------------------ # + + +class ApplicationSpec(BaseModel): + """One application entry under ``ServerConfig.applications``. + + The ``args`` block is validated against the schema selected by + ``import_path``: ``server`` → ``ServerArgs``, ``model`` → ``ModelArgs``, + etc. Unknown keys at this level (or inside ``args``) are rejected. + """ + + model_config = ConfigDict(extra='forbid') + + name: str + route_prefix: str = '/' + import_path: Literal['server', 'model', 'sampler', 'processor'] + args: ServerArgs | ModelArgs | SamplerArgs | ProcessorArgs = Field( + default_factory=lambda: ServerArgs(), + ) + deployments: list[dict[str, Any]] = Field(default_factory=list) + + @model_validator(mode='before') + @classmethod + def _coerce_args_to_schema(cls, data: Any) -> Any: + """Validate the raw ``args`` dict against the schema for ``import_path``. + + Pydantic's union resolution would also work here, but keying off + ``import_path`` makes the failure messages point at the right schema + and avoids ambiguity when two schemas share field names. + """ + if not isinstance(data, dict): + return data + import_path = data.get('import_path') + args = data.get('args') + if import_path in _ARGS_SCHEMA and isinstance(args, dict): + schema = _ARGS_SCHEMA[import_path] + data = {**data, 'args': schema.model_validate(args)} + return data diff --git a/src/twinkle/server/config/server_config.py b/src/twinkle/server/config/server_config.py new file mode 100644 index 000000000..a285272e8 --- /dev/null +++ b/src/twinkle/server/config/server_config.py @@ -0,0 +1,92 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Aggregate-root server configuration (R6, R7, R8). + +``ServerConfig`` is the single Pydantic model that nests every configuration +subsystem the launcher consumes (telemetry, persistence, task-queue, and the +list of ``ApplicationSpec``). Loading a YAML file and validating it goes +through one entry point — ``ServerConfig.from_yaml(path)`` — so the launcher +no longer reaches into a raw dict. + +Top-level fields use their current names with no aliases for legacy names +(``telemetry_config``, ``persistence_config``); the model is configured with +``extra='forbid'`` so a YAML that uses a legacy field is rejected with the +offending name pointed at (R8.1, R8.2). +""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from twinkle.server.exceptions import ConfigParseError +from twinkle.server.state.backend.factory import PersistenceConfig +from twinkle.server.telemetry.provider import TelemetryConfig +from twinkle.server.utils.task_queue.config import TaskQueueConfig + +from .application_spec import ApplicationSpec, HttpOptions + + +class ServerConfig(BaseModel): + """Top-level server configuration aggregate root.""" + + model_config = ConfigDict(extra='forbid') + + ray_namespace: str | None = None + proxy_location: str | None = None + http_options: HttpOptions = Field(default_factory=HttpOptions) + telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig) + persistence: PersistenceConfig = Field(default_factory=PersistenceConfig) + task_queue: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + applications: list[ApplicationSpec] = Field(default_factory=list) + + # ---- loading ---------------------------------------------------------- # + + @classmethod + def from_yaml(cls, path: str | Path) -> 'ServerConfig': + """Load and validate a YAML file into a ``ServerConfig``. + + Raises: + FileNotFoundError: ``path`` does not exist or cannot be read. + ConfigParseError: ``path`` exists but is not well-formed YAML. + pydantic.ValidationError: a field or cross-field constraint + fails — the error names every offending field. + """ + from omegaconf import OmegaConf + + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f'Config file not found: {p}') + try: + raw = OmegaConf.to_container(OmegaConf.load(p), resolve=True) + except Exception as e: # malformed YAML / OmegaConf parse failure + raise ConfigParseError(f'Malformed YAML in {p}: {e}') from e + if raw is None: + raw = {} + if not isinstance(raw, dict): + raise ConfigParseError( + f'Top-level YAML in {p} must be a mapping, got {type(raw).__name__}', + ) + return cls.model_validate(raw) + + # ---- cross-field validation (R7) ------------------------------------- # + + @model_validator(mode='after') + def _validate_cross_field(self) -> 'ServerConfig': + # R7.1: redis mode requires redis_url + if self.persistence.mode == 'redis' and not self.persistence.redis_url: + raise ValueError( + "persistence.redis_url is required when persistence.mode == 'redis'", + ) + # R7.2: file mode requires file_path + if self.persistence.mode == 'file' and not self.persistence.file_path: + raise ValueError( + "persistence.file_path is required when persistence.mode == 'file'", + ) + return self + + # ---- round-trip / serialization (R6.7) ------------------------------- # + + def to_yaml_dict(self) -> dict[str, Any]: + """Return a JSON-mode dict suitable for ``yaml.safe_dump`` / round-trip.""" + return self.model_dump(mode='json') diff --git a/src/twinkle/server/exceptions.py b/src/twinkle/server/exceptions.py index 649c40b59..df0b19f5a 100644 --- a/src/twinkle/server/exceptions.py +++ b/src/twinkle/server/exceptions.py @@ -18,6 +18,41 @@ class ConfigMismatchError(TwinkleServerError): pass +class ConfigError(TwinkleServerError): + """Invalid configuration value for a known field. + + Used when a field is present and parseable but its value is not in the + permitted set (e.g. ``backend`` is ``""`` or ``"hf"``). Carries enough + detail for the operator to find and fix the offending YAML entry without + re-running the server. + """ + + def __init__( + self, + field: str, + value: object, + allowed: list[str] | tuple[str, ...] | None = None, + message: str | None = None, + ) -> None: + self.field = field + self.value = value + self.allowed = list(allowed) if allowed is not None else None + if message is None: + allowed_part = f', allowed: {self.allowed}' if self.allowed is not None else '' + message = f'Invalid value for {field}: {value!r}{allowed_part}' + super().__init__(message) + + +class ConfigParseError(TwinkleServerError): + """The configuration source could not be parsed (malformed YAML, ...). + + Distinct from ``pydantic.ValidationError`` (which signals that a parsed + value violates a field/cross-field rule) and from ``FileNotFoundError`` + (which signals that the source could not be read at all). + """ + pass + + class ResourceExhaustedError(TwinkleServerError): """Resource exhausted — queue full, insufficient memory, connection pool exhausted, etc.""" pass diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 2a739b32b..3e839f00f 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -27,6 +27,7 @@ from typing import Any, Callable, Dict, NoReturn, Optional, Union from twinkle import get_logger +from twinkle.server.config import ServerConfig from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches logger = get_logger() @@ -53,17 +54,26 @@ class ServerLauncher: def __init__( self, - config: dict[str, Any] | None = None, + config: ServerConfig, ray_namespace: str | None = None, ): """ Initialize the server launcher. Args: - config: Configuration dictionary + config: A validated :class:`ServerConfig` instance. Raw dicts are + rejected — operators must build ``ServerConfig`` via + ``ServerConfig.from_yaml`` or its constructor so cross-field + validation runs before the launcher consumes anything (R6.6). ray_namespace: Ray namespace (default: 'twinkle_cluster') """ - self.config = config or {} + if not isinstance(config, ServerConfig): + raise TypeError( + 'ServerLauncher requires a typed ServerConfig instance; ' + f'got {type(config).__name__}. Build one with ' + 'ServerConfig.from_yaml(path) or ServerConfig(...).' + ) + self.config: ServerConfig = config self.ray_namespace = ray_namespace self._builders: dict[str, Callable] = {} self._ray_initialized = False @@ -159,7 +169,7 @@ def _init_ray(self) -> None: import ray - namespace = self.ray_namespace or self.config.get('ray_namespace') or 'twinkle_cluster' + namespace = self.ray_namespace or self.config.ray_namespace or 'twinkle_cluster' if not ray.is_initialized(): # Use runtime_env to apply patches in worker processes @@ -197,30 +207,28 @@ def _start_serve(self) -> None: # Serve not running — nothing to shut down pass - http_options = self.config.get('http_options', {}) - if isinstance(http_options, dict): - http_options = dict(http_options) - else: - http_options = dict(http_options) if http_options else {} - + http_options = self.config.http_options.model_dump() serve.start(http_options=http_options) logger.info(f'Ray Serve started with http_options={http_options}') self._serve_started = True - def _deploy_application(self, app_config: dict[str, Any]) -> None: + def _deploy_application(self, app_spec: 'ApplicationSpec') -> None: """Deploy a single application. Args: - app_config: Application configuration dictionary + app_spec: Validated :class:`ApplicationSpec` from the typed config. """ from ray import serve - name = app_config.get('name', 'app') - route_prefix = app_config.get('route_prefix', '/') - import_path = app_config.get('import_path', 'server') - args = app_config.get('args', {}) or {} - deployments = app_config.get('deployments', []) + name = app_spec.name + route_prefix = app_spec.route_prefix + import_path = app_spec.import_path + # Re-serialize the typed args back to a kwargs dict for the builder. + # Using ``mode='python'`` keeps nested Pydantic models as dicts (which + # the legacy builders expect) without losing field-level validation. + args = app_spec.args.model_dump(mode='python', exclude_none=True) + deployments = list(app_spec.deployments or []) logger.info(f'Starting {name} at {route_prefix}...') @@ -252,11 +260,10 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: deploy_options['ray_actor_options'] = ray_actor_options # Pass http_options to server apps for internal proxy routing - http_options = self.config.get('http_options', {}) - if import_path == 'server' and http_options: - args['http_options'] = http_options + if import_path == 'server': + args.setdefault('http_options', self.config.http_options.model_dump()) - app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()}) + app = builder(deploy_options=deploy_options, **args) serve.run(app, name=name, route_prefix=route_prefix) logger.info(f'Deployed {name} at {route_prefix}') @@ -272,55 +279,45 @@ def launch(self) -> None: apply_ray_serve_patches() # Initialize telemetry if configured - telemetry_config = self.config.get('telemetry_config', {}) - if telemetry_config: - from twinkle.server.telemetry import TelemetryConfig, init_telemetry - config = TelemetryConfig(**telemetry_config) - init_telemetry(config) + telemetry = self.config.telemetry + if telemetry.enabled: + from twinkle.server.telemetry import init_telemetry + init_telemetry(telemetry) # Export config to env vars for Ray worker processes import os os.environ['TWINKLE_TELEMETRY_ENABLED'] = '1' - os.environ['TWINKLE_TELEMETRY_DEBUG'] = '1' if config.debug else '0' - os.environ['TWINKLE_TELEMETRY_SERVICE'] = config.service_name - os.environ['TWINKLE_TELEMETRY_ENDPOINT'] = config.otlp_endpoint - os.environ['TWINKLE_TELEMETRY_INTERVAL'] = str(config.export_interval_ms) + os.environ['TWINKLE_TELEMETRY_DEBUG'] = '1' if telemetry.debug else '0' + os.environ['TWINKLE_TELEMETRY_SERVICE'] = telemetry.service_name + os.environ['TWINKLE_TELEMETRY_ENDPOINT'] = telemetry.otlp_endpoint + os.environ['TWINKLE_TELEMETRY_INTERVAL'] = str(telemetry.export_interval_ms) - # Export top-level persistence_config to env vars so any worker + # Export top-level persistence to env vars so any worker # (not just Gateway) can build the same backend on first call to # get_server_state(). - persistence_config_dict = self.config.get('persistence_config') - if persistence_config_dict: - import os - from twinkle.server.state.backend.factory import PersistenceConfig - pconfig = PersistenceConfig(**persistence_config_dict) - for k, v in pconfig.to_env_vars().items(): - os.environ[k] = v - logger.info(f'Persistence backend configured: mode={pconfig.mode}') + persistence = self.config.persistence + import os + for k, v in persistence.to_env_vars().items(): + os.environ[k] = v + logger.info(f'Persistence backend configured: mode={persistence.mode}') self._init_ray() self._start_serve() - applications = self.config.get('applications', []) + applications = self.config.applications if not applications: logger.warning('No applications configured') return - for app_config in applications: - if isinstance(app_config, dict): - self._deploy_application(app_config) - else: - self._deploy_application(dict(app_config)) + for app_spec in applications: + self._deploy_application(app_spec) - http_options = self.config.get('http_options', {}) - host = http_options.get('host', 'localhost') - port = http_options.get('port', 8000) + host = self.config.http_options.host + port = self.config.http_options.port print('\nAll applications started!') print('Endpoints:') - for app_config in applications: - route_prefix = app_config.get('route_prefix', '/') if isinstance(app_config, - dict) else app_config.route_prefix - print(f' - http://{host}:{port}{route_prefix}') + for app_spec in applications: + print(f' - http://{host}:{port}{app_spec.route_prefix}') # Graceful shutdown via signal handling shutdown_event = threading.Event() @@ -349,59 +346,38 @@ def from_yaml( config_path: str | Path, ray_namespace: str | None = None, ) -> ServerLauncher: - """ - Create a ServerLauncher from a YAML config file. - - Args: - config_path: Path to the YAML config file - ray_namespace: Override Ray namespace from config + """Build a ``ServerLauncher`` from a YAML config file. - Returns: - Configured ServerLauncher instance + Thin wrapper over :meth:`ServerConfig.from_yaml`. ``FileNotFoundError`` + / ``ConfigParseError`` / ``pydantic.ValidationError`` propagate so the + caller can surface a precise message before the launcher is constructed. """ - from omegaconf import OmegaConf - - config_path = Path(config_path) - if not config_path.exists(): - raise FileNotFoundError(f'Config file not found: {config_path}') - - config = OmegaConf.load(config_path) - config_dict = OmegaConf.to_container(config, resolve=True) - + config = ServerConfig.from_yaml(config_path) return cls( - config=config_dict, - ray_namespace=ray_namespace or config_dict.get('ray_namespace'), + config=config, + ray_namespace=ray_namespace or config.ray_namespace, ) def launch_server( - config: dict[str, Any] | None = None, + config: ServerConfig | None = None, config_path: str | Path | None = None, ray_namespace: str | None = None, ) -> None: - """ - Launch a twinkle server with flexible configuration options. - - This is the main entry point for launching servers programmatically. - The call blocks until a SIGINT/SIGTERM signal is received. + """Launch a twinkle server. - Args: - config: Configuration dictionary (takes precedence over config_path) - config_path: Path to YAML config file - ray_namespace: Ray namespace + Exactly one of ``config`` (a :class:`ServerConfig` instance) or + ``config_path`` (a YAML file) must be provided. The call blocks until a + SIGINT/SIGTERM signal is received. Raises: - ValueError: If neither config nor config_path is provided + ValueError: neither ``config`` nor ``config_path`` was provided. + TypeError: ``config`` is not a :class:`ServerConfig` instance — raw + dicts are no longer accepted (R6.6). Examples: - # From YAML config launch_server(config_path="server_config.yaml") - - # From Python dict - launch_server(config={ - "http_options": {"host": "0.0.0.0", "port": 8000}, - "applications": [...] - }) + launch_server(config=ServerConfig(...)) """ if config is None and config_path is None: raise ValueError("Either 'config' or 'config_path' must be provided") @@ -409,7 +385,7 @@ def launch_server( if config is not None: launcher = ServerLauncher( config=config, - ray_namespace=ray_namespace or config.get('ray_namespace'), + ray_namespace=ray_namespace or config.ray_namespace, ) else: launcher = ServerLauncher.from_yaml( diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 59f6f00e2..0b558d6e9 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -49,7 +49,14 @@ def __init__(self, use_megatron: bool = False, adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, + backend: str | None = None, **kwargs): + # Bridge: Phase 0c introduces ``backend`` as the canonical selector; + # Phase 1 replaces the use_megatron branch below with full dispatch + # on this field. Until then, derive use_megatron from backend when + # supplied so both YAML schemas keep working through the transition. + if backend is not None: + use_megatron = backend == 'megatron' self.device_group = DeviceGroup(**device_group) twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) if 'mesh_dim_names' in device_mesh: @@ -137,6 +144,7 @@ def build_model_app(model_id: str, use_megatron: bool = False, adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, + backend: str | None = None, **kwargs): """Build a unified model management application for distributed training. @@ -200,7 +208,8 @@ async def verify_token(request: Request, call_next): )( ModelManagementWithIngress) return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, - use_megatron, adapter_config, queue_config, **kwargs) + use_megatron, adapter_config, queue_config, + backend=backend, **kwargs) build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/tests/server/config/__init__.py b/tests/server/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/config/test_server_config.py b/tests/server/config/test_server_config.py new file mode 100644 index 000000000..f46f1c4f3 --- /dev/null +++ b/tests/server/config/test_server_config.py @@ -0,0 +1,232 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the typed ``ServerConfig`` (R6, R7, R8). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 12: Valid configuration yields a fully validated instance +- # Feature: server-config-observability-refactor, Property 13: Any constraint violation is rejected with the offending field named +- # Feature: server-config-observability-refactor, Property 14: Configuration round-trip fidelity +- # Feature: server-config-observability-refactor, Property 15: Legacy / unknown field names are rejected +""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from hypothesis import given, settings, strategies as st +from pydantic import ValidationError + +from twinkle.server.config import ApplicationSpec, ServerConfig +from twinkle.server.exceptions import ConfigParseError +from twinkle.server.launcher import ServerLauncher + + +# ---------- minimal valid config strategy ---------------------------------- # + + +_PERSISTENCE_VARIANTS = st.one_of( + st.fixed_dictionaries({'mode': st.just('memory')}), + st.fixed_dictionaries( + {'mode': st.just('file'), 'file_path': st.just('/tmp/state.json')} + ), + st.fixed_dictionaries( + {'mode': st.just('redis'), 'redis_url': st.just('redis://localhost:6379/0')} + ), +) + + +_MODEL_APP = st.fixed_dictionaries( + { + 'name': st.just('m'), + 'route_prefix': st.just('/api/v1/m'), + 'import_path': st.just('model'), + 'args': st.fixed_dictionaries( + { + 'model_id': st.just('model-id'), + 'device_group': st.just({'name': 'g', 'ranks': 1, 'device_type': 'CPU'}), + 'device_mesh': st.just({'device_type': 'CPU', 'dp_size': 1}), + 'backend': st.sampled_from(['mock', 'transformers', 'megatron']), + } + ), + } +) + + +_VALID_CONFIG = st.fixed_dictionaries( + { + 'persistence': _PERSISTENCE_VARIANTS, + 'applications': st.lists(_MODEL_APP, min_size=0, max_size=3), + } +) + + +# ---------- Property 12: valid → fully validated (R6.2, R7.3) -------------- # + + +@settings(max_examples=100) +@given(payload=_VALID_CONFIG) +def test_property_12_valid_payload_yields_full_instance(payload: dict) -> None: + cfg = ServerConfig.model_validate(payload) + assert isinstance(cfg, ServerConfig) + assert all(isinstance(a, ApplicationSpec) for a in cfg.applications) + # Nested sections instantiated and validated. + assert cfg.persistence.mode == payload['persistence']['mode'] + assert cfg.task_queue.rps_limit >= 0 + + +# ---------- Property 13: violation → field-named error (R6.3, R7.1, R7.2) -- # + + +def test_property_13_redis_mode_missing_url() -> None: + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate({'persistence': {'mode': 'redis'}}) + msg = str(exc.value) + assert 'persistence.redis_url' in msg or 'redis_url' in msg + + +def test_property_13_file_mode_missing_path() -> None: + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate({'persistence': {'mode': 'file'}}) + msg = str(exc.value) + assert 'persistence.file_path' in msg or 'file_path' in msg + + +@settings(max_examples=100) +@given(bad_backend=st.text(min_size=1, max_size=8).filter( + lambda s: s not in ('mock', 'transformers', 'megatron') +)) +def test_property_13_bad_backend_names_field(bad_backend: str) -> None: + payload = { + 'applications': [ + { + 'name': 'm', + 'import_path': 'model', + 'args': { + 'model_id': 'x', + 'device_group': {}, + 'device_mesh': {}, + 'backend': bad_backend, + }, + } + ] + } + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate(payload) + errors = exc.value.errors() + assert any('backend' in err['loc'] for err in errors) + + +@settings(max_examples=100) +@given(bad_max_input_tokens=st.integers(max_value=0, min_value=-1000)) +def test_property_13_nested_field_constraint_violation_named(bad_max_input_tokens: int) -> None: + """Nested-section constraints (here ``task_queue.max_input_tokens``) are + enforced together with cross-field ones (R7.3) and the offending path is + visible in the error.""" + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate( + {'task_queue': {'max_input_tokens': bad_max_input_tokens}} + ) + errors = exc.value.errors() + assert any('max_input_tokens' in err['loc'] for err in errors) + + +# ---------- Property 14: round-trip fidelity (R6.7) ------------------------ # + + +@settings(max_examples=100) +@given(payload=_VALID_CONFIG) +def test_property_14_round_trip_fidelity(payload: dict) -> None: + cfg = ServerConfig.model_validate(payload) + dumped = cfg.to_yaml_dict() + re_loaded = ServerConfig.model_validate(dumped) + assert re_loaded == cfg + assert re_loaded.model_dump() == cfg.model_dump() + + +# ---------- Property 15: legacy/unknown rejected (R8.1, R8.2) -------------- # + + +@pytest.mark.parametrize( + 'legacy_field', + ['telemetry_config', 'persistence_config'], +) +def test_property_15_legacy_field_rejected(legacy_field: str) -> None: + payload = {legacy_field: {}} + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate(payload) + errors = exc.value.errors() + assert any(err['type'] == 'extra_forbidden' for err in errors) + assert any(legacy_field in err['loc'] for err in errors) + + +@settings(max_examples=100) +@given(unknown=st.text(min_size=1, max_size=20).filter(lambda s: not s.startswith('_'))) +def test_property_15_unknown_field_rejected(unknown: str) -> None: + known = { + 'ray_namespace', 'proxy_location', 'http_options', + 'telemetry', 'persistence', 'task_queue', 'applications', + } + if unknown in known: + return + with pytest.raises(ValidationError): + ServerConfig.model_validate({unknown: 'x'}) + + +# ---------- 3.11: from_yaml error paths + launcher dict rejection ---------- # + + +def test_from_yaml_missing_path(tmp_path: Path) -> None: + p = tmp_path / 'does_not_exist.yaml' + with pytest.raises(FileNotFoundError) as exc: + ServerConfig.from_yaml(p) + assert str(p) in str(exc.value) + + +def test_from_yaml_malformed_yaml(tmp_path: Path) -> None: + p = tmp_path / 'bad.yaml' + p.write_text('this is: not: valid: yaml: : :\n - [unbalanced\n') + with pytest.raises(ConfigParseError): + ServerConfig.from_yaml(p) + + +def test_from_yaml_top_level_must_be_mapping(tmp_path: Path) -> None: + p = tmp_path / 'list.yaml' + p.write_text('- a\n- b\n') + with pytest.raises(ConfigParseError): + ServerConfig.from_yaml(p) + + +def test_from_yaml_valid_minimal(tmp_path: Path) -> None: + p = tmp_path / 'mini.yaml' + yaml.safe_dump( + {'persistence': {'mode': 'memory'}, 'applications': []}, + p.open('w'), + ) + cfg = ServerConfig.from_yaml(p) + assert cfg.persistence.mode == 'memory' + assert cfg.applications == [] + + +def test_launcher_rejects_raw_dict() -> None: + with pytest.raises(TypeError) as exc: + ServerLauncher(config={'applications': []}) + assert 'ServerConfig' in str(exc.value) + + +def test_launcher_accepts_typed_config() -> None: + cfg = ServerConfig() + launcher = ServerLauncher(config=cfg) + assert launcher.config is cfg + + +def test_cookbook_examples_load() -> None: + """Migrated cookbook configs all parse with the new field names.""" + here = Path(__file__).resolve().parents[3] + examples = [ + here / 'cookbook' / 'client' / 'server' / 'transformer' / 'server_config.yaml', + here / 'cookbook' / 'client' / 'server' / 'megatron' / 'server_config.yaml', + here / 'cookbook' / 'client' / 'server' / 'megatron' / 'server_config_4b.yaml', + ] + for p in examples: + cfg = ServerConfig.from_yaml(p) + assert isinstance(cfg, ServerConfig), p From a558f32b2ff4f5504196d9aa8f90a80d2e429118 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:16:09 +0800 Subject: [PATCH 13/34] refactor(server): direct-backend ServerState, drop detached actor (Phase 0d) Removes the single detached Ray Actor that centralized server state (get_server_state used to call ray.remote(ServerState).options(lifetime='detached')) and replaces it with a process-local ServerState bound directly to the configured StateBackend. Every deployment now reads and writes through the shared backend, which removes the actor as a single-point bottleneck and makes state visibility a property of the backend (Redis cross-process, MemoryBackend in-process) rather than the actor. Adds ReplicaRegistry persisting capacity at replica::::max_loras so two workers on a shared backend agree on the cluster's available LoRA capacity. ModelManager loses its in-memory _replica_max_loras / _replica_models / _token_models dicts: capacity, per-replica loaded counts, and per-token model counts are derived from persisted ModelRecords on each read. register_replica / unregister_replica / get_available_replica_ids / get_capacity_info are now async to match the backend roundtrip; ServerState awaits them through. ServerStateProxy stays as a typing alias of ServerState so existing call-site annotations keep working without import churn. Updates the existing manager tests to the new async API and adds a Phase 0d test module: a static + dynamic check that no detached actor is created (R19.1), an in-process MemoryBackend smoke test (R19.6), the ReplicaRegistry round-trip, cross-instance visibility on a shared MemoryBackend, and a Hypothesis property (Property 25) showing two ServerState instances driven by the same op stream agree on every read. The Phase 0a client-API contract baseline is re-run green. --- src/twinkle/server/state/__init__.py | 10 +- src/twinkle/server/state/model_manager.py | 226 +++++++----------- src/twinkle/server/state/replica_registry.py | 70 ++++++ src/twinkle/server/state/server_state.py | 233 ++++++------------- tests/server/state/test_de_actor.py | 183 +++++++++++++++ tests/server/state/test_managers.py | 35 +-- 6 files changed, 438 insertions(+), 319 deletions(-) create mode 100644 src/twinkle/server/state/replica_registry.py create mode 100644 tests/server/state/test_de_actor.py diff --git a/src/twinkle/server/state/__init__.py b/src/twinkle/server/state/__init__.py index f79f39121..08f07dce1 100644 --- a/src/twinkle/server/state/__init__.py +++ b/src/twinkle/server/state/__init__.py @@ -11,7 +11,13 @@ from .model_manager import ModelManager from .models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord from .sampling_manager import SamplingSessionManager -from .server_state import ServerState, ServerStateProxy, get_server_state +from .replica_registry import ReplicaRegistry +from .server_state import ( + ServerState, + ServerStateProxy, + get_server_state, + reset_server_state_cache, +) from .session_manager import SessionManager __all__ = [ @@ -31,7 +37,9 @@ # Server state 'ServerState', 'ServerStateProxy', + 'ReplicaRegistry', 'get_server_state', + 'reset_server_state_cache', # Persistence backend factory 'PersistenceConfig', 'create_backend', diff --git a/src/twinkle/server/state/model_manager.py b/src/twinkle/server/state/model_manager.py index e006d8e43..df211ea99 100644 --- a/src/twinkle/server/state/model_manager.py +++ b/src/twinkle/server/state/model_manager.py @@ -1,209 +1,157 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +"""Backend-backed model manager (R19). + +Every index this manager exposes — token → model count, replica → loaded model +count, replica capacity — is computed from the persisted ``model::*`` and +``replica::*`` records in the shared :class:`StateBackend`. There is no +in-process cache, so two workers connected to the same backend see one +consistent view of the cluster's model registry without going through a +detached Ray Actor. +""" from __future__ import annotations from .backend.base import StateBackend from .base import BaseManager from .models import ModelRecord +from .replica_registry import ReplicaRegistry class ModelManager(BaseManager[ModelRecord]): - """Manages registered models. + """Manages registered models with backend-derived per-token / per-replica indexes. - Expiry is based on `created_at`. A model is also considered expired if + Expiry is based on ``created_at``. A model is also considered expired if its owning session has already been removed (cascade expiry). - Enforces a per-token model limit across all model instances (server-global). - - Also tracks replica registrations so the router can query which replicas - still have capacity (i.e. their loaded-model count < max_loras). - - Uses a **hybrid mode**: primary records (ModelRecord) are persisted in the - StateBackend, while derived indexes are kept in memory for fast lookups. - On startup, `rebuild_indexes()` loads all records and rebuilds the indexes. """ def __init__(self, backend: StateBackend, expiration_timeout: float, per_token_model_limit: int = 30) -> None: - super().__init__(backend, "model::", ModelRecord, expiration_timeout) + super().__init__(backend, 'model::', ModelRecord, expiration_timeout) self._per_token_model_limit = per_token_model_limit - # token -> set of model_ids owned by that token - self._token_models: dict[str, set[str]] = {} - # replica_id -> set of model_ids currently loaded on that replica - self._replica_models: dict[str, set[str]] = {} - # replica_id -> max_loras limit declared at registration time - self._replica_max_loras: dict[str, int] = {} + self._replicas = ReplicaRegistry(backend) - # ----- Index Rebuild ----- + # ----- Index Rebuild -------------------------------------------------- # async def rebuild_indexes(self) -> None: - """Rebuild in-memory indexes from all records in the backend. + """Compatibility shim — indexes are now derived from the backend per call.""" + return None - Should be called once after startup (e.g. in ServerState.start_cleanup_task). - """ - all_records = await self.get_all() - self._token_models.clear() - self._replica_models.clear() - for model_id, record in all_records.items(): - token = record.token - self._token_models.setdefault(token, set()).add(model_id) - if record.replica_id is not None: - self._replica_models.setdefault(record.replica_id, set()).add(model_id) - - # ----- Capacity Info ----- + # ----- Capacity ------------------------------------------------------- # - def get_capacity_info(self) -> dict[str, int]: - """Return global LoRA capacity across all registered replicas. - - Returns: - Dict containing 'max_loras', 'used_loras', and 'free_loras'. - """ - total_max_loras = sum(self._replica_max_loras.values()) - total_used_loras = sum(len(self._replica_models.get(rid, set())) for rid in self._replica_max_loras.keys()) + async def get_capacity_info(self) -> dict[str, int]: + """Return global LoRA capacity across all registered replicas.""" + replicas = await self._replicas.get_all() + loaded_per_replica = await self._loaded_per_replica() + total_max = sum(replicas.values()) + total_used = sum(loaded_per_replica.get(rid, 0) for rid in replicas) return { - 'max_loras': total_max_loras, - 'used_loras': total_used_loras, - 'free_loras': max(0, total_max_loras - total_used_loras), + 'max_loras': total_max, + 'used_loras': total_used, + 'free_loras': max(0, total_max - total_used), } - # ----- Replica Registration ----- + # ----- Replica Registration ------------------------------------------ # - def register_replica(self, replica_id: str, max_loras: int) -> None: - """Register a replica and its LoRA capacity. - - Args: - replica_id: Unique identifier for the replica. - max_loras: Maximum number of LoRA adapters the replica can hold. - """ - self._replica_max_loras[replica_id] = max_loras - self._replica_models.setdefault(replica_id, set()) + async def register_replica(self, replica_id: str, max_loras: int) -> None: + """Register a replica and its LoRA capacity in the shared backend.""" + await self._replicas.register(replica_id, max_loras) async def unregister_replica(self, replica_id: str) -> None: - """Remove a replica from the registry. - - Any model associations for this replica are also cleared from both - the backend and the in-memory indexes. - - Args: - replica_id: Unique identifier for the replica to remove. - """ - # Remove models associated with this replica - model_ids = list(self._replica_models.get(replica_id, set())) - for model_id in model_ids: + """Remove a replica's capacity entry and any models it owns.""" + loaded = await self._models_for_replica(replica_id) + for model_id in loaded: await self.remove(model_id) - self._replica_max_loras.pop(replica_id, None) - self._replica_models.pop(replica_id, None) - - def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: - """Return the subset of candidate replica IDs that still have capacity. + await self._replicas.unregister(replica_id) - A replica has capacity when its current loaded-model count is strictly - less than its declared ``max_loras``. Replicas that are not registered - (unknown to this manager) are included as-is (conservative fallback). + async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + """Return the subset of ``candidate_ids`` that still have capacity. - Args: - candidate_ids: Replica IDs to evaluate. - - Returns: - Filtered list preserving the original order. + A replica has capacity when its persisted loaded-model count is strictly + less than its declared ``max_loras``. Unknown replicas (no capacity row + in the backend) are included conservatively, matching the previous + actor-based behavior (R19.3). """ - available = [] + if not candidate_ids: + return [] + replicas = await self._replicas.get_all() + loaded_per_replica = await self._loaded_per_replica( + replica_filter=set(candidate_ids) | replicas.keys() + ) + available: list[str] = [] for rid in candidate_ids: - max_loras = self._replica_max_loras.get(rid) + max_loras = replicas.get(rid) if max_loras is None: - # Unknown replica – include conservatively + # Unknown replica — include conservatively. available.append(rid) continue - current = len(self._replica_models.get(rid, set())) - if current < max_loras: + if loaded_per_replica.get(rid, 0) < max_loras: available.append(rid) return available - # ----- CRUD ----- + # ----- CRUD ----------------------------------------------------------- # async def add(self, model_id: str, record: ModelRecord) -> None: - """Store a record under the given ID. - - Args: - model_id: Unique identifier for the model. - record: ModelRecord to store. + """Store a record, enforcing the per-token model limit. Raises: - RuntimeError: If the token has reached per_token_model_limit. + RuntimeError: when adding ``record`` would exceed + ``per_token_model_limit`` for ``record.token``. """ token = record.token - current_ids = self._token_models.get(token, set()) - if len(current_ids) >= self._per_token_model_limit: - raise RuntimeError(f'Model limit exceeded: ' - f'{len(current_ids)}/{self._per_token_model_limit} models') - # Persist to backend + current = await self._count_models_for_token(token) + if current >= self._per_token_model_limit: + raise RuntimeError( + f'Model limit exceeded: {current}/{self._per_token_model_limit} models' + ) await super().add(model_id, record) - # Update in-memory indexes - self._token_models.setdefault(token, set()).add(model_id) - if record.replica_id is not None: - self._replica_models.setdefault(record.replica_id, set()).add(model_id) async def remove(self, model_id: str) -> bool: - """Remove a record by ID and clean up token and replica ownership. - - Returns: - True if the record existed and was removed, False otherwise. - """ - # Get the record first for index cleanup + """Remove a record by ID.""" record = await self.get(model_id) if record is None: return False - # Remove from backend await super().remove(model_id) - # Clean up in-memory indexes - self._cleanup_ownership(model_id, record) return True - # ----- Cleanup ----- + # ----- Cleanup -------------------------------------------------------- # async def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None, **kwargs) -> int: - """Remove models that are older than cutoff_time, or whose owning - session has already been expired. - - Args: - cutoff_time: Unix timestamp threshold. - expired_session_ids: Optional list of session IDs that have just - been expired; any model belonging to one of these sessions will - also be removed regardless of its own age. - - Returns: - Number of models removed. - """ + """Remove models older than ``cutoff_time`` or whose owning session expired.""" session_set = set(expired_session_ids or []) all_records = await self.get_all() - expired_ids = [] - + expired_ids: list[str] = [] for model_id, record in all_records.items(): - # Cascade: owner session was expired if record.session_id and record.session_id in session_set: expired_ids.append(model_id) continue - # Own age created_at = self._parse_timestamp(record.created_at) if created_at < cutoff_time: expired_ids.append(model_id) - for model_id in expired_ids: await self.remove(model_id) - return len(expired_ids) - # ----- Internal helpers ----- + # ----- Backend-derived helpers --------------------------------------- # - def _cleanup_ownership(self, model_id: str, record: ModelRecord) -> None: - """Remove token and replica ownership entries for a model record. + async def _count_models_for_token(self, token: str | None) -> int: + if not token: + return 0 + all_records = await self.get_all() + return sum(1 for r in all_records.values() if r.token == token) - Args: - model_id: The model ID being removed. - record: The associated ModelRecord. - """ - token = record.token - if token and token in self._token_models: - self._token_models[token].discard(model_id) - if not self._token_models[token]: - del self._token_models[token] - if record.replica_id and record.replica_id in self._replica_models: - self._replica_models[record.replica_id].discard(model_id) + async def _models_for_replica(self, replica_id: str) -> list[str]: + all_records = await self.get_all() + return [mid for mid, r in all_records.items() if r.replica_id == replica_id] + + async def _loaded_per_replica( + self, replica_filter: set[str] | None = None + ) -> dict[str, int]: + all_records = await self.get_all() + counts: dict[str, int] = {} + for record in all_records.values(): + rid = record.replica_id + if rid is None: + continue + if replica_filter is not None and rid not in replica_filter: + continue + counts[rid] = counts.get(rid, 0) + 1 + return counts diff --git a/src/twinkle/server/state/replica_registry.py b/src/twinkle/server/state/replica_registry.py new file mode 100644 index 000000000..69352ce8f --- /dev/null +++ b/src/twinkle/server/state/replica_registry.py @@ -0,0 +1,70 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Backend-backed registry of replica capacity (R19). + +Replaces the in-process ``_replica_max_loras`` dict that lived inside the +detached Ray Actor. Each entry persists to ``replica::::max_loras`` +in the configured :class:`StateBackend`, so two workers on the same shared +backend (Redis in multi-node, MemoryBackend in single-node) see one consistent +view of the cluster's capacity. + +The registry knows *only* about declared capacity. The current loaded-model +count is derived by querying the persisted ``model::*`` records directly — +nothing here caches that count, so concurrent writes from different workers +cannot drift into an inconsistent local index. +""" +from __future__ import annotations + +from .backend.base import StateBackend + +REPLICA_PREFIX = 'replica::' +_MAX_LORAS_SUFFIX = '::max_loras' + + +def _make_key(replica_id: str) -> str: + return f'{REPLICA_PREFIX}{replica_id}{_MAX_LORAS_SUFFIX}' + + +def _replica_id_from_key(key: str) -> str | None: + if not key.startswith(REPLICA_PREFIX) or not key.endswith(_MAX_LORAS_SUFFIX): + return None + return key[len(REPLICA_PREFIX): -len(_MAX_LORAS_SUFFIX)] + + +class ReplicaRegistry: + """Read/write replica capacity through the shared :class:`StateBackend`.""" + + def __init__(self, backend: StateBackend) -> None: + self._backend = backend + + async def register(self, replica_id: str, max_loras: int) -> None: + """Store / overwrite the declared LoRA capacity for ``replica_id``.""" + await self._backend.set(_make_key(replica_id), int(max_loras)) + + async def unregister(self, replica_id: str) -> None: + """Remove the capacity entry for ``replica_id`` (idempotent).""" + await self._backend.delete(_make_key(replica_id)) + + async def get_max_loras(self, replica_id: str) -> int | None: + """Return the declared capacity, or ``None`` if the replica is unknown.""" + value = await self._backend.get(_make_key(replica_id)) + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + async def get_all(self) -> dict[str, int]: + """Return every registered replica's declared capacity.""" + keys = await self._backend.keys(f'{REPLICA_PREFIX}*{_MAX_LORAS_SUFFIX}') + out: dict[str, int] = {} + for key in keys: + rid = _replica_id_from_key(key) + if rid is None: + continue + value = await self._backend.get(key) + try: + out[rid] = int(value) + except (TypeError, ValueError): + continue + return out diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index eea830dd6..908103067 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio -import ray import re import time import uuid @@ -73,7 +72,7 @@ def __init__( self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0)) async def get_capacity_info(self) -> dict[str, int]: - return self._model_mgr.get_capacity_info() + return await self._model_mgr.get_capacity_info() # ----- Session Management ----- @@ -171,7 +170,7 @@ async def register_replica(self, replica_id: str, max_loras: int) -> None: replica_id: Unique identifier for the replica. max_loras: Maximum number of LoRA adapters the replica can hold. """ - self._model_mgr.register_replica(replica_id, max_loras) + await self._model_mgr.register_replica(replica_id, max_loras) async def unregister_replica(self, replica_id: str) -> None: """Remove a replica from the registry. @@ -190,7 +189,7 @@ async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str] Returns: Filtered list of replica IDs with remaining capacity. """ - return self._model_mgr.get_available_replica_ids(candidate_ids) + return await self._model_mgr.get_available_replica_ids(candidate_ids) # ----- Sampling Session Management ----- @@ -440,127 +439,21 @@ async def get_cleanup_stats(self) -> dict[str, Any]: # --------------------------------------------------------------------------- -# Ray proxy +# Direct-backend factory (R19) # --------------------------------------------------------------------------- +# +# ``ServerStateProxy`` is intentionally retained as a thin alias of +# ``ServerState`` so call-site type hints (e.g. ``state: ServerStateProxy``) +# keep working without import churn during this transition. The detached Ray +# Actor and the per-method ``self._actor.X.remote(...)`` forwarding are gone: +# every worker now binds a process-local :class:`ServerState` directly to the +# shared :class:`StateBackend`, and the existing ``await state.X(...)`` call +# sites awaited a coroutine before and still do. +ServerStateProxy = ServerState # type: ignore[assignment] -class ServerStateProxy: - """ - Proxy for interacting with a ServerState Ray actor. - - Wraps Ray remote calls to provide a synchronous-looking API for - interacting with the distributed ServerState actor. - """ - - def __init__(self, actor_handle) -> None: - self._actor = actor_handle - - async def get_capacity_info(self) -> dict[str, int]: - return await self._actor.get_capacity_info.remote() - - # ----- Session Management ----- - - async def create_session(self, payload: dict[str, Any]) -> str: - return await self._actor.create_session.remote(payload) - - async def touch_session(self, session_id: str) -> bool: - return await self._actor.touch_session.remote(session_id) - - async def get_session_last_heartbeat(self, session_id: str) -> float | None: - return await self._actor.get_session_last_heartbeat.remote(session_id) - - # ----- Model Registration ----- - - async def register_model(self, - payload: dict[str, Any], - token: str, - model_id: str | None = None, - replica_id: str | None = None, - session_id: str | None = None) -> str: - return await self._actor.register_model.remote(payload, token, model_id, replica_id, session_id) - - async def unload_model(self, model_id: str) -> bool: - return await self._actor.unload_model.remote(model_id) - - async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: - return await self._actor.get_model_metadata.remote(model_id) - - # ----- Replica Management ----- - - async def register_replica(self, replica_id: str, max_loras: int) -> None: - await self._actor.register_replica.remote(replica_id, max_loras) - - async def unregister_replica(self, replica_id: str) -> None: - await self._actor.unregister_replica.remote(replica_id) - - async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: - return await self._actor.get_available_replica_ids.remote(candidate_ids) - - # ----- Sampling Session Management ----- - - async def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: - return await self._actor.create_sampling_session.remote(payload, sampling_session_id) - - async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: - return await self._actor.get_sampling_session.remote(sampling_session_id) - - # ----- Configuration Management ----- - - async def add_config(self, key: str, value: Any) -> None: - await self._actor.add_config.remote(key, value) - async def add_or_get_config(self, key: str, value: Any) -> Any: - return await self._actor.add_or_get_config.remote(key, value) - - async def get_config(self, key: str) -> Any | None: - return await self._actor.get_config.remote(key) - - async def pop_config(self, key: str) -> Any | None: - return await self._actor.pop_config.remote(key) - - async def clear_config(self) -> None: - await self._actor.clear_config.remote() - - async def count_config(self) -> int: - return await self._actor.count_config.remote() - - # ----- Future Management ----- - - async def get_future(self, request_id: str) -> dict[str, Any] | None: - return await self._actor.get_future.remote(request_id) - - async def store_future_status( - self, - request_id: str, - status: str, - model_id: str | None, - reason: str | None = None, - result: Any = None, - queue_state: str | None = None, - queue_state_reason: str | None = None, - ) -> None: - """Store task status with optional result.""" - await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, - queue_state_reason) - - # ----- Resource Cleanup ----- - - async def cleanup_expired_resources(self) -> dict[str, int]: - return await self._actor.cleanup_expired_resources.remote() - - async def start_cleanup_task(self) -> bool: - return await self._actor.start_cleanup_task.remote() - - async def stop_cleanup_task(self) -> bool: - return await self._actor.stop_cleanup_task.remote() - - async def get_cleanup_stats(self) -> dict[str, Any]: - return await self._actor.get_cleanup_stats.remote() - - -# --------------------------------------------------------------------------- -# Factory -# --------------------------------------------------------------------------- +_PROCESS_STATE_CACHE: dict[str, ServerState] = {} def get_server_state( @@ -569,55 +462,71 @@ def get_server_state( persistence_config: PersistenceConfig | None = None, signature_config: dict[str, Any] | None = None, signature_policy: str = 'warn', - **kwargs) -> ServerStateProxy: - """Get or create the ServerState Ray actor. + **kwargs) -> ServerState: + """Return a process-local :class:`ServerState` bound directly to the backend. - Ensures only one ServerState actor exists with the given name. Uses a - detached actor so the state persists across driver restarts. + No detached Ray Actor is created — every deployment / worker accesses the + shared persistence backend directly, which removes the actor as a + single-point bottleneck (R19.1, R19.2). Within one process the same + ``actor_name`` is cached so repeated callers share one ``ServerState`` + instance and the cleanup loop is started exactly once. Args: - actor_name: Name for the Ray actor (default: 'twinkle_server_state'). - backend: Optional :class:`StateBackend` injected into the ServerState - actor. When ``None`` the actor falls back to ``persistence_config`` - or an in-process :class:`MemoryBackend`. - persistence_config: Optional :class:`PersistenceConfig` used to build - a backend via :func:`create_backend` when ``backend`` is None. - **kwargs: Additional keyword arguments passed to ServerState constructor - (e.g., expiration_timeout, cleanup_interval). - - Returns: - A ServerStateProxy for interacting with the actor. + actor_name: Cache key for the per-process ``ServerState`` instance. + The legacy parameter name is kept for call-site compatibility. + backend: Optional :class:`StateBackend` to inject. When ``None`` a + backend is built from ``persistence_config`` (or env vars) via + :func:`create_backend`. + persistence_config: Optional :class:`PersistenceConfig`. Accepted as a + raw dict for YAML compatibility. + signature_config: Optional dict whose hash is validated against the + stored config signature on first access. + signature_policy: ``warn`` | ``error`` | ``ignore`` for signature drift. + **kwargs: Forwarded to the :class:`ServerState` constructor + (``expiration_timeout``, ``cleanup_interval``, ...). """ - # Support passing persistence_config as raw dict from YAML if isinstance(persistence_config, dict): persistence_config = PersistenceConfig(**persistence_config) - # Fall back to env-var-propagated config so any worker (not just Gateway) - # can create the actor with the right backend regardless of deployment - # startup order. Explicit args still take precedence. if backend is None and persistence_config is None: persistence_config = PersistenceConfig.from_env() + cached = _PROCESS_STATE_CACHE.get(actor_name) + if cached is not None: + return cached + + state = ServerState( + backend=backend, + persistence_config=persistence_config, + signature_config=signature_config, + signature_policy=signature_policy, + **kwargs, + ) + _PROCESS_STATE_CACHE[actor_name] = state + + # Start the cleanup loop best-effort. The loop runs inside an asyncio task, + # so ``start_cleanup_task`` must be awaited from within a running loop. + # Schedule it lazily — when there's no running loop yet (sync constructor + # path), skip and let the first awaited handler call start it later. try: - actor = ray.get_actor(actor_name) - except ValueError: - try: - _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote( - backend=backend, - persistence_config=persistence_config, - signature_config=signature_config, - signature_policy=signature_policy, - **kwargs) - try: - ray.get(actor.start_cleanup_task.remote()) - except Exception as e: - # Ray wraps remote exceptions - check cause - cause = e.__cause__ if hasattr(e, '__cause__') and e.__cause__ else e - if isinstance(cause, ConfigMismatchError) or 'ConfigMismatchError' in type(e).__name__: - raise - logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') - except ValueError: - actor = ray.get_actor(actor_name) - assert actor is not None - return ServerStateProxy(actor) + loop = asyncio.get_running_loop() + loop.create_task(state.start_cleanup_task()) + except RuntimeError: + # No running loop — handler bodies will start the loop on first await. + pass + except Exception as e: + cause = e.__cause__ if hasattr(e, '__cause__') and e.__cause__ else e + if isinstance(cause, ConfigMismatchError) or 'ConfigMismatchError' in type(e).__name__: + raise + logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') + + return state + + +def reset_server_state_cache() -> None: + """Clear the per-process ServerState cache. + + Test-only helper. Production code should never need to reset state across + requests — workers reuse one instance for the lifetime of the process. + """ + _PROCESS_STATE_CACHE.clear() diff --git a/tests/server/state/test_de_actor.py b/tests/server/state/test_de_actor.py new file mode 100644 index 000000000..ba1573e2f --- /dev/null +++ b/tests/server/state/test_de_actor.py @@ -0,0 +1,183 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 0d — De-Actor ServerState tests (R19). + +Covers: +- # Feature: server-config-observability-refactor, Property 25: State operation equivalence under direct-backend access +- de-Actor wiring: ``get_server_state`` returns a direct-bound ``ServerState`` + and never creates a detached Ray Actor (R19.1, R19.2). +- in-process MemoryBackend works without Redis (R19.6). +""" +from __future__ import annotations + +from unittest import mock + +import pytest +from hypothesis import HealthCheck, given, settings, strategies as st + +from twinkle.server.state import ( + PersistenceConfig, + ReplicaRegistry, + ServerState, + get_server_state, + reset_server_state_cache, +) +from twinkle.server.state.backend.memory_backend import MemoryBackend + + +# ---------- 4.6: de-Actor wiring + in-process persistence ------------------ # + + +def _ray_attr_used(obj_path: str) -> bool: + """Return True if ``obj_path`` (e.g. ``ray.remote``) is referenced in source.""" + from pathlib import Path + src = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'state' / 'server_state.py' + return obj_path in src.read_text() + + +def test_no_detached_actor_in_source() -> None: + """The state module must not call ``ray.remote(...)`` or use ``lifetime='detached'``. + + Static check: searching the file is enough — the dynamic check below also + confirms ``ray.remote`` is never invoked when ``get_server_state`` runs. + """ + assert not _ray_attr_used('ray.remote('), ( + 'state/server_state.py still references ray.remote(...) — ' + 'detached actor must not be created (R19.1).' + ) + assert not _ray_attr_used("lifetime='detached'"), ( + "state/server_state.py still uses lifetime='detached' (R19.1)." + ) + + +def test_get_server_state_does_not_call_ray_remote() -> None: + reset_server_state_cache() + import ray + + with mock.patch.object(ray, 'remote') as remote_spy, \ + mock.patch.object(ray, 'get_actor', side_effect=ValueError) as get_actor_spy: + state = get_server_state(actor_name='unit', backend=MemoryBackend()) + assert isinstance(state, ServerState) + assert remote_spy.call_count == 0, 'ray.remote was called — detached actor created' + # ray.get_actor may not be called at all under direct-backend access. + # Either way, the contract is that no remote actor is built. + _ = get_actor_spy + + +def test_get_server_state_caches_per_process() -> None: + reset_server_state_cache() + a = get_server_state(actor_name='cache-a', backend=MemoryBackend()) + b = get_server_state(actor_name='cache-a') + assert a is b + + +def test_get_server_state_separate_keys_yield_separate_instances() -> None: + reset_server_state_cache() + a = get_server_state(actor_name='k1', backend=MemoryBackend()) + b = get_server_state(actor_name='k2', backend=MemoryBackend()) + assert a is not b + + +def test_in_process_persistence_no_redis_required() -> None: + """``PersistenceConfig`` defaults to memory mode and ``ServerState`` works + without an external Redis (R19.6).""" + reset_server_state_cache() + cfg = PersistenceConfig() # mode == 'memory' + state = get_server_state(actor_name='no-redis', persistence_config=cfg) + assert isinstance(state, ServerState) + + +# ---------- 4.5: state-operation equivalence under direct-backend ---------- # + + +_OP_STRATEGY = st.lists( + st.one_of( + # ('register_replica', replica_id, max_loras) + st.tuples( + st.just('register_replica'), + st.sampled_from(['r1', 'r2', 'r3']), + st.integers(min_value=1, max_value=4), + ), + # ('add_model', model_id, token, replica_id) + st.tuples( + st.just('add_model'), + st.text(min_size=1, max_size=4, alphabet='abcdefg'), + st.sampled_from(['t1', 't2']), + st.sampled_from(['r1', 'r2', None]), + ), + # ('config_set', key, value) + st.tuples( + st.just('config_set'), + st.sampled_from(['k1', 'k2', 'k3']), + st.integers(min_value=0, max_value=99), + ), + ), + min_size=0, + max_size=12, +) + + +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow]) +@given(ops=_OP_STRATEGY) +@pytest.mark.asyncio +async def test_property_25_state_operation_equivalence(ops: list[tuple]) -> None: + """Two ``ServerState`` instances driven by the same op stream agree. + + Two instances bound to one shared backend must agree on every read after + the same sequence of writes — this is the equivalence the actor used to + enforce, now provided by the shared backend itself (R19.3). + """ + backend = MemoryBackend() + a = ServerState(backend=backend) + b = ServerState(backend=backend) + seen_models: set[str] = set() + for op in ops: + kind = op[0] + if kind == 'register_replica': + _, rid, mx = op + await a.register_replica(rid, mx) + elif kind == 'add_model': + _, mid, token, rid = op + if mid in seen_models: + continue + seen_models.add(mid) + await a.register_model({'base_model': 'x'}, token=token, model_id=mid, replica_id=rid) + elif kind == 'config_set': + _, k, v = op + await b.add_config(k, v) + + # Both instances see the same persisted view. + assert await a.get_capacity_info() == await b.get_capacity_info() + for k in ('k1', 'k2', 'k3'): + assert await a.get_config(k) == await b.get_config(k) + + +# ---------- ReplicaRegistry direct ---------------------------------------- # + + +@pytest.mark.asyncio +async def test_replica_registry_round_trip() -> None: + backend = MemoryBackend() + reg = ReplicaRegistry(backend) + await reg.register('r1', 4) + await reg.register('r2', 7) + assert await reg.get_max_loras('r1') == 4 + assert await reg.get_max_loras('r2') == 7 + assert await reg.get_max_loras('unknown') is None + all_ = await reg.get_all() + assert all_ == {'r1': 4, 'r2': 7} + await reg.unregister('r1') + assert await reg.get_max_loras('r1') is None + + +@pytest.mark.asyncio +async def test_cross_instance_visibility_in_process() -> None: + """Two ``ServerState`` instances on one shared MemoryBackend see the same writes (R19.4 in-process).""" + backend = MemoryBackend() + a = ServerState(backend=backend) + b = ServerState(backend=backend) + await a.register_replica('r1', 3) + await a.register_model({'base_model': 'x'}, token='t1', model_id='m1', replica_id='r1') + info_b = await b.get_capacity_info() + assert info_b == {'max_loras': 3, 'used_loras': 1, 'free_loras': 2} + avail_b = await b.get_available_replica_ids(['r1']) + assert avail_b == ['r1'] diff --git a/tests/server/state/test_managers.py b/tests/server/state/test_managers.py index bea3f15f3..95f98676a 100644 --- a/tests/server/state/test_managers.py +++ b/tests/server/state/test_managers.py @@ -168,38 +168,39 @@ async def test_token_limit_per_token(self, manager): @pytest.mark.asyncio async def test_replica_registration(self, manager): - manager.register_replica("replica1", max_loras=5) - info = manager.get_capacity_info() + await manager.register_replica("replica1", max_loras=5) + info = await manager.get_capacity_info() assert info["max_loras"] == 5 assert info["used_loras"] == 0 assert info["free_loras"] == 5 @pytest.mark.asyncio async def test_capacity_info_after_add(self, manager): - manager.register_replica("r1", max_loras=3) + await manager.register_replica("r1", max_loras=3) record = ModelRecord(token="tok1", replica_id="r1") await manager.add("m1", record) - info = manager.get_capacity_info() + info = await manager.get_capacity_info() assert info["used_loras"] == 1 assert info["free_loras"] == 2 @pytest.mark.asyncio - async def test_rebuild_indexes(self, manager): - """rebuild_indexes should reconstruct token and replica indexes from backend.""" + async def test_indexes_derived_from_backend(self, manager): + """Per-token / per-replica counts are derived from the backend.""" record1 = ModelRecord(token="tok1", replica_id="r1") record2 = ModelRecord(token="tok1", replica_id="r2") await manager.add("m1", record1) await manager.add("m2", record2) - # Clear in-memory indexes - manager._token_models.clear() - manager._replica_models.clear() + await manager.register_replica("r1", max_loras=5) + await manager.register_replica("r2", max_loras=5) - await manager.rebuild_indexes() - assert "m1" in manager._token_models.get("tok1", set()) - assert "m2" in manager._token_models.get("tok1", set()) - assert "m1" in manager._replica_models.get("r1", set()) - assert "m2" in manager._replica_models.get("r2", set()) + # Backend-derived availability reflects all persisted records. + avail = await manager.get_available_replica_ids(["r1", "r2"]) + assert avail == ["r1", "r2"] + + # Per-token count enforces the limit using the persisted records. + count = await manager._count_models_for_token("tok1") + assert count == 2 @pytest.mark.asyncio async def test_cascade_cleanup_by_session(self, manager): @@ -222,12 +223,12 @@ async def test_cascade_cleanup_by_session(self, manager): @pytest.mark.asyncio async def test_get_available_replica_ids(self, manager): - manager.register_replica("r1", max_loras=2) - manager.register_replica("r2", max_loras=1) + await manager.register_replica("r1", max_loras=2) + await manager.register_replica("r2", max_loras=1) # Fill r2 await manager.add("m1", ModelRecord(token="t", replica_id="r2")) - available = manager.get_available_replica_ids(["r1", "r2", "r3_unknown"]) + available = await manager.get_available_replica_ids(["r1", "r2", "r3_unknown"]) # r1 has capacity, r2 is full, r3 unknown (conservative include) assert "r1" in available assert "r2" not in available From 2940f9c11ddac591ed9b823c10cc0748648a115f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:28:41 +0800 Subject: [PATCH 14/34] feat(server): mock model + sampler backends with case-sensitive dispatch (Phase 1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds numpy-only TwinkleCompatMockModel and MockSampler so the server can be launched on a CPU-only host with no torch / transformers / vllm / megatron installed. Both backends return deterministic results keyed by the request parameters: forward / forward_only / forward_backward yield logprob and elementwise_loss arrays whose shapes are derived from the input sequence lengths, sample emits one logprob per token and num_samples sequences per prompt, and identical requests produce identical bytes (R1.3, R2.3-2.5, R4.4, R4.5). Adapter add / remove / has are tracked in an in-memory record; remove on an absent name raises KeyError without mutating the record (R1.7). Replaces the if-use_megatron branch in model/app.py with strict case- sensitive dispatch on the new ``backend`` field (mock|transformers| megatron) and the hardcoded vLLMSampler in sampler/app.py with dispatch on ``sampler_type`` (mock|vllm|torch). Both validators raise ConfigError with field/value/allowed *before* instantiating any backend (R3.9, R3.10) and the mock branch skips ``twinkle.initialize(mode='ray', ...)`` entirely (R3.7, R3.8) — the largest startup-time saving on a CPU-only host. Makes ``twinkle.server.model`` and ``twinkle.server.sampler`` package __init__s lazy via __getattr__ so importing the mock backend module does not transitively pull torch (via app.py → common/router → template) or vllm (via app.py → twinkle.sampler) on a CPU-only host (R1.2, R2.2, R4.3). Adds the all-mock cookbook config at cookbook/client/server/mock/ with a README documenting the launch command, the 30-second ready-state target, and an explicit not-for-production note. Mock-mode persistence defaults to in-process MemoryBackend so no Redis is required. Property tests cover interface conformance, forward determinism + shape, adapter round-trip, remove-absent semantics, sampler output length and logprob count, sampler determinism, max_tokens<1 rejection, and dispatch validation for every (field, allowed, invalid) tuple. Static checks guarantee mock_sampler.py never imports vllm directly and mock_model imports successfully when torch/transformers/vllm/megatron are blocked from sys.modules. The Phase 0a client-API contract baseline is re-run green. --- cookbook/client/server/mock/README.md | 43 ++++ .../client/server/mock/server_config.yaml | 100 ++++++++ src/twinkle/server/model/__init__.py | 18 +- src/twinkle/server/model/app.py | 85 +++++-- .../server/model/backends/mock_model.py | 233 ++++++++++++++++++ src/twinkle/server/sampler/__init__.py | 17 +- src/twinkle/server/sampler/app.py | 71 ++++-- .../server/sampler/backends/__init__.py | 0 .../server/sampler/backends/mock_sampler.py | 106 ++++++++ tests/server/model/__init__.py | 0 tests/server/model/test_mock_model.py | 154 ++++++++++++ tests/server/sampler/__init__.py | 0 tests/server/sampler/test_mock_sampler.py | 161 ++++++++++++ 13 files changed, 942 insertions(+), 46 deletions(-) create mode 100644 cookbook/client/server/mock/README.md create mode 100644 cookbook/client/server/mock/server_config.yaml create mode 100644 src/twinkle/server/model/backends/mock_model.py create mode 100644 src/twinkle/server/sampler/backends/__init__.py create mode 100644 src/twinkle/server/sampler/backends/mock_sampler.py create mode 100644 tests/server/model/__init__.py create mode 100644 tests/server/model/test_mock_model.py create mode 100644 tests/server/sampler/__init__.py create mode 100644 tests/server/sampler/test_mock_sampler.py diff --git a/cookbook/client/server/mock/README.md b/cookbook/client/server/mock/README.md new file mode 100644 index 000000000..9aefe586e --- /dev/null +++ b/cookbook/client/server/mock/README.md @@ -0,0 +1,43 @@ +# Mock backend — CPU-only quick start + +This directory ships an all-mock Twinkle Server configuration so you can +launch the HTTP surface in seconds on a CPU-only laptop, no GPU and no +torch/transformers/vllm/megatron required. Use it for local development, +CI smoke tests, and contract-level HTTP debugging. + +> **Not for production.** Mock backends return fixed numpy-derived results +> without performing real model computation or sampling. The training and +> sampling endpoints respond with deterministic synthetic outputs derived +> only from the request shape and a seed. + +## Launch + +```bash +python -m twinkle.server --config cookbook/client/server/mock/server_config.yaml +``` + +The launcher should reach the ready state within **30 seconds** on a CPU-only +host (R4.1) — `ModelManagement` and `SamplerManagement` skip the +`twinkle.initialize(mode='ray', ...)` step that the GPU backends would run +(R3.7, R3.8). + +## What the mock backends do + +- **Model (`backend: mock`)** — numpy-only. Forward / forward-only / + forward-backward calls return deterministic logprobs and elementwise + losses keyed by `(model_id, adapter_name, seed, input_shape)`. Step / + backward / optimizer-update calls are no-ops. Adapter add / remove / + has are tracked in an in-memory record. +- **Sampler (`sampler_type: mock`)** — numpy-only. `sample` returns one + `SampleResponse` per input prompt with `num_samples` sequences of length + `max_tokens`, exactly one logprob entry per emitted token. Repeated calls + with the same parameters return identical token sequences and logprobs. + `max_tokens < 1` raises a validation error. + +## Verifying determinism + +```bash +curl -s -X POST http://localhost:8000/api/v1/model/mock/twinkle/forward_only \ + -H 'Content-Type: application/json' -d @some_payload.json +# Repeat the same request — the response body is byte-for-byte identical. +``` diff --git a/cookbook/client/server/mock/server_config.yaml b/cookbook/client/server/mock/server_config.yaml new file mode 100644 index 000000000..45ed71eb2 --- /dev/null +++ b/cookbook/client/server/mock/server_config.yaml @@ -0,0 +1,100 @@ +# Twinkle Server Configuration — Mock backend (CPU-only / no GPU) +# +# NOT FOR PRODUCTION. This config wires the all-mock model + sampler backends +# so the server starts in seconds on a CPU-only host with no torch / +# transformers / vllm / megatron installed. Mock backends return fixed +# numpy-derived results without performing real model computation or +# sampling — use it for local development, CI smoke tests, and HTTP-surface +# debugging. For real training/inference use one of the GPU configs in +# ../transformer/ or ../megatron/. + +proxy_location: EveryNode + +http_options: + host: 0.0.0.0 + port: 8000 + +telemetry: + enabled: false + debug: false + service_name: twinkle-server + otlp_endpoint: http://localhost:4317 + +# In-process MemoryBackend — no Redis required for the mock workflow. +persistence: + mode: memory + +applications: + + # 1. Tinker-compatible gateway + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 3 + supported_models: + - mock-model + deployments: + - name: GatewayServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + + # 2. Mock model — numpy-only, deterministic. Skips twinkle.initialize. + - name: models-mock + route_prefix: /api/v1/model/mock + import_path: model + args: + backend: mock # Mock backend: numpy-only, CPU-only + model_id: mock-model + nproc_per_node: 1 + device_group: + name: model + ranks: 1 + device_type: cpu # No GPU required + device_mesh: + device_type: cpu + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + + # 3. Mock sampler — numpy-only, deterministic. No vllm import. + - name: sampler-mock + route_prefix: /api/v1/sampler/mock + import_path: sampler + args: + sampler_type: mock # Mock sampler: numpy-only, CPU-only + model_id: mock-model + nproc_per_node: 1 + device_group: + name: sampler + ranks: 1 + device_type: cpu + device_mesh: + device_type: cpu + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 diff --git a/src/twinkle/server/model/__init__.py b/src/twinkle/server/model/__init__.py index 1a203083e..8499387b1 100644 --- a/src/twinkle/server/model/__init__.py +++ b/src/twinkle/server/model/__init__.py @@ -1,3 +1,19 @@ -from .app import build_model_app +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Model deployment package. + +``build_model_app`` is exposed lazily via ``__getattr__`` so importing the +mock backend (``twinkle.server.model.backends.mock_model``) on a CPU-only +host doesn't pull in torch/transformers via ``app.py`` at package-init time +(R1.2, R4.3). +""" +from __future__ import annotations __all__ = ['build_model_app'] + + +def __getattr__(name: str): + if name == 'build_model_app': + from .app import build_model_app + + return build_model_app + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 0b558d6e9..101061b3a 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -15,6 +15,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.exceptions import ConfigError from twinkle.server.utils.lifecycle import AdapterManagerMixin from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware @@ -30,6 +31,33 @@ logger = get_logger() +_MODEL_BACKENDS: tuple[str, ...] = ('mock', 'transformers', 'megatron') + + +def _dispatch_model_backend(backend: str, ctor_kwargs: dict[str, Any]) -> Any: + """Instantiate the model backend selected by ``backend`` (R3.1-3.3, R3.9). + + Raises :class:`ConfigError` *before* instantiating any backend if the + value is missing, empty, or not in the permitted set, so the deployment + never reaches a ready state with a bad backend. + """ + if backend is None or not isinstance(backend, str) or backend == '': + raise ConfigError(field='backend', value=backend, allowed=list(_MODEL_BACKENDS)) + if backend not in _MODEL_BACKENDS: + raise ConfigError(field='backend', value=backend, allowed=list(_MODEL_BACKENDS)) + if backend == 'mock': + from .backends.mock_model import TwinkleCompatMockModel + + return TwinkleCompatMockModel(**ctor_kwargs) + if backend == 'megatron': + from .backends.megatron_model import TwinkleCompatMegatronModel + + return TwinkleCompatMegatronModel(**ctor_kwargs) + from .backends.transformers_model import TwinkleCompatTransformersModel + + return TwinkleCompatTransformersModel(**ctor_kwargs) + + class ModelManagement(TaskQueueMixin, AdapterManagerMixin): """Unified model management service. @@ -51,41 +79,44 @@ def __init__(self, queue_config: dict[str, Any] | None = None, backend: str | None = None, **kwargs): - # Bridge: Phase 0c introduces ``backend`` as the canonical selector; - # Phase 1 replaces the use_megatron branch below with full dispatch - # on this field. Until then, derive use_megatron from backend when - # supplied so both YAML schemas keep working through the transition. - if backend is not None: - use_megatron = backend == 'megatron' - self.device_group = DeviceGroup(**device_group) - twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) + # Backwards-compatible bridge: legacy callers passed ``use_megatron`` — + # derive ``backend`` from it when not supplied so we always go through + # the strict dispatch below. + if backend is None: + backend = 'megatron' if use_megatron else 'transformers' + self.backend = backend + self.use_megatron = backend == 'megatron' + # Skip twinkle.initialize for the mock backend (R3.7) — the largest + # startup-time saving and the only way to start without CUDA/torch. + if backend != 'mock': + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False, + ) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.use_megatron = use_megatron + self.device_group = None + self.device_mesh = None self.replica_id = serve.get_replica_context().replica_id.unique_id self.max_loras = kwargs.get('max_loras', 5) self.base_model = model_id - # Choose model backend - if use_megatron: - from ..model.backends.megatron_model import TwinkleCompatMegatronModel - - self.model = TwinkleCompatMegatronModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - else: - from ..model.backends.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, + # Choose model backend (R3.1-3.3, R3.9). Validation runs before + # instantiation so an invalid value never produces a partial backend. + ctor_kwargs: dict[str, Any] = {'model_id': model_id, **kwargs} + if backend != 'mock': + ctor_kwargs.update( device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=self.replica_id, - **kwargs) + ) + self.model = _dispatch_model_backend(backend, ctor_kwargs) self.state: ServerStateProxy = get_server_state() self._replica_registered = False diff --git a/src/twinkle/server/model/backends/mock_model.py b/src/twinkle/server/model/backends/mock_model.py new file mode 100644 index 000000000..d4b4e425a --- /dev/null +++ b/src/twinkle/server/model/backends/mock_model.py @@ -0,0 +1,233 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Numpy-only mock model backend (R1, R3, R4). + +Provides ``TwinkleCompatMockModel``, a stand-in for the real +``TwinkleCompatTransformersModel`` whose only purpose is to exercise the +server's HTTP and dispatch paths on a CPU-only host with no torch / +transformers / vllm / megatron installed. Determinism is keyed by +``(model_id, adapter_name, seed, input_shape)`` so repeated requests with the +same payload produce identical numpy-derived results. + +This module deliberately avoids importing ``torch``, ``transformers``, +``vllm``, ``megatron`` or any module whose own imports would pull them +in transitively (e.g. ``twinkle.server.model.backends.common``). The class is +duck-typed against ``TwinkleCompatModelBase`` rather than subclassing it — +the base class lives in a torch-importing module and would defeat the +import-isolation requirement (R1.2). +""" +from __future__ import annotations + +from typing import Any + +import numpy as np + + +def _seed_for(model_id: str, adapter_name: str | None, seed: int, *extra: Any) -> int: + """Deterministic per-request RNG seed derived from string/int components.""" + h = hash((str(model_id), str(adapter_name), int(seed), tuple(map(repr, extra)))) + # numpy seeds must fit in uint32. + return h & 0xFFFFFFFF + + +class TwinkleCompatMockModel: + """Numpy-only mock model. + + Public API mirrors the methods that the model FastAPI handlers call on + ``self.model``. Every method either returns a deterministic numpy-derived + payload or completes as a no-op without raising. + """ + + def __init__( + self, + model_id: str, + *, + hidden_size: int = 8, + vocab_size: int = 32, + seed: int = 0, + **kwargs: Any, + ) -> None: + self.model_id = model_id + self._hidden_size = int(hidden_size) + self._vocab_size = int(vocab_size) + self._rng_seed = int(seed) + # adapter_name -> arbitrary config payload + self._adapters: dict[str, dict[str, Any]] = {} + + # ----- Forward family (R1.3) ----------------------------------------- # + + def _build_forward_result( + self, + inputs: Any, + adapter_name: str | None, + *, + loss_value: float = 0.0, + ) -> list[dict[str, Any]]: + """Return one deterministic synthetic per-input record. + + Shapes are derived from the input so ``_tinker_build_output``-style + callers see correctly-sized arrays. + """ + seq_lens = _input_seq_lengths(inputs) + out: list[dict[str, Any]] = [] + for idx, seq_len in enumerate(seq_lens): + rng = np.random.default_rng(_seed_for(self.model_id, adapter_name, self._rng_seed, idx, seq_len)) + logprobs = rng.uniform(-2.0, 0.0, size=seq_len).astype(np.float32) + elementwise_loss = rng.uniform(0.0, 1.0, size=seq_len).astype(np.float32) + out.append({ + 'logprobs': logprobs.tolist(), + 'elementwise_loss': elementwise_loss.tolist(), + 'loss': float(loss_value), + }) + return out + + def tinker_forward_only(self, *, inputs: Any, adapter_name: str | None = None, **kwargs: Any) -> list[Any]: + return [self._build_forward_result(inputs, adapter_name), 0.0] + + def tinker_forward_backward(self, *, inputs: Any, adapter_name: str, loss_fn: str, **kwargs: Any) -> list[Any]: + loss_seed = _seed_for(self.model_id, adapter_name, self._rng_seed, 'loss', loss_fn) + loss = float(np.random.default_rng(loss_seed).uniform(0.0, 1.0)) + return [self._build_forward_result(inputs, adapter_name, loss_value=loss), loss] + + def forward(self, *, inputs: Any, **kwargs: Any) -> list[dict[str, Any]]: + return self._build_forward_result(inputs, kwargs.get('adapter_name')) + + def forward_only(self, *, inputs: Any, **kwargs: Any) -> list[dict[str, Any]]: + return self._build_forward_result(inputs, kwargs.get('adapter_name')) + + def forward_backward(self, *, inputs: Any, **kwargs: Any) -> list[Any]: + loss = float(np.random.default_rng(self._rng_seed).uniform(0.0, 1.0)) + return [self._build_forward_result(inputs, kwargs.get('adapter_name'), loss_value=loss), loss] + + def calculate_loss(self, *, inputs: Any, **kwargs: Any) -> float: + return float(np.random.default_rng(self._rng_seed).uniform(0.0, 1.0)) + + # ----- Backward / optimizer (R1.4) ----------------------------------- # + + def backward(self, *args: Any, **kwargs: Any) -> None: + return None + + def step(self, *args: Any, **kwargs: Any) -> None: + return None + + def zero_grad(self, *args: Any, **kwargs: Any) -> None: + return None + + def lr_step(self, *args: Any, **kwargs: Any) -> None: + return None + + def clip_grad_norm(self, *args: Any, **kwargs: Any) -> float: + return 0.0 + + def clip_grad_and_step(self, *args: Any, **kwargs: Any) -> None: + return None + + def tinker_step(self, *, adam_params: Any = None, **kwargs: Any) -> None: + return None + + def tinker_calculate_metric(self, is_training: bool, **kwargs: Any) -> dict[str, float]: + return {'loss': 0.5, 'grad_norm': 0.1} + + def calculate_metric(self, *args: Any, **kwargs: Any) -> dict[str, float]: + return {'loss': 0.5, 'grad_norm': 0.1} + + def tinker_load(self, checkpoint_dir: str, **kwargs: Any) -> None: + return None + + # ----- Configuration setters (R1.4) ---------------------------------- # + + def set_loss(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_optimizer(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_lr_scheduler(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_template(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_processor(self, *args: Any, **kwargs: Any) -> None: + return None + + def add_metric(self, *args: Any, **kwargs: Any) -> None: + return None + + def apply_patch(self, *args: Any, **kwargs: Any) -> None: + return None + + # ----- Persistence stubs (R1.4) -------------------------------------- # + + def save(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {'status': 'ok', 'path': None} + + def load(self, *args: Any, **kwargs: Any) -> None: + return None + + def resume_from_checkpoint(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {'status': 'ok', 'progress': {}} + + def get_state_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {} + + def get_train_configs(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {} + + # ----- Adapter management (R1.5, R1.6, R1.7) ------------------------- # + + def add_adapter(self, adapter_name: str, **cfg: Any) -> None: + """Record an adapter without loading real weights (R1.5).""" + self._adapters[adapter_name] = dict(cfg) + + def add_adapter_to_model(self, adapter_name: str, config: Any = None, **cfg: Any) -> None: + merged: dict[str, Any] = dict(cfg) + if config is not None: + merged.setdefault('config', config) + self._adapters[adapter_name] = merged + + def remove_adapter(self, adapter_name: str) -> None: + """Remove ``adapter_name`` (R1.6); raise on absent (R1.7).""" + if adapter_name not in self._adapters: + raise KeyError(f'adapter not present: {adapter_name}') + del self._adapters[adapter_name] + + def has_adapter(self, adapter_name: str) -> bool: + return adapter_name in self._adapters + + +def _input_seq_lengths(inputs: Any) -> list[int]: + """Best-effort recovery of per-datum sequence lengths from heterogeneous inputs. + + The real backend pulls lengths from ``Datum.loss_fn_inputs['target_tokens']``, + but we want to stay numpy-only and avoid importing the tinker types. Falls + back to ``[1]`` so callers always get at least one record back. + """ + if inputs is None: + return [1] + if isinstance(inputs, list): + if not inputs: + return [1] + out: list[int] = [] + for item in inputs: + length = _seq_length_of(item) + out.append(length) + return out + return [_seq_length_of(inputs)] + + +def _seq_length_of(item: Any) -> int: + # Datum-like: model_input.tokens or loss_fn_inputs['target_tokens'] + for attr in ('model_input', 'inputs', 'tokens'): + v = getattr(item, attr, None) + if v is None: + continue + tokens = getattr(v, 'tokens', v) + if hasattr(tokens, '__len__'): + return max(1, len(tokens)) + if isinstance(item, dict): + for k in ('input_ids', 'tokens', 'target_tokens'): + if k in item and hasattr(item[k], '__len__'): + return max(1, len(item[k])) + if hasattr(item, '__len__'): + return max(1, len(item)) + return 1 diff --git a/src/twinkle/server/sampler/__init__.py b/src/twinkle/server/sampler/__init__.py index 58db90983..3569f1056 100644 --- a/src/twinkle/server/sampler/__init__.py +++ b/src/twinkle/server/sampler/__init__.py @@ -1,3 +1,18 @@ -from .app import build_sampler_app +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Sampler deployment package. + +``build_sampler_app`` is exposed lazily via ``__getattr__`` so importing the +mock backend (``twinkle.server.sampler.backends.mock_sampler``) on a CPU-only +host doesn't pull in vllm via ``app.py`` at package-init time (R2.2, R4.3). +""" +from __future__ import annotations __all__ = ['build_sampler_app'] + + +def __getattr__(name: str): + if name == 'build_sampler_app': + from .app import build_sampler_app + + return build_sampler_app + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 1262ae158..9877c68b4 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -14,6 +14,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.exceptions import ConfigError from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.state import ServerStateProxy, get_server_state @@ -27,6 +28,33 @@ logger = get_logger() +_SAMPLER_TYPES: tuple[str, ...] = ('mock', 'vllm', 'torch') + + +def _dispatch_sampler_backend(sampler_type: str, ctor_kwargs: dict[str, Any]) -> Any: + """Instantiate the sampler selected by ``sampler_type`` (R3.4-3.6, R3.10).""" + if sampler_type is None or not isinstance(sampler_type, str) or sampler_type == '': + raise ConfigError(field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES)) + if sampler_type not in _SAMPLER_TYPES: + raise ConfigError(field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES)) + if sampler_type == 'mock': + from .backends.mock_sampler import MockSampler + + # MockSampler accepts only model_id/seed/vocab_size — strip extras silently. + return MockSampler( + model_id=ctor_kwargs.get('model_id'), + seed=ctor_kwargs.get('seed', 0), + vocab_size=ctor_kwargs.get('vocab_size', 32), + ) + if sampler_type == 'torch': + from twinkle.sampler import TorchSampler # type: ignore[attr-defined] + + return TorchSampler(**ctor_kwargs) + from twinkle.sampler import vLLMSampler + + return vLLMSampler(**ctor_kwargs) + + class SamplerManagement(TaskQueueMixin): """Unified sampler management service. @@ -46,29 +74,38 @@ def __init__(self, engine_args: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) + # Skip twinkle.initialize for the mock backend (R3.8) — start without + # CUDA/torch/vllm. + if sampler_type != 'mock': + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False, + ) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.device_group = None + self.device_mesh = None self.sampler_type = sampler_type self.model_id = model_id replica_context = serve.get_replica_context() replica_id = replica_context.replica_id.unique_id - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) + sampler_kwargs: dict[str, Any] = {'model_id': model_id} + if sampler_type != 'mock': + sampler_kwargs.update( + engine_args=engine_args or {}, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **{k: v for k, v in kwargs.items() if k not in ('engine_args',)}, + ) + self.sampler = _dispatch_sampler_backend(sampler_type, sampler_kwargs) self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/sampler/backends/__init__.py b/src/twinkle/server/sampler/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/twinkle/server/sampler/backends/mock_sampler.py b/src/twinkle/server/sampler/backends/mock_sampler.py new file mode 100644 index 000000000..ef0c6c9c9 --- /dev/null +++ b/src/twinkle/server/sampler/backends/mock_sampler.py @@ -0,0 +1,106 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Numpy-only mock sampler backend (R2, R3). + +Implements the same surface as :class:`twinkle.sampler.base.Sampler` — +``sample``, ``apply_patch``, ``add_adapter_to_sampler`` — using only numpy. +The class is intentionally **duck-typed** rather than subclassed because +``twinkle.sampler.__init__`` eagerly imports the vLLM engine, which would +defeat the no-vllm-import requirement (R2.2) on a CPU-only host. + +Outputs are deterministic — keyed by ``(model_id, adapter_name, seed, +prompt_index, sample_index)`` — so repeated calls with the same parameters +produce identical token sequences and logprobs (R2.5). +""" +from __future__ import annotations + +from typing import Any, List, Optional + +import numpy as np + +# These data containers don't pull torch / vllm. +from twinkle.data_format import SampleResponse, SampledSequence, SamplingParams + + +class MockSampler: + """Deterministic numpy-only sampler. + + Provides the public methods callable from the sampler app and the Tinker / + Twinkle handlers; ``has_adapter`` is added for convenience and tests. + """ + + def __init__(self, model_id: str, *, seed: int = 0, vocab_size: int = 32, **kwargs: Any) -> None: + self.model_id = model_id + self._seed = int(seed) + self._vocab_size = int(vocab_size) + self._adapters: dict[str, Any] = {} + # Match the Sampler base attributes so duck-typed callers don't surprise. + self.engine = None + self.template = None + + # ----- Sampler interface (R2.1, R2.3, R2.4, R2.5, R2.6) -------------- # + + def sample( + self, + inputs: Any, + sampling_params: Optional[SamplingParams] = None, + adapter_name: str = '', + *, + num_samples: int = 1, + ) -> List[SampleResponse]: + max_tokens = self._resolve_max_tokens(sampling_params) + if max_tokens is None or max_tokens < 1: + raise ValueError( + f'max_tokens must be >= 1, got {max_tokens!r} ' + '(set sampling_params.max_tokens to a positive integer)' + ) + + normalized = self._normalize_inputs(inputs) + responses: list[SampleResponse] = [] + for prompt_idx, _ in enumerate(normalized): + sequences: list[SampledSequence] = [] + for sample_idx in range(num_samples): + seed = ( + abs(hash((str(self.model_id), str(adapter_name), int(self._seed), int(prompt_idx), int(sample_idx)))) + & 0xFFFFFFFF + ) + rng = np.random.default_rng(seed) + tokens = [int(t) for t in rng.integers(low=0, high=max(1, self._vocab_size), size=max_tokens)] + logprobs_per_token = rng.uniform(-2.0, 0.0, size=max_tokens).astype(float).tolist() + # One logprob entry per emitted token (R2.4) — list of (id, logprob). + logprobs = [[(tok, float(lp))] for tok, lp in zip(tokens, logprobs_per_token)] + sequences.append( + SampledSequence( + stop_reason='length', + tokens=tokens, + logprobs=logprobs, + ) + ) + responses.append(SampleResponse(sequences=sequences)) + return responses + + def apply_patch(self, patch_cls: Any, **kwargs: Any) -> None: + return None + + # ----- Adapter management (R2.7) ------------------------------------- # + + def add_adapter_to_sampler(self, adapter_name: str, config: Any) -> None: + self._adapters[adapter_name] = config + + def has_adapter(self, adapter_name: str) -> bool: + return adapter_name in self._adapters + + # ----- Helpers ------------------------------------------------------- # + + @staticmethod + def _normalize_inputs(inputs: Any) -> list[Any]: + if inputs is None: + return [None] + if isinstance(inputs, list): + return inputs if inputs else [None] + return [inputs] + + @staticmethod + def _resolve_max_tokens(params: Optional[SamplingParams]) -> Optional[int]: + if params is None: + return None + return getattr(params, 'max_tokens', None) diff --git a/tests/server/model/__init__.py b/tests/server/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/model/test_mock_model.py b/tests/server/model/test_mock_model.py new file mode 100644 index 000000000..5e86da377 --- /dev/null +++ b/tests/server/model/test_mock_model.py @@ -0,0 +1,154 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the numpy-only mock model backend (R1, R3, R4). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 1: Mock model interface conformance +- # Feature: server-config-observability-refactor, Property 2: Mock model forward determinism and shape +- # Feature: server-config-observability-refactor, Property 3: Mock model adapter add/remove round-trip +- # Feature: server-config-observability-refactor, Property 4: Mock model remove-absent raises and preserves record +- # Feature: server-config-observability-refactor, Property 10: Model backend dispatch +""" +from __future__ import annotations + +import sys + +import pytest +from hypothesis import given, settings, strategies as st + +from twinkle.server.exceptions import ConfigError +from twinkle.server.model.app import _MODEL_BACKENDS, _dispatch_model_backend +from twinkle.server.model.backends.mock_model import TwinkleCompatMockModel + + +# ---------- Property 1: interface conformance (R1.1, R1.4) ---------------- # + + +_REQUIRED_METHODS = ( + 'tinker_forward_only', 'tinker_forward_backward', 'tinker_step', + 'tinker_calculate_metric', 'tinker_load', + 'forward_only', 'forward_backward', 'forward', 'calculate_loss', + 'backward', 'step', 'zero_grad', 'lr_step', 'clip_grad_norm', + 'set_loss', 'set_optimizer', 'set_lr_scheduler', 'set_template', + 'set_processor', 'add_metric', 'apply_patch', + 'save', 'load', 'resume_from_checkpoint', 'get_state_dict', 'get_train_configs', + 'add_adapter', 'add_adapter_to_model', 'remove_adapter', 'has_adapter', +) + + +@pytest.mark.parametrize('method_name', _REQUIRED_METHODS) +def test_property_1_required_method_present(method_name: str) -> None: + m = TwinkleCompatMockModel('mid') + assert callable(getattr(m, method_name)), method_name + + +def test_property_1_constructor_does_not_raise() -> None: + TwinkleCompatMockModel('mid') + + +# ---------- Property 2: forward determinism + shape (R1.3, R4.4) ---------- # + + +@settings(max_examples=100) +@given( + seq_lens=st.lists(st.integers(min_value=1, max_value=12), min_size=1, max_size=5), + seed=st.integers(min_value=0, max_value=99), +) +def test_property_2_forward_only_deterministic_and_shaped(seq_lens: list, seed: int) -> None: + inputs = [{'tokens': list(range(n))} for n in seq_lens] + a = TwinkleCompatMockModel('mid', seed=seed) + b = TwinkleCompatMockModel('mid', seed=seed) + out_a = a.forward_only(inputs=inputs) + out_b = b.forward_only(inputs=inputs) + assert out_a == out_b + assert len(out_a) == len(inputs) + for record, n in zip(out_a, seq_lens): + assert len(record['logprobs']) == n + assert len(record['elementwise_loss']) == n + + +@settings(max_examples=100) +@given(seq_lens=st.lists(st.integers(min_value=1, max_value=8), min_size=1, max_size=4)) +def test_property_2_tinker_forward_backward_loss_is_finite(seq_lens: list) -> None: + m = TwinkleCompatMockModel('mid') + inputs = [{'tokens': list(range(n))} for n in seq_lens] + result, loss = m.tinker_forward_backward(inputs=inputs, adapter_name='a', loss_fn='cross_entropy') + assert isinstance(loss, float) + assert 0.0 <= loss <= 1.0 + assert len(result) == len(inputs) + + +# ---------- Property 3: adapter round-trip (R1.5, R1.6) ------------------- # + + +@settings(max_examples=100) +@given(name=st.text(min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) +def test_property_3_adapter_add_remove_round_trip(name: str) -> None: + m = TwinkleCompatMockModel('mid') + assert not m.has_adapter(name) + m.add_adapter(name, rank=4) + assert m.has_adapter(name) + m.remove_adapter(name) + assert not m.has_adapter(name) + + +# ---------- Property 4: remove-absent raises + preserves (R1.7) ----------- # + + +@settings(max_examples=100) +@given(name=st.text(min_size=1, max_size=12)) +def test_property_4_remove_absent_raises(name: str) -> None: + m = TwinkleCompatMockModel('mid') + pre = dict(m._adapters) + with pytest.raises(KeyError): + m.remove_adapter(name) + assert m._adapters == pre + + +# ---------- Property 10: Model backend dispatch (R3.1-3.3, R3.7, R3.9) ---- # + + +def test_property_10_mock_dispatch_returns_mock_model() -> None: + m = _dispatch_model_backend('mock', {'model_id': 'mid'}) + assert isinstance(m, TwinkleCompatMockModel) + + +@settings(max_examples=100) +@given(bad=st.text(min_size=1, max_size=10).filter(lambda s: s not in _MODEL_BACKENDS)) +def test_property_10_invalid_backend_raises_config_error(bad: str) -> None: + with pytest.raises(ConfigError) as exc: + _dispatch_model_backend(bad, {'model_id': 'mid'}) + assert exc.value.field == 'backend' + assert exc.value.value == bad + assert exc.value.allowed == sorted(_MODEL_BACKENDS) or set(exc.value.allowed) == set(_MODEL_BACKENDS) + + +@pytest.mark.parametrize('value', [None, '']) +def test_property_10_absent_or_empty_backend_raises(value) -> None: + with pytest.raises(ConfigError) as exc: + _dispatch_model_backend(value, {'model_id': 'mid'}) + assert exc.value.field == 'backend' + + +# ---------- Import isolation (R1.2) --------------------------------------- # + + +def test_import_isolation_no_torch_required() -> None: + """Block torch/transformers/vllm/megatron and re-import the mock module.""" + blocked = ('torch', 'transformers', 'vllm', 'megatron') + saved = {m: sys.modules.pop(m, None) for m in blocked} + saved_mock = sys.modules.pop('twinkle.server.model.backends.mock_model', None) + try: + for m in blocked: + sys.modules[m] = None # type: ignore[assignment] + import importlib + + mod = importlib.import_module('twinkle.server.model.backends.mock_model') + assert mod.TwinkleCompatMockModel is not None + finally: + for m in blocked: + if saved[m] is None: + sys.modules.pop(m, None) + else: + sys.modules[m] = saved[m] + if saved_mock is not None: + sys.modules['twinkle.server.model.backends.mock_model'] = saved_mock diff --git a/tests/server/sampler/__init__.py b/tests/server/sampler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/sampler/test_mock_sampler.py b/tests/server/sampler/test_mock_sampler.py new file mode 100644 index 000000000..594da7ea9 --- /dev/null +++ b/tests/server/sampler/test_mock_sampler.py @@ -0,0 +1,161 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the numpy-only mock sampler (R2, R3, R4). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 5: Mock sampler interface conformance +- # Feature: server-config-observability-refactor, Property 6: Mock sampler output length and logprob count +- # Feature: server-config-observability-refactor, Property 7: Mock sampler determinism +- # Feature: server-config-observability-refactor, Property 8: Mock sampler rejects invalid max tokens +- # Feature: server-config-observability-refactor, Property 9: Mock sampler adapter record update +- # Feature: server-config-observability-refactor, Property 11: Sampler backend dispatch +""" +from __future__ import annotations + +from pathlib import Path + +import pytest +from hypothesis import given, settings, strategies as st + +from twinkle.data_format import InputFeature, SamplingParams +from twinkle.server.exceptions import ConfigError +from twinkle.server.sampler.app import _SAMPLER_TYPES, _dispatch_sampler_backend +from twinkle.server.sampler.backends.mock_sampler import MockSampler + + +# ---------- Property 5: interface conformance (R2.1) ---------------------- # + + +_REQUIRED_METHODS = ('sample', 'apply_patch', 'add_adapter_to_sampler', 'has_adapter') + + +@pytest.mark.parametrize('method', _REQUIRED_METHODS) +def test_property_5_required_method_present(method: str) -> None: + s = MockSampler('mid') + assert callable(getattr(s, method)) + + +# ---------- Property 6: output length + logprob count (R2.3, R2.4) -------- # + + +@settings(max_examples=100) +@given( + max_tokens=st.integers(min_value=1, max_value=20), + num_samples=st.integers(min_value=1, max_value=4), +) +def test_property_6_output_length_and_logprob_count(max_tokens: int, num_samples: int) -> None: + s = MockSampler('mid') + inp = InputFeature(input_ids=[1, 2, 3]) + responses = s.sample(inp, SamplingParams(max_tokens=max_tokens), adapter_name='a', num_samples=num_samples) + assert len(responses) == 1 + seqs = responses[0].sequences + assert len(seqs) == num_samples + for seq in seqs: + assert len(seq.tokens) == max_tokens + assert len(seq.logprobs) == max_tokens + + +# ---------- Property 7: determinism (R2.5, R4.5) -------------------------- # + + +@settings(max_examples=100) +@given( + max_tokens=st.integers(min_value=1, max_value=10), + num_samples=st.integers(min_value=1, max_value=3), + adapter=st.sampled_from(['', 'a', 'lora-1']), +) +def test_property_7_determinism(max_tokens: int, num_samples: int, adapter: str) -> None: + s = MockSampler('mid', seed=42) + inp = InputFeature(input_ids=[1, 2, 3]) + r1 = s.sample(inp, SamplingParams(max_tokens=max_tokens), adapter_name=adapter, num_samples=num_samples) + r2 = s.sample(inp, SamplingParams(max_tokens=max_tokens), adapter_name=adapter, num_samples=num_samples) + assert r1 == r2 + + +# ---------- Property 8: invalid max_tokens rejected (R2.6) ---------------- # + + +@settings(max_examples=50) +@given(bad=st.integers(max_value=0, min_value=-1000)) +def test_property_8_max_tokens_lt_1_raises(bad: int) -> None: + s = MockSampler('mid') + inp = InputFeature(input_ids=[1]) + with pytest.raises(ValueError) as exc: + s.sample(inp, SamplingParams(max_tokens=bad)) + assert 'max_tokens' in str(exc.value) + + +def test_property_8_no_sampling_params_raises() -> None: + s = MockSampler('mid') + inp = InputFeature(input_ids=[1]) + with pytest.raises(ValueError): + s.sample(inp, sampling_params=None) + + +# ---------- Property 9: adapter record update (R2.7) ---------------------- # + + +@settings(max_examples=100) +@given(name=st.text(min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) +def test_property_9_add_adapter_to_sampler(name: str) -> None: + s = MockSampler('mid') + assert not s.has_adapter(name) + s.add_adapter_to_sampler(name, {'rank': 4}) + assert s.has_adapter(name) + assert s._adapters[name] == {'rank': 4} + + +# ---------- Property 11: Sampler dispatch (R3.4-3.6, R3.10) --------------- # + + +def test_property_11_mock_dispatch_returns_mock_sampler() -> None: + s = _dispatch_sampler_backend('mock', {'model_id': 'mid'}) + assert isinstance(s, MockSampler) + + +@settings(max_examples=100) +@given(bad=st.text(min_size=1, max_size=10).filter(lambda s: s not in _SAMPLER_TYPES)) +def test_property_11_invalid_sampler_type_raises_config_error(bad: str) -> None: + with pytest.raises(ConfigError) as exc: + _dispatch_sampler_backend(bad, {'model_id': 'mid'}) + assert exc.value.field == 'sampler_type' + assert exc.value.value == bad + + +@pytest.mark.parametrize('value', [None, '']) +def test_property_11_absent_or_empty_sampler_type_raises(value) -> None: + with pytest.raises(ConfigError) as exc: + _dispatch_sampler_backend(value, {'model_id': 'mid'}) + assert exc.value.field == 'sampler_type' + + +# ---------- No direct vllm import (R2.2) ---------------------------------- # + + +def test_mock_sampler_module_does_not_directly_import_vllm() -> None: + """Static check: ``mock_sampler.py`` must not import ``vllm`` directly.""" + src = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'sampler' / 'backends' / 'mock_sampler.py' + text = src.read_text() + for forbidden in ('import vllm', 'from vllm'): + assert forbidden not in text, f'mock_sampler.py contains {forbidden!r}' + + +# ---------- Mock example config loads (R5.4) ------------------------------ # + + +def test_mock_example_config_loads_via_server_config() -> None: + from twinkle.server.config import ServerConfig + + repo_root = Path(__file__).resolve().parents[3] + cfg_path = repo_root / 'cookbook' / 'client' / 'server' / 'mock' / 'server_config.yaml' + cfg = ServerConfig.from_yaml(cfg_path) + backends = {a.name: getattr(a.args, 'backend', None) or getattr(a.args, 'sampler_type', None) for a in cfg.applications} + assert backends.get('models-mock') == 'mock' + assert backends.get('sampler-mock') == 'mock' + + +def test_mock_readme_documents_launch_and_targets() -> None: + repo_root = Path(__file__).resolve().parents[3] + readme = (repo_root / 'cookbook' / 'client' / 'server' / 'mock' / 'README.md').read_text() + assert 'python -m twinkle.server' in readme + assert '30 seconds' in readme or '30s' in readme + assert 'Not for production' in readme or 'NOT FOR PRODUCTION' in readme.upper() From e4ed2ba22d52873dd54633616cf6c28f7bb67e2b Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:37:42 +0800 Subject: [PATCH 15/34] feat(server): business-layer tracing + correlation + resource metrics (Phase 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a traced_operation context manager that wraps a business-layer block in one OpenTelemetry span: starts before the block, records exceptions and sets span status to ERROR on raise, ends after the block, and re-raises the original exception (R10.1, R10.4). The helper degrades to a NoOp context manager when the OTEL SDK is missing so call sites get the same return value with or without tracing installed (R10.5 / R18.3). Defines the standardized correlation keys (twinkle.session_id / twinkle.model_id / twinkle.replica_id / twinkle.token_id / twinkle.sampling_session_id / twinkle.base_model) in a new telemetry/correlation.py and adds set_correlation_attrs(span, values) which attaches only present (non-None) values so partially-known operations never end up with empty attributes (R11.1, R11.2, R11.3). Wraps every server-state mutation that creates / registers an entity — create_session, register_model, register_replica, create_sampling_session — with traced_operation and the matching correlation attributes. Adds ResourceMetricsCollector exposing observable gauges for system CPU utilization, system memory, process RSS memory, and per-GPU utilization / memory (R12.1). The collector is started by ensure_telemetry_initialized in each Ray Serve worker, including when telemetry is disabled, so the graceful-degradation path matches the enabled path (R12.2). When psutil or pynvml is missing, or no GPU is present, the affected gauges report no data and the collector does not raise (R12.3 / R18.3). Declares psutil and pynvml as a new [telemetry] extras group in pyproject.toml (R12.4). Property tests (Hypothesis, max_examples=100) cover the prefix invariant, correlation attachment skipping None, and NoOp degradation equivalence; unit tests verify span lifecycle and exception recording against an in- memory OTEL exporter, and the wiring tests confirm the worker-init hook calls into the collector regardless of TWINKLE_TELEMETRY_ENABLED. The Phase 0a client-API contract baseline is re-run green. Note: the Grafana dashboard CPU/Mem/GPU panels (task 7.13) and the LGTM integration tests (7.15) require the docker-compose stack and are deferred to the documentation phase. --- pyproject.toml | 1 + src/twinkle/server/state/server_state.py | 89 +++++-- src/twinkle/server/telemetry/correlation.py | 48 ++++ .../server/telemetry/resource_metrics.py | 194 +++++++++++++++ src/twinkle/server/telemetry/tracing.py | 42 ++++ src/twinkle/server/telemetry/worker_init.py | 24 +- tests/server/telemetry/__init__.py | 0 .../telemetry/test_tracing_and_correlation.py | 234 ++++++++++++++++++ 8 files changed, 604 insertions(+), 28 deletions(-) create mode 100644 src/twinkle/server/telemetry/correlation.py create mode 100644 src/twinkle/server/telemetry/resource_metrics.py create mode 100644 tests/server/telemetry/__init__.py create mode 100644 tests/server/telemetry/test_tracing_and_correlation.py diff --git a/pyproject.toml b/pyproject.toml index b54feaceb..652876ed0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ vllm = ["vllm>=0.11"] ray = ["ray[serve]"] tinker = ["tinker==0.14.0"] test = ["hypothesis>=6.0"] +telemetry = ["psutil>=5.9.0", "pynvml>=11.0.0"] docs = [ "sphinx>=5.3.0,<6.0.0", "docutils>=0.16.0,<0.17.0", diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 908103067..e76604a3b 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -10,6 +10,15 @@ from twinkle.server.exceptions import ConfigMismatchError from twinkle.server.telemetry import MetricsRegistry +from twinkle.server.telemetry.correlation import ( + BASE_MODEL, + MODEL_ID, + REPLICA_ID, + SAMPLING_SESSION_ID, + SESSION_ID, + TOKEN_ID, +) +from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger from .backend import StateBackend from .backend.factory import PersistenceConfig, create_backend @@ -86,13 +95,17 @@ async def create_session(self, payload: dict[str, Any]) -> str: The session_id for the created session. """ session_id = payload.get('session_id') or f'session_{uuid.uuid4().hex}' - record = SessionRecord( - tags=list(payload.get('tags') or []), - user_metadata=payload.get('user_metadata') or {}, - sdk_version=payload.get('sdk_version'), - ) - await self._session_mgr.add(session_id, record) - return session_id + with traced_operation( + 'server_state.create_session', + attrs={SESSION_ID: session_id}, + ): + record = SessionRecord( + tags=list(payload.get('tags') or []), + user_metadata=payload.get('user_metadata') or {}, + sdk_version=payload.get('sdk_version'), + ) + await self._session_mgr.add(session_id, record) + return session_id async def touch_session(self, session_id: str) -> bool: """Update session heartbeat timestamp. @@ -136,17 +149,27 @@ async def register_model(self, 'model_id') or f"{_time}-{payload.get('base_model', 'model')}-{uuid.uuid4().hex[:8]}" _model_id = re.sub(r'[^\w\-]', '_', _model_id) - record = ModelRecord( - session_id=session_id or payload.get('session_id'), - model_seq_id=payload.get('model_seq_id'), - base_model=payload.get('base_model'), - user_metadata=payload.get('user_metadata') or {}, - lora_config=payload.get('lora_config'), - token=token, - replica_id=replica_id, - ) - await self._model_mgr.add(_model_id, record) - return _model_id + with traced_operation( + 'server_state.register_model', + attrs={ + MODEL_ID: _model_id, + BASE_MODEL: payload.get('base_model'), + REPLICA_ID: replica_id, + TOKEN_ID: token, + SESSION_ID: session_id or payload.get('session_id'), + }, + ): + record = ModelRecord( + session_id=session_id or payload.get('session_id'), + model_seq_id=payload.get('model_seq_id'), + base_model=payload.get('base_model'), + user_metadata=payload.get('user_metadata') or {}, + lora_config=payload.get('lora_config'), + token=token, + replica_id=replica_id, + ) + await self._model_mgr.add(_model_id, record) + return _model_id async def unload_model(self, model_id: str) -> bool: """Remove a model from the registry. @@ -170,7 +193,11 @@ async def register_replica(self, replica_id: str, max_loras: int) -> None: replica_id: Unique identifier for the replica. max_loras: Maximum number of LoRA adapters the replica can hold. """ - await self._model_mgr.register_replica(replica_id, max_loras) + with traced_operation( + 'server_state.register_replica', + attrs={REPLICA_ID: replica_id}, + ): + await self._model_mgr.register_replica(replica_id, max_loras) async def unregister_replica(self, replica_id: str) -> None: """Remove a replica from the registry. @@ -205,14 +232,22 @@ async def create_sampling_session(self, payload: dict[str, Any], sampling_sessio """ _sampling_session_id: str = sampling_session_id or payload.get( 'sampling_session_id') or f'sampling_{uuid.uuid4().hex}' - record = SamplingSessionRecord( - session_id=payload.get('session_id'), - seq_id=payload.get('sampling_session_seq_id'), - base_model=payload.get('base_model'), - model_path=payload.get('model_path'), - ) - await self._sampling_mgr.add(_sampling_session_id, record) - return _sampling_session_id + with traced_operation( + 'server_state.create_sampling_session', + attrs={ + SAMPLING_SESSION_ID: _sampling_session_id, + SESSION_ID: payload.get('session_id'), + BASE_MODEL: payload.get('base_model'), + }, + ): + record = SamplingSessionRecord( + session_id=payload.get('session_id'), + seq_id=payload.get('sampling_session_seq_id'), + base_model=payload.get('base_model'), + model_path=payload.get('model_path'), + ) + await self._sampling_mgr.add(_sampling_session_id, record) + return _sampling_session_id async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: """Get a sampling session by ID as a plain dict.""" diff --git a/src/twinkle/server/telemetry/correlation.py b/src/twinkle/server/telemetry/correlation.py new file mode 100644 index 000000000..cdbf23975 --- /dev/null +++ b/src/twinkle/server/telemetry/correlation.py @@ -0,0 +1,48 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Correlation key constants for business-layer spans (R11). + +All correlation attributes share the ``twinkle.`` prefix so operators can +filter Tempo / Loki by ``twinkle.session_id``, ``twinkle.model_id``, etc. +The ``set_correlation_attrs`` helper attaches only present (non-None) values +so partially-known operations don't end up with empty attributes. +""" +from __future__ import annotations + +from typing import Any, Mapping + +PREFIX = 'twinkle.' + +SESSION_ID = f'{PREFIX}session_id' +MODEL_ID = f'{PREFIX}model_id' +REPLICA_ID = f'{PREFIX}replica_id' +TOKEN_ID = f'{PREFIX}token_id' +SAMPLING_SESSION_ID = f'{PREFIX}sampling_session_id' +BASE_MODEL = f'{PREFIX}base_model' + + +CORRELATION_KEYS: tuple[str, ...] = ( + SESSION_ID, + MODEL_ID, + REPLICA_ID, + TOKEN_ID, + SAMPLING_SESSION_ID, + BASE_MODEL, +) + + +def set_correlation_attrs(span: Any, values: Mapping[str, Any] | None) -> None: + """Attach the given correlation attributes to ``span``. + + Skips ``None`` values entirely (R11.2: ``when that value is available``) + and is a no-op for NoOp spans returned when the OTEL SDK is not + installed. + """ + if not values or span is None: + return + setter = getattr(span, 'set_attribute', None) + if setter is None: + return + for key, value in values.items(): + if value is None: + continue + setter(key, value) diff --git a/src/twinkle/server/telemetry/resource_metrics.py b/src/twinkle/server/telemetry/resource_metrics.py new file mode 100644 index 000000000..15248d5f9 --- /dev/null +++ b/src/twinkle/server/telemetry/resource_metrics.py @@ -0,0 +1,194 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Resource (CPU / Memory / GPU) observable gauges (R12). + +Registers OTEL ``observable_gauge`` instruments for CPU utilization, memory +usage (system + process), GPU utilization, and GPU memory. The data sources +(``psutil`` for CPU/memory, ``pynvml`` for GPU) are **optional telemetry +dependencies** — when either is missing, the corresponding gauges report no +data and the collector does not raise (R12.3, R18.3). + +The collector is started by :func:`worker_init.ensure_telemetry_initialized` +in each Ray Serve worker process so per-replica resource usage shows up in +Mimir / Grafana. +""" +from __future__ import annotations + +import os +from typing import Any, Iterator + +from .provider import get_meter + +try: + import psutil # type: ignore + _PSUTIL_AVAILABLE = True +except Exception: + _PSUTIL_AVAILABLE = False + +try: + import pynvml # type: ignore + _PYNVML_AVAILABLE = True +except Exception: + _PYNVML_AVAILABLE = False + + +_NVML_INITIALIZED = False + + +def _nvml_handle_count() -> int: + """Return GPU count, initializing pynvml lazily. ``0`` if unavailable.""" + global _NVML_INITIALIZED + if not _PYNVML_AVAILABLE: + return 0 + try: + if not _NVML_INITIALIZED: + pynvml.nvmlInit() + _NVML_INITIALIZED = True + return int(pynvml.nvmlDeviceGetCount()) + except Exception: + return 0 + + +class ResourceMetricsCollector: + """Owns the observable-gauge callbacks for CPU/Mem/GPU.""" + + def __init__(self) -> None: + self._started = False + # Track which gauges were registered so callers (and tests) can + # introspect what's exported in this process. + self.registered_gauges: list[str] = [] + + def maybe_start(self) -> None: + """Register the available gauges; idempotent and never raises.""" + if self._started: + return + self._started = True + try: + meter = get_meter('twinkle.server.resource') + except Exception: + return + + # CPU + memory require psutil. + if _PSUTIL_AVAILABLE: + try: + meter.create_observable_gauge( + 'twinkle.system.cpu.utilization', + description='System CPU utilization (0..1)', + callbacks=[self._cpu_utilization_callback], + ) + self.registered_gauges.append('twinkle.system.cpu.utilization') + meter.create_observable_gauge( + 'twinkle.system.memory.usage_bytes', + description='System memory used in bytes', + callbacks=[self._memory_usage_callback], + ) + self.registered_gauges.append('twinkle.system.memory.usage_bytes') + meter.create_observable_gauge( + 'twinkle.process.memory.usage_bytes', + description='Resident-set memory of this process in bytes', + callbacks=[self._process_memory_callback], + ) + self.registered_gauges.append('twinkle.process.memory.usage_bytes') + except Exception: + pass + + # GPU requires pynvml AND at least one GPU device — without either, + # we silently skip GPU gauges (R12.3 / R18.3). + if _nvml_handle_count() > 0: + try: + meter.create_observable_gauge( + 'twinkle.gpu.utilization', + description='Per-GPU utilization (0..1)', + callbacks=[self._gpu_utilization_callback], + ) + self.registered_gauges.append('twinkle.gpu.utilization') + meter.create_observable_gauge( + 'twinkle.gpu.memory.usage_bytes', + description='Per-GPU memory used in bytes', + callbacks=[self._gpu_memory_callback], + ) + self.registered_gauges.append('twinkle.gpu.memory.usage_bytes') + except Exception: + pass + + # ----- callbacks ----------------------------------------------------- # + + @staticmethod + def _cpu_utilization_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + if not _PSUTIL_AVAILABLE: + return iter(()) + try: + value = float(psutil.cpu_percent(interval=None)) / 100.0 + return iter([Observation(value)]) + except Exception: + return iter(()) + + @staticmethod + def _memory_usage_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + if not _PSUTIL_AVAILABLE: + return iter(()) + try: + return iter([Observation(int(psutil.virtual_memory().used))]) + except Exception: + return iter(()) + + @staticmethod + def _process_memory_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + if not _PSUTIL_AVAILABLE: + return iter(()) + try: + rss = psutil.Process(os.getpid()).memory_info().rss + return iter([Observation(int(rss))]) + except Exception: + return iter(()) + + @staticmethod + def _gpu_utilization_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + count = _nvml_handle_count() + out: list[Any] = [] + for i in range(count): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + out.append(Observation(float(util.gpu) / 100.0, {'gpu_index': i})) + except Exception: + continue + return iter(out) + + @staticmethod + def _gpu_memory_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + count = _nvml_handle_count() + out: list[Any] = [] + for i in range(count): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + out.append(Observation(int(mem.used), {'gpu_index': i})) + except Exception: + continue + return iter(out) + + +_GLOBAL_COLLECTOR: ResourceMetricsCollector | None = None + + +def get_collector() -> ResourceMetricsCollector: + global _GLOBAL_COLLECTOR + if _GLOBAL_COLLECTOR is None: + _GLOBAL_COLLECTOR = ResourceMetricsCollector() + return _GLOBAL_COLLECTOR + + +def reset_collector_for_tests() -> None: + """Clear the module-global collector. Test-only helper.""" + global _GLOBAL_COLLECTOR + _GLOBAL_COLLECTOR = None diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py index 9a12a5b9d..d093013a8 100644 --- a/src/twinkle/server/telemetry/tracing.py +++ b/src/twinkle/server/telemetry/tracing.py @@ -2,6 +2,9 @@ from __future__ import annotations +from contextlib import contextmanager +from typing import Any, Iterator, Mapping + from fastapi import Request try: @@ -13,6 +16,9 @@ _OTEL_AVAILABLE = False +from .correlation import set_correlation_attrs + + def get_tracer(name: str = "twinkle-server"): """Retrieve tracer instance. Returns NoOp tracer when OTEL is not installed.""" if not _OTEL_AVAILABLE: @@ -59,6 +65,42 @@ def start_span(self, name, **kwargs): return _NoopSpan() +@contextmanager +def traced_operation( + name: str, + *, + attrs: Mapping[str, Any] | None = None, + tracer_name: str = 'twinkle.server.business', +) -> Iterator[Any]: + """Run a business-layer block under one OTEL span (R10). + + The span starts before the block runs and ends after it returns. If the + block raises, the exception is recorded on the span, the span status is + set to ERROR, the span is ended, and the original exception is re-raised + to the caller (R10.4). When the OTEL SDK is missing, the context manager + degrades to a NoOp that runs the block normally and returns the same + result it would return when tracing is active (R10.5 / R18.3). + """ + if not _OTEL_AVAILABLE: + yield _NoopSpan() + return + + tracer = trace.get_tracer(tracer_name) + with tracer.start_as_current_span(name) as span: + if attrs: + set_correlation_attrs(span, attrs) + try: + yield span + except Exception as exc: + try: + span.record_exception(exc) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(exc))) + except Exception: + # NEVER let a tracing error mask the underlying exception. + pass + raise + + def create_tracing_middleware(service_component: str): """Create an HTTP tracing middleware compatible with Ray Serve pickling. diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py index e07c7562b..f4ac60cfd 100644 --- a/src/twinkle/server/telemetry/worker_init.py +++ b/src/twinkle/server/telemetry/worker_init.py @@ -27,7 +27,12 @@ def ensure_telemetry_initialized() -> None: _worker_initialized = True - if os.environ.get('TWINKLE_TELEMETRY_ENABLED') != '1': + telemetry_enabled = os.environ.get('TWINKLE_TELEMETRY_ENABLED') == '1' + + if not telemetry_enabled: + # Even with telemetry disabled, register the resource collector so + # graceful-degradation behavior matches the enabled path (R12.2/R18.3). + _start_resource_collector() return try: @@ -47,3 +52,20 @@ def ensure_telemetry_initialized() -> None: logger.info(f'Worker telemetry initialized (service={config.service_name}, debug={config.debug})') except Exception as e: logger.warning(f'Failed to initialize worker telemetry: {e}') + + _start_resource_collector() + + +def _start_resource_collector() -> None: + """Start the resource (CPU / Memory / GPU) metrics collector. + + Safe to call even when telemetry init was skipped or failed — the + collector picks up the NoOp meter and silently records no observations + (R12.2 / R18.3). + """ + try: + from twinkle.server.telemetry import resource_metrics + + resource_metrics.get_collector().maybe_start() + except Exception as e: + logger.debug(f'Resource metrics collector start failed: {e}') diff --git a/tests/server/telemetry/__init__.py b/tests/server/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py new file mode 100644 index 000000000..838be3f99 --- /dev/null +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -0,0 +1,234 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the business-layer tracing helper, correlation +keys, and ``ResourceMetricsCollector`` (R10, R11, R12, R18). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 19: Business-layer span lifecycle +- # Feature: server-config-observability-refactor, Property 20: Span exception handling +- # Feature: server-config-observability-refactor, Property 21: Tracing graceful-degradation equivalence +- # Feature: server-config-observability-refactor, Property 22: Correlation attribute attachment +- # Feature: server-config-observability-refactor, Property 23: Correlation prefix invariant +""" +from __future__ import annotations + +from unittest import mock + +import pytest +from hypothesis import given, settings, strategies as st + +from twinkle.server.telemetry import correlation +from twinkle.server.telemetry.correlation import ( + CORRELATION_KEYS, + PREFIX, + set_correlation_attrs, +) +from twinkle.server.telemetry.tracing import _NoopSpan, traced_operation + + +# ---------- Property 23: prefix invariant (R11.3) ------------------------- # + + +@pytest.mark.parametrize('key', CORRELATION_KEYS) +def test_property_23_prefix_invariant(key: str) -> None: + assert key.startswith(PREFIX), key + + +def test_property_23_helper_constants_complete() -> None: + expected = { + 'twinkle.session_id', + 'twinkle.model_id', + 'twinkle.replica_id', + 'twinkle.token_id', + 'twinkle.sampling_session_id', + 'twinkle.base_model', + } + assert set(CORRELATION_KEYS) == expected + + +# ---------- Property 22: attachment of present-only values (R11.1, R11.2) - # + + +class _RecordingSpan: + def __init__(self) -> None: + self.attrs: dict[str, object] = {} + + def set_attribute(self, key: str, value: object) -> None: + self.attrs[key] = value + + +@settings(max_examples=100) +@given( + payload=st.fixed_dictionaries( + {}, + optional={ + correlation.SESSION_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + correlation.MODEL_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + correlation.REPLICA_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + correlation.TOKEN_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + }, + ) +) +def test_property_22_set_correlation_attrs_skips_none(payload: dict) -> None: + span = _RecordingSpan() + set_correlation_attrs(span, payload) + expected = {k: v for k, v in payload.items() if v is not None} + assert span.attrs == expected + + +def test_property_22_noop_span_safe() -> None: + """``set_correlation_attrs`` is a no-op on a NoOp span (no SDK installed).""" + span = _NoopSpan() + set_correlation_attrs(span, {correlation.SESSION_ID: 's1'}) + # NoOp span has no recording surface — passing None / empty mapping is also safe. + set_correlation_attrs(None, {correlation.SESSION_ID: 's1'}) + set_correlation_attrs(span, None) + + +# ---------- Property 21: NoOp degradation equivalence (R10.5, R18.3) ------ # + + +def test_property_21_noop_yields_same_result_as_active() -> None: + """When OTEL is absent, ``traced_operation`` runs the body and returns the + body's result identically to when OTEL is active.""" + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', False): + with traced_operation('op') as span: + assert isinstance(span, _NoopSpan) + result = sum(range(5)) + assert result == 10 + + +def test_property_21_noop_propagates_exceptions() -> None: + """NoOp path still re-raises the original exception unchanged.""" + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', False): + with pytest.raises(RuntimeError, match='boom'): + with traced_operation('op'): + raise RuntimeError('boom') + + +# ---------- Property 19/20: span lifecycle + exception handling ----------- # + + +def _otel_available() -> bool: + try: + from opentelemetry import trace as _otel_trace # noqa: F401 + except Exception: + return False + return True + + +@pytest.fixture(scope='module') +def in_memory_span_exporter(): + """Module-scoped exporter — set once because OTEL global tracer provider + is one-shot per process.""" + if not _otel_available(): + pytest.skip('OTEL SDK not installed in test env') + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + yield exporter + + +def test_property_19_span_lifecycle(in_memory_span_exporter) -> None: + """When OTEL is present, a span is started before and ended after the block.""" + in_memory_span_exporter.clear() + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', True): + with traced_operation('op.under.test', attrs={correlation.SESSION_ID: 's1'}): + pass + + spans = in_memory_span_exporter.get_finished_spans() + matches = [s for s in spans if s.name == 'op.under.test'] + assert matches + assert matches[-1].attributes.get(correlation.SESSION_ID) == 's1' + + +def test_property_20_exception_recorded_and_reraised(in_memory_span_exporter) -> None: + """Exception inside the block is recorded on the span and re-raised.""" + in_memory_span_exporter.clear() + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', True): + with pytest.raises(ValueError, match='boom'): + with traced_operation('op.exc'): + raise ValueError('boom') + + spans = [s for s in in_memory_span_exporter.get_finished_spans() if s.name == 'op.exc'] + assert spans, 'span was not exported' + span = spans[-1] + assert span.status.status_code.name == 'ERROR' + assert any('exception' in evt.name.lower() for evt in span.events) + + +# ---------- ResourceMetricsCollector wiring (R12.1, R12.2, R12.3, R18.4) -- # + + +def test_resource_metrics_collector_does_not_raise_without_pynvml() -> None: + """Collector starts cleanly even when pynvml/GPU is absent (R12.3).""" + from twinkle.server.telemetry import resource_metrics + + with mock.patch.object(resource_metrics, '_PYNVML_AVAILABLE', False): + collector = resource_metrics.ResourceMetricsCollector() + collector.maybe_start() # must not raise + # No GPU gauges registered when pynvml is absent. + assert all(not g.startswith('twinkle.gpu.') for g in collector.registered_gauges) + + +def test_resource_metrics_collector_registers_named_gauges_when_psutil_present() -> None: + from twinkle.server.telemetry import resource_metrics + + if not resource_metrics._PSUTIL_AVAILABLE: + pytest.skip('psutil not installed in test env') + collector = resource_metrics.ResourceMetricsCollector() + collector.maybe_start() + # System CPU + system memory + process memory always present when psutil is. + expected = { + 'twinkle.system.cpu.utilization', + 'twinkle.system.memory.usage_bytes', + 'twinkle.process.memory.usage_bytes', + } + assert expected.issubset(set(collector.registered_gauges)) + + +def test_resource_metrics_collector_idempotent() -> None: + from twinkle.server.telemetry import resource_metrics + + collector = resource_metrics.ResourceMetricsCollector() + collector.maybe_start() + pre = list(collector.registered_gauges) + collector.maybe_start() + assert collector.registered_gauges == pre + + +def test_worker_init_starts_collector() -> None: + """``ensure_telemetry_initialized`` calls the resource collector even when + telemetry is disabled — the collector silently records to a NoOp meter + in that case (R12.2).""" + from twinkle.server.telemetry import resource_metrics, worker_init + + # Force the worker_init guard to re-run and clear the global collector + # so we observe a fresh ``maybe_start`` call. + worker_init._worker_initialized = False + resource_metrics.reset_collector_for_tests() + + sentinel = mock.MagicMock() + sentinel.maybe_start = mock.MagicMock() + + with mock.patch.object(resource_metrics, 'get_collector', return_value=sentinel) as get_spy: + worker_init.ensure_telemetry_initialized() + + assert get_spy.call_count >= 1 + assert sentinel.maybe_start.call_count == 1 + + +def test_pyproject_declares_telemetry_extras() -> None: + """``pyproject.toml`` declares ``psutil`` and ``pynvml`` as telemetry extras (R12.4).""" + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[3] + text = (repo_root / 'pyproject.toml').read_text() + assert 'telemetry =' in text + assert 'psutil' in text + assert 'pynvml' in text From 81cc51f358254cabbe3d40242f7762488823a1ee Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:44:27 +0800 Subject: [PATCH 16/34] feat(server): typer CLI with launch-time config-drift validation (Phase 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the argparse __main__ with a typer-based operations CLI living in twinkle.server.cli. The CLI exposes four subcommands: - launch — start the server from a YAML config. Validates the persistence config signature against the persistence backend BEFORE ray.init so a configuration drift fails fast (R15.1). - check-config — exit 0 on a valid config, non-zero with the offending field/error on failure (R14.3, R14.4). - print-config — emit the validated, normalized ServerConfig as YAML or JSON; the JSON output round-trips back to an equal ServerConfig (R14.5). - clear persistence — delete persisted state for the namespace derived from a config (R14.2). Every option declares envvar= so env vars apply when the flag is omitted (R14.6). The new console script twinkle-server is registered under [project.scripts] and python -m twinkle.server delegates to the same typer entry point, so the documented launch path is one shim layer. Adds validate_against_backend in state/config_signature.py: builds the backend from a PersistenceConfig, computes the current signature, stores it on first run, and on mismatch raises ConfigMismatchError with a stored-vs-current diff and a remediation hint pointing at the clear-persistence subcommand (R15.2, R15.3, R15.4). Adds a fully documented example config at cookbook/client/server/server_config.example.yaml — every field carries its type, default, and available options. Loadable as-is via check-config. CLI tests cover subcommand existence, exit-code semantics, env-var override, print-config round-trip, the order-of-operations property (launch validates drift BEFORE ServerLauncher is even imported), and the drift detection / first-run-storage property (Property 29). The Phase 0a client-API contract baseline is re-run green. --- .../client/server/server_config.example.yaml | 162 +++++++++++++++++ pyproject.toml | 4 + src/twinkle/server/__main__.py | 118 ++----------- src/twinkle/server/cli/__init__.py | 5 + src/twinkle/server/cli/app.py | 167 ++++++++++++++++++ src/twinkle/server/state/config_signature.py | 58 +++++- tests/server/cli/__init__.py | 0 tests/server/cli/test_cli.py | 162 +++++++++++++++++ 8 files changed, 569 insertions(+), 107 deletions(-) create mode 100644 cookbook/client/server/server_config.example.yaml create mode 100644 src/twinkle/server/cli/__init__.py create mode 100644 src/twinkle/server/cli/app.py create mode 100644 tests/server/cli/__init__.py create mode 100644 tests/server/cli/test_cli.py diff --git a/cookbook/client/server/server_config.example.yaml b/cookbook/client/server/server_config.example.yaml new file mode 100644 index 000000000..468f217c6 --- /dev/null +++ b/cookbook/client/server/server_config.example.yaml @@ -0,0 +1,162 @@ +# ============================================================================= +# Twinkle Server — fully documented example configuration +# ============================================================================= +# +# This file is a reference: every field carries its type, default value, and +# the available options. Loadable as-is via: +# +# python -m twinkle.server check-config --config server_config.example.yaml +# python -m twinkle.server launch --config server_config.example.yaml +# +# Field naming after the Phase 0c refactor is strict: legacy aliases +# `telemetry_config` / `persistence_config` are no longer accepted (R8). Use +# `telemetry` / `persistence` instead. + +# Optional. Ray cluster namespace. +# Type: string | null. Default: null (resolves to "twinkle_cluster"). +# Env override: TWINKLE_RAY_NAMESPACE. +ray_namespace: twinkle_cluster + +# Optional. Ray Serve proxy placement. +# Type: string | null. Options: "EveryNode" (default for multi-node), "HeadOnly". +proxy_location: EveryNode + +# HTTP listener. +# host: str — bind address. Default "localhost". Use "0.0.0.0" to listen on all. +# port: int — TCP port. Default 8000. +http_options: + host: 0.0.0.0 + port: 8000 + +# Telemetry (OpenTelemetry pipeline). +# enabled: bool — when false, init/shutdown is a NoOp. Default false. +# debug: bool — true: console exporters; false: OTLP exporter. Default false. +# service_name: str — OTEL resource service.name. Default "twinkle-server". +# otlp_endpoint: str — gRPC OTLP endpoint. Default "http://localhost:4317". +# export_interval_ms: int — metric export interval in ms. Default 30000. +# resource_attributes: map[str, any] — extra OTEL Resource attributes. Default {}. +telemetry: + enabled: false + debug: false + service_name: twinkle-server + otlp_endpoint: http://localhost:4317 + export_interval_ms: 30000 + +# Persistence backend for ServerState (sessions, models, futures, ...). +# mode: str — "memory" | "file" | "redis". Default "memory". +# file_path: str — required when mode == "file". +# redis_url: str — required when mode == "redis", e.g. "redis://localhost:6379/0". +# key_prefix: str — optional global key prefix. Default "". +persistence: + mode: memory + # file_path: /tmp/twinkle_state.json + # redis_url: redis://localhost:6379/0 + # key_prefix: "" + +# Task queue / rate-limit defaults (overridable per application under args.queue_config). +# rps_limit: float >= 0 — requests/sec. 0 disables. Default 100.0. +# tps_limit: float >= 0 — input tokens/sec. 0 disables. Default 16000.0. +# window_seconds: float > 0 — rate-limit sliding window. Default 1.0. +# queue_timeout: float >= 0 — max queue wait (s). Default 300.0. +# execution_timeout: float >= 0 — task execution timeout (s). 0 disables. Default 120.0. +# enabled: bool — rate limiting on/off. Default true. +# token_cleanup_multiplier:float >= 0 — token retention multiplier. Default 10.0. +# token_cleanup_interval: float >= 0 — cleanup task interval (s). Default 60.0. +# max_input_tokens: int >= 1 — per-request input token cap. Default 16000. +task_queue: + rps_limit: 100.0 + tps_limit: 16000.0 + window_seconds: 1.0 + queue_timeout: 300.0 + execution_timeout: 120.0 + enabled: true + max_input_tokens: 16000 + +# Applications: each entry deploys one component (server | model | sampler | processor). +# Required fields per entry: +# name: str — Ray Serve app name. +# route_prefix:str — HTTP route prefix. Default "/". +# import_path: str — one of {server, model, sampler, processor}. +# args: map — typed args, schema selected by import_path. +# Optional: +# deployments: list — Ray Serve deployment options (only the first is used). +applications: + + # 1. Tinker-compatible gateway (server) + - name: server + route_prefix: /api/v1 + import_path: server + args: + # ServerArgs schema — fields are optional unless noted. + server_config: + per_token_model_limit: 3 + supported_models: + - mock-model + deployments: + - name: GatewayServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + + # 2. Model deployment. + # backend: str — required. Options: "mock" | "transformers" | "megatron". + # model_id: str — required. Model identifier (e.g. "Qwen/Qwen3.5-4B"). + # nproc_per_node: int — distributed processes per node. Default 1. + # device_group / device_mesh: dict — required parallelism config. + # max_loras: int — per-replica LoRA capacity. Default 5. + # queue_config: map — overrides task_queue defaults for this app. + - name: models + route_prefix: /api/v1/model + import_path: model + args: + backend: mock + model_id: mock-model + nproc_per_node: 1 + device_group: + name: model + ranks: 1 + device_type: cpu + device_mesh: + device_type: cpu + dp_size: 1 + max_loras: 5 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: ModelManagement + autoscaling_config: { min_replicas: 1, max_replicas: 1, target_ongoing_requests: 16 } + ray_actor_options: { num_cpus: 0.1 } + + # 3. Sampler deployment. + # sampler_type: str — required. Options: "mock" | "vllm" | "torch". + - name: sampler + route_prefix: /api/v1/sampler + import_path: sampler + args: + sampler_type: mock + model_id: mock-model + nproc_per_node: 1 + device_group: { name: sampler, ranks: 1, device_type: cpu } + device_mesh: { device_type: cpu, dp_size: 1 } + deployments: + - name: SamplerManagement + autoscaling_config: { min_replicas: 1, max_replicas: 1, target_ongoing_requests: 16 } + ray_actor_options: { num_cpus: 0.1 } + + # 4. Processor deployment (CPU-only feature engineering). + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: { name: processor, ranks: 2, device_type: cpu } + device_mesh: { device_type: cpu, dp_size: 2 } + deployments: + - name: ProcessorManagement + autoscaling_config: { min_replicas: 1, max_replicas: 1, target_ongoing_requests: 128 } + ray_actor_options: { num_cpus: 0.1 } diff --git a/pyproject.toml b/pyproject.toml index 652876ed0..d59470940 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,12 @@ dependencies = [ "safetensors", "peft>=0.11.0,<=0.19.0", "transformers", + "typer>=0.9.0", ] +[project.scripts] +twinkle-server = "twinkle.server.cli:main" + [project.optional-dependencies] transformers = [ "accelerate", diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index 8f97ef097..403033142 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -1,116 +1,22 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -""" -CLI entry point for Twinkle Server. +"""CLI entry point for Twinkle Server. + +Thin shim — delegates to the typer-based :mod:`twinkle.server.cli` so the +``python -m twinkle.server`` command and the ``twinkle-server`` console +script share one implementation. + +Usage:: -Usage: - # From config file - python -m twinkle.server --config server_config.yaml + python -m twinkle.server launch --config server_config.yaml + python -m twinkle.server check-config --config server_config.yaml + python -m twinkle.server print-config --config server_config.yaml + python -m twinkle.server clear persistence --config server_config.yaml """ from __future__ import annotations -import argparse -import os import sys -from pathlib import Path - -from twinkle import get_logger - -logger = get_logger() - - -def create_parser() -> argparse.ArgumentParser: - """Create the argument parser.""" - parser = argparse.ArgumentParser( - prog='python -m twinkle.server', - description='Twinkle Server Launcher - Unified launcher supporting both Tinker and Twinkle clients', - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Start server from YAML config file - python -m twinkle.server --config server_config.yaml - """, - ) - - # Config file option - parser.add_argument( - '-c', - '--config', - type=str, - required=True, - metavar='PATH', - help='Path to YAML configuration file (required)', - ) - - # Ray options - parser.add_argument( - '--namespace', - type=str, - metavar='NS', - help="Ray namespace (default: 'twinkle_cluster')", - ) - - # Runtime options - parser.add_argument( - '--log-level', - type=str, - default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], - metavar='LEVEL', - help='Logging level (default: INFO)', - ) - - return parser - - -def main(args: list[str] | None = None) -> int: - """ - Main entry point for the CLI. - - Args: - args: Command line arguments (uses sys.argv if None) - - Returns: - Exit code (0 for success, non-zero for error) - """ - parser = create_parser() - parsed_args = parser.parse_args(args) - - try: - from twinkle.server.launcher import launch_server - - # Apply log level so that all loggers (including those created later) - # pick up the user-specified level via the LOG_LEVEL env var that - # get_logger() already reads. - os.environ['LOG_LEVEL'] = parsed_args.log_level - - config_path = Path(parsed_args.config) - if not config_path.exists(): - logger.error(f'Config file not found: {config_path}') - return 1 - - launch_server( - config_path=config_path, - ray_namespace=parsed_args.namespace, - ) - - return 0 - except KeyboardInterrupt: - logger.info('Server stopped by user') - return 0 - except FileNotFoundError as e: - logger.error(f'File not found: {e}') - return 1 - except ValueError as e: - logger.error(f'Configuration error: {e}') - return 1 - except ImportError as e: - logger.error(f'Import error: {e}') - logger.error('Make sure all required dependencies are installed') - return 1 - except Exception as e: - logger.exception(f'Unexpected error: {e}') - return 1 +from twinkle.server.cli import main if __name__ == '__main__': diff --git a/src/twinkle/server/cli/__init__.py b/src/twinkle/server/cli/__init__.py new file mode 100644 index 000000000..4be1e0a1e --- /dev/null +++ b/src/twinkle/server/cli/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Twinkle Server CLI (typer).""" +from .app import app, main + +__all__ = ['app', 'main'] diff --git a/src/twinkle/server/cli/app.py b/src/twinkle/server/cli/app.py new file mode 100644 index 000000000..3f7893196 --- /dev/null +++ b/src/twinkle/server/cli/app.py @@ -0,0 +1,167 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Typer-based operations CLI (R14, R15). + +Provides four subcommands: + +- ``launch`` — start the Twinkle Server from a YAML config. + Validates the persistence config signature against + the persistence backend BEFORE ``ray.init`` so a + configuration drift fails fast (R15.1). +- ``check-config`` — validate a config file; exit 0 on success, non-zero + with the validation error on failure (R14.3, R14.4). +- ``print-config`` — emit the fully resolved + normalized ``ServerConfig`` + as YAML (R14.5). +- ``clear persistence``— delete persisted state for the namespace derived + from a config file (R14.2). +""" +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path +from typing import Optional + +import typer +import yaml + +from twinkle.server.config import ServerConfig +from twinkle.server.exceptions import ConfigMismatchError, ConfigParseError + +app = typer.Typer( + add_completion=False, + no_args_is_help=True, + help='Operations CLI for Twinkle Server.', +) + +clear_app = typer.Typer( + no_args_is_help=True, + help='Clear server-side state.', +) +app.add_typer(clear_app, name='clear') + + +CONFIG_OPTION = typer.Option( + ..., + '--config', + '-c', + envvar='TWINKLE_SERVER_CONFIG', + help='Path to the YAML configuration file.', + metavar='PATH', +) +NAMESPACE_OPTION = typer.Option( + None, + '--namespace', + envvar='TWINKLE_RAY_NAMESPACE', + help="Ray namespace (overrides ray_namespace in the config).", +) + + +def _load_config(path: Path) -> ServerConfig: + """Load + validate ``path``; print typed errors and exit non-zero on failure.""" + try: + return ServerConfig.from_yaml(path) + except FileNotFoundError as e: + typer.echo(f'error: {e}', err=True) + raise typer.Exit(code=2) + except ConfigParseError as e: + typer.echo(f'error: {e}', err=True) + raise typer.Exit(code=2) + except Exception as e: # pydantic.ValidationError + cross-field + typer.echo(f'error: invalid configuration\n{e}', err=True) + raise typer.Exit(code=2) + + +def _signature_payload(config: ServerConfig) -> dict: + """The dict whose hash drives signature validation. + + Restricted to the persistence-relevant fields so unrelated edits don't + invalidate the persisted state. Future phases can broaden this. + """ + return {'persistence': config.persistence.model_dump(mode='json')} + + +@app.command('launch') +def launch_cmd( + config: Path = CONFIG_OPTION, + namespace: Optional[str] = NAMESPACE_OPTION, +) -> None: + """Start the Twinkle Server from a YAML config file (R14.1, R15.1).""" + cfg = _load_config(config) + + # Validate the persistence config signature BEFORE we touch Ray (R15.1). + try: + from twinkle.server.state.config_signature import validate_against_backend + + asyncio.run(validate_against_backend(cfg.persistence, _signature_payload(cfg))) + except ConfigMismatchError as e: + typer.echo(f'error: {e}', err=True) + raise typer.Exit(code=3) + + # Defer the heavy launcher import until after drift validation passes so + # the failure path stays cheap (and a missing Ray install doesn't block + # `check-config`). + from twinkle.server.launcher import ServerLauncher + + launcher = ServerLauncher(config=cfg, ray_namespace=namespace or cfg.ray_namespace) + launcher.launch() + + +@app.command('check-config') +def check_config_cmd(config: Path = CONFIG_OPTION) -> None: + """Validate ``config`` and exit 0 on success, non-zero on failure (R14.3, R14.4).""" + _load_config(config) + typer.echo('ok') + + +@app.command('print-config') +def print_config_cmd( + config: Path = CONFIG_OPTION, + fmt: str = typer.Option('yaml', '--format', envvar='TWINKLE_PRINT_FORMAT', help='yaml|json'), +) -> None: + """Emit the validated, normalized ``ServerConfig`` (R14.5).""" + cfg = _load_config(config) + payload = cfg.to_yaml_dict() + if fmt == 'json': + typer.echo(json.dumps(payload, indent=2, sort_keys=True)) + else: + typer.echo(yaml.safe_dump(payload, sort_keys=True).rstrip()) + + +@clear_app.command('persistence') +def clear_persistence_cmd(config: Path = CONFIG_OPTION) -> None: + """Remove persisted state for the namespace derived from ``config`` (R14.2).""" + cfg = _load_config(config) + from twinkle.server.state.backend.factory import create_backend + + async def _clear() -> int: + backend = create_backend(cfg.persistence) + keys = await backend.keys('*') + removed = 0 + for k in keys: + await backend.delete(k) + removed += 1 + return removed + + n = asyncio.run(_clear()) + typer.echo(f'cleared {n} keys from persistence backend (mode={cfg.persistence.mode})') + + +def main(argv: list[str] | None = None) -> int: + """Programmatic entry point used by ``__main__.py`` and tests. + + Runs the typer app in standalone mode and converts its ``SystemExit`` + into a plain return code so callers can react without re-trapping. + """ + try: + app(args=argv, standalone_mode=True) + except SystemExit as exc: + code = exc.code + if code is None: + return 0 + return int(code) if not isinstance(code, int) else code + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/twinkle/server/state/config_signature.py b/src/twinkle/server/state/config_signature.py index 228c927ce..83c0bf565 100644 --- a/src/twinkle/server/state/config_signature.py +++ b/src/twinkle/server/state/config_signature.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any +from twinkle.server.exceptions import ConfigMismatchError from twinkle.server.state.backend.base import StateBackend logger = logging.getLogger(__name__) @@ -90,10 +91,65 @@ async def validate_config_signature( return False elif policy == SignatureMismatchPolicy.ABORT: - from twinkle.server.exceptions import ConfigMismatchError raise ConfigMismatchError( f"Configuration signature mismatch. " f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " f"Use policy='warn' or 'clear' to allow startup with changed config.") return False + + +# --------------------------------------------------------------------------- +# CLI startup hook (R15) +# --------------------------------------------------------------------------- + + +def _format_diff(stored: dict[str, Any] | None, current: dict[str, Any]) -> str: + """Render a stored-vs-current diff suitable for a remediation hint.""" + lines: list[str] = [] + keys = sorted(set((stored or {}).keys()) | set(current.keys())) + for k in keys: + s = (stored or {}).get(k, '') + c = current.get(k, '') + if s != c: + lines.append(f' - {k}: stored={s!r} current={c!r}') + return '\n'.join(lines) if lines else ' (no field-level diff — values differ at nested level)' + + +async def validate_against_backend(persistence_config: Any, current_config: dict[str, Any]) -> None: + """Validate the persistence config signature on launcher startup (R15). + + Builds a backend from ``persistence_config`` (a :class:`PersistenceConfig`), + computes the current signature, and compares it to the stored value: + - if no signature is stored, store the current one and continue (R15.4); + - if signatures match, return cleanly; + - if they differ, raise :class:`ConfigMismatchError` with a stored-vs- + current diff and a remediation hint (R15.2, R15.3). + + Designed to be called BEFORE ``ray.init`` so the launcher can fail fast + without spinning up the cluster (R15.1). + """ + from twinkle.server.state.backend.factory import create_backend + + backend = create_backend(persistence_config) + current_sig = compute_signature(current_config) + stored_sig = await backend.get(_SIGNATURE_KEY) + + if stored_sig is None: + await backend.set(_SIGNATURE_KEY, current_sig) + logger.info('No previous config signature found. Stored current signature.') + return + + if stored_sig == current_sig: + return + + stored_payload = await backend.get('_meta::config_payload') + diff = _format_diff(stored_payload if isinstance(stored_payload, dict) else None, current_config) + raise ConfigMismatchError( + 'Persistence configuration drifted since the last launch. ' + f'Stored signature: {stored_sig[:12]}..., current signature: {current_sig[:12]}...\n' + f'Differences:\n{diff}\n' + 'Remediation: either revert the persistence section to match the stored ' + 'value, or clear the persisted state with ' + '`python -m twinkle.server clear persistence --config ` and relaunch.' + ) diff --git a/tests/server/cli/__init__.py b/tests/server/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/cli/test_cli.py b/tests/server/cli/test_cli.py new file mode 100644 index 000000000..ef8cc021d --- /dev/null +++ b/tests/server/cli/test_cli.py @@ -0,0 +1,162 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 3 — typer CLI + config-drift validation tests (R14, R15, R16). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 29: Config-drift detection and first-run storage +""" +from __future__ import annotations + +import json +from pathlib import Path +from unittest import mock + +import pytest +import yaml +from typer.testing import CliRunner + +from twinkle.server.cli.app import app, main +from twinkle.server.config import ServerConfig +from twinkle.server.exceptions import ConfigMismatchError +from twinkle.server.state.backend.factory import PersistenceConfig +from twinkle.server.state.backend.memory_backend import MemoryBackend +from twinkle.server.state.config_signature import ( + _SIGNATURE_KEY, + compute_signature, + validate_against_backend, +) + + +REPO_ROOT = Path(__file__).resolve().parents[3] +EXAMPLE = REPO_ROOT / 'cookbook' / 'client' / 'server' / 'server_config.example.yaml' +MOCK_CFG = REPO_ROOT / 'cookbook' / 'client' / 'server' / 'mock' / 'server_config.yaml' + + +# ---------- 9.5 CLI subcommand existence + exit codes (R14.3, R14.4) ------ # + + +def test_subcommands_present() -> None: + runner = CliRunner() + out = runner.invoke(app, ['--help']) + assert out.exit_code == 0, out.output + for cmd in ('launch', 'check-config', 'print-config', 'clear'): + assert cmd in out.output + + +def test_check_config_exit_zero_on_valid() -> None: + runner = CliRunner() + res = runner.invoke(app, ['check-config', '--config', str(EXAMPLE)]) + assert res.exit_code == 0 + assert 'ok' in res.output + + +def test_check_config_nonzero_on_missing(tmp_path: Path) -> None: + runner = CliRunner() + p = tmp_path / 'nope.yaml' + res = runner.invoke(app, ['check-config', '--config', str(p)]) + assert res.exit_code != 0 + assert 'not found' in res.output.lower() + + +def test_check_config_nonzero_on_invalid(tmp_path: Path) -> None: + runner = CliRunner() + p = tmp_path / 'bad.yaml' + p.write_text('persistence: {mode: redis}\napplications: []\n') # missing redis_url + res = runner.invoke(app, ['check-config', '--config', str(p)]) + assert res.exit_code != 0 + assert 'invalid configuration' in res.output.lower() or 'redis_url' in res.output + + +# ---------- 9.5 print-config round-trip (R14.5) --------------------------- # + + +def test_print_config_round_trip(tmp_path: Path) -> None: + runner = CliRunner() + res = runner.invoke(app, ['print-config', '--config', str(EXAMPLE), '--format', 'json']) + assert res.exit_code == 0, res.output + payload = json.loads(res.output) + rebuilt = ServerConfig.model_validate(payload) + original = ServerConfig.from_yaml(EXAMPLE) + assert rebuilt == original + + +# ---------- 9.5 env-var override (R14.6) ---------------------------------- # + + +def test_env_var_overrides_when_flag_omitted(monkeypatch) -> None: + runner = CliRunner() + monkeypatch.setenv('TWINKLE_SERVER_CONFIG', str(EXAMPLE)) + res = runner.invoke(app, ['check-config']) + assert res.exit_code == 0 + + +# ---------- 9.5 launch validates drift BEFORE ray.init (R15.1) ------------ # + + +def test_launch_validates_drift_before_ray_init() -> None: + """Order check: ``validate_against_backend`` is called before + ``ServerLauncher`` is even imported (and thus before ray.init).""" + runner = CliRunner() + + def _abort_drift(*args, **kwargs): + raise ConfigMismatchError('drift sentinel') + + with mock.patch( + 'twinkle.server.state.config_signature.validate_against_backend', + side_effect=_abort_drift, + ): + # Should never reach the launcher import — patch it to a sentinel that + # would make the test fail loudly if reached. + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy: + res = runner.invoke(app, ['launch', '--config', str(MOCK_CFG)]) + assert res.exit_code == 3, res.output + assert 'drift sentinel' in res.output + assert launcher_spy.call_count == 0 + + +# ---------- Property 29: drift detection + first-run storage (R15.2/4) ---- # + + +@pytest.mark.asyncio +async def test_property_29_first_run_stores_signature() -> None: + """First run with no stored signature stores it and returns silently (R15.4).""" + backend = MemoryBackend() + cfg_payload = {'persistence': {'mode': 'memory'}} + pcfg = PersistenceConfig(mode='memory') + # Patch create_backend to return our shared in-process backend so we can + # inspect the stored signature afterwards. + with mock.patch( + 'twinkle.server.state.backend.factory.create_backend', return_value=backend + ): + await validate_against_backend(pcfg, cfg_payload) + assert await backend.get(_SIGNATURE_KEY) == compute_signature(cfg_payload) + # Second run with same payload is a no-op. + await validate_against_backend(pcfg, cfg_payload) + + +@pytest.mark.asyncio +async def test_property_29_drift_raises_with_diff_and_remediation() -> None: + backend = MemoryBackend() + pcfg = PersistenceConfig(mode='memory') + initial = {'persistence': {'mode': 'memory'}} + later = {'persistence': {'mode': 'file', 'file_path': '/tmp/x.json'}} + + with mock.patch( + 'twinkle.server.state.backend.factory.create_backend', return_value=backend + ): + await validate_against_backend(pcfg, initial) + with pytest.raises(ConfigMismatchError) as exc: + await validate_against_backend(pcfg, later) + + msg = str(exc.value) + assert 'drifted' in msg.lower() or 'mismatch' in msg.lower() + assert 'Remediation' in msg + + +# ---------- 9.8 example config loads (R16.3) ------------------------------ # + + +def test_example_config_loads_via_server_config() -> None: + cfg = ServerConfig.from_yaml(EXAMPLE) + assert isinstance(cfg, ServerConfig) + assert any(a.import_path == 'model' for a in cfg.applications) + assert any(a.import_path == 'sampler' for a in cfg.applications) From 65775ad8a035d24268e657b79e373917faddeb02 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:47:52 +0800 Subject: [PATCH 17/34] feat(server): trace context carrier for cross-deployment propagation (Phase 4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds make_carrier() and activate_carrier(carrier) in telemetry/context_carrier.py so internal Ray Serve DeploymentHandle calls can keep one trace continuous: the calling deployment serializes its active OTEL context into a small dict, the receiving deployment wraps its handler body in activate_carrier(...) and any spans it starts attach as children of the propagated context. When the OTEL SDK is missing, make_carrier returns an empty dict and activate_carrier becomes a no-op context manager, so the body always runs and never raises (R13.4 / R18.3). When the carrier is None or empty, activate_carrier also degrades to a no-op so the receiving side just starts a fresh trace. Adds Property 24 round-trip tests against an in-memory OTEL exporter showing parent.trace_id == child.trace_id when the carrier is honored, and that both sides are safe in the absence of OTEL or context. Refactors the telemetry test fixture into a session-scoped conftest because OTel's trace.set_tracer_provider is one-shot per process — the second per-module fixture would have silently shared the first one's exporter and made tests order-dependent. The Phase 0a client-API contract baseline is re-run green. Note: the LGTM single-trace-id fan-out integration test (task 10.4) requires the docker-compose stack and is deferred to the documentation phase. --- .../server/telemetry/context_carrier.py | 75 +++++++++++++++++ tests/server/telemetry/conftest.py | 37 ++++++++ .../server/telemetry/test_context_carrier.py | 84 +++++++++++++++++++ .../telemetry/test_tracing_and_correlation.py | 18 ---- 4 files changed, 196 insertions(+), 18 deletions(-) create mode 100644 src/twinkle/server/telemetry/context_carrier.py create mode 100644 tests/server/telemetry/conftest.py create mode 100644 tests/server/telemetry/test_context_carrier.py diff --git a/src/twinkle/server/telemetry/context_carrier.py b/src/twinkle/server/telemetry/context_carrier.py new file mode 100644 index 000000000..87dc6553b --- /dev/null +++ b/src/twinkle/server/telemetry/context_carrier.py @@ -0,0 +1,75 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Trace context carrier helpers for cross-deployment propagation (R13). + +``DeploymentHandle`` calls between Ray Serve deployments do not go through +HTTP, so the existing ``inject_context(headers)`` path doesn't apply. To keep +a single trace continuous across an internal handle hop, callers serialize +the active OpenTelemetry context into a small dict using :func:`make_carrier` +and pass it as a kwarg; the receiving side calls :func:`activate_carrier` +inside its handler so subsequent spans attach as children of the propagated +context. + +When the OTEL SDK is missing, both helpers degrade to a NoOp: ``make_carrier`` +returns an empty dict and ``activate_carrier`` is a no-op context manager +that runs the body without raising (R13.4 / R18.3). +""" +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Iterator, Mapping + +try: + from opentelemetry import context as _otel_context # type: ignore + from opentelemetry.propagate import extract as _otel_extract, inject as _otel_inject # type: ignore + + _OTEL_AVAILABLE = True +except Exception: + _OTEL_AVAILABLE = False + + +def make_carrier() -> dict[str, Any]: + """Inject the active trace context into a fresh dict (R13.1). + + Returns an empty dict when OTEL is not installed; the receiving side's + :func:`activate_carrier` handles that case gracefully. + """ + carrier: dict[str, Any] = {} + if not _OTEL_AVAILABLE: + return carrier + try: + _otel_inject(carrier) + except Exception: + # NEVER let a tracing failure block the underlying RPC. + return {} + return carrier + + +@contextmanager +def activate_carrier(carrier: Mapping[str, Any] | None) -> Iterator[None]: + """Attach the trace context from ``carrier`` for the lifetime of the block. + + Subsequent spans started inside the block become children of the + propagated context (R13.2). When ``carrier`` is ``None`` or empty, or + when OTEL is not installed, the block runs without attaching any + context and a fresh trace is started (R13.4). + """ + if not _OTEL_AVAILABLE or not carrier: + yield + return + + try: + ctx = _otel_extract(dict(carrier)) + except Exception: + yield + return + + token = None + try: + token = _otel_context.attach(ctx) + yield + finally: + if token is not None: + try: + _otel_context.detach(token) + except Exception: + pass diff --git a/tests/server/telemetry/conftest.py b/tests/server/telemetry/conftest.py new file mode 100644 index 000000000..5a9276bc8 --- /dev/null +++ b/tests/server/telemetry/conftest.py @@ -0,0 +1,37 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared OTEL setup for telemetry tests. + +OTel's global tracer provider is one-shot per process — the second +``trace.set_tracer_provider(...)`` call no-ops with a warning. So multiple +test modules that each tried to register their own provider would silently +share whichever one ran first, and tests that read spans from the wrong +exporter would fail. This fixture installs one provider + one in-memory +exporter for the entire telemetry test package. +""" +from __future__ import annotations + +import pytest + + +def _otel_available() -> bool: + try: + from opentelemetry import trace # noqa: F401 + except Exception: + return False + return True + + +@pytest.fixture(scope='session') +def in_memory_span_exporter(): + if not _otel_available(): + pytest.skip('OTEL SDK not installed in test env') + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + return exporter diff --git a/tests/server/telemetry/test_context_carrier.py b/tests/server/telemetry/test_context_carrier.py new file mode 100644 index 000000000..c02716141 --- /dev/null +++ b/tests/server/telemetry/test_context_carrier.py @@ -0,0 +1,84 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Trace context carrier round-trip tests (R13). + +# Feature: server-config-observability-refactor, Property 24: Trace context carrier round-trip +""" +from __future__ import annotations + +from unittest import mock + +import pytest + +from twinkle.server.telemetry import context_carrier +from twinkle.server.telemetry.context_carrier import activate_carrier, make_carrier + + +def _otel_available() -> bool: + try: + from opentelemetry import trace # noqa: F401 + except Exception: + return False + return True + + +# ---------- NoOp path (R13.4 / R18.3) ------------------------------------- # + + +def test_make_carrier_returns_empty_dict_when_otel_absent() -> None: + with mock.patch.object(context_carrier, '_OTEL_AVAILABLE', False): + assert make_carrier() == {} + + +def test_activate_carrier_with_none_is_safe_noop() -> None: + with activate_carrier(None): + pass + with activate_carrier({}): + pass + + +def test_activate_carrier_when_otel_absent_is_noop() -> None: + with mock.patch.object(context_carrier, '_OTEL_AVAILABLE', False): + with activate_carrier({'traceparent': 'whatever'}): + pass + + +# ---------- Property 24: round-trip (R13.1, R13.2) ------------------------ # + + +def test_property_24_carrier_round_trip(in_memory_span_exporter) -> None: + """Active context → make_carrier → activate_carrier → child span shares + the same trace id (R13.1, R13.2).""" + from opentelemetry import trace + + in_memory_span_exporter.clear() + tracer = trace.get_tracer('twinkle.test') + + carrier: dict[str, object] = {} + parent_trace_id: int | None = None + with tracer.start_as_current_span('caller') as parent: + parent_trace_id = parent.get_span_context().trace_id + carrier = make_carrier() + + # Sanity — the carrier carries something OTEL recognizes (traceparent). + assert carrier and any(k.lower() == 'traceparent' for k in carrier.keys()) + + # On the receiving side, activating the carrier and starting a span + # should yield a span whose trace id equals the parent's. + with activate_carrier(carrier): + with tracer.start_as_current_span('callee') as child: + child_trace_id = child.get_span_context().trace_id + + assert parent_trace_id == child_trace_id + + +def test_property_24_empty_carrier_starts_fresh_trace(in_memory_span_exporter) -> None: + """An empty / None carrier means: start a new trace (R13.4).""" + from opentelemetry import trace + + in_memory_span_exporter.clear() + tracer = trace.get_tracer('twinkle.test') + + with activate_carrier(None): + with tracer.start_as_current_span('orphan') as span: + tid = span.get_span_context().trace_id + assert tid != 0 diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py index 838be3f99..ff3a76383 100644 --- a/tests/server/telemetry/test_tracing_and_correlation.py +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -116,24 +116,6 @@ def _otel_available() -> bool: return True -@pytest.fixture(scope='module') -def in_memory_span_exporter(): - """Module-scoped exporter — set once because OTEL global tracer provider - is one-shot per process.""" - if not _otel_available(): - pytest.skip('OTEL SDK not installed in test env') - from opentelemetry import trace - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - - exporter = InMemorySpanExporter() - provider = TracerProvider() - provider.add_span_processor(SimpleSpanProcessor(exporter)) - trace.set_tracer_provider(provider) - yield exporter - - def test_property_19_span_lifecycle(in_memory_span_exporter) -> None: """When OTEL is present, a span is started before and ended after the block.""" in_memory_span_exporter.clear() From d55766ba5409ea7d48ebee1f274d3a033698f6f9 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 00:55:49 +0800 Subject: [PATCH 18/34] docs(server): observability + server-configuration guides (Phase 5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the documentation set the refactor has been building toward: - docs/source_en/Usage Guide/Observability.md + docs/source_zh/使用指引/可观测化.md Document the six twinkle.* correlation keys, the make_carrier / activate_carrier mechanism for cross-deployment trace propagation, and an end-to-end LGTM example using the cookbook/observability/ docker-compose stack (R17.1, R17.2, R11.4). - docs/source_zh/使用指引/服务配置.md ServerConfig field reference (every top-level + applications args schema), the supported environment variables (TWINKLE_SERVER_CONFIG, TWINKLE_RAY_NAMESPACE, telemetry / persistence env-var bag), a minimal YAML example, and a legacy → current field migration table covering telemetry_config → telemetry, persistence_config → persistence, and use_megatron → backend (R17.3, R8.3). Adds index links to both guides from docs/source_zh/index.rst and the Observability guide from docs/source_en/index.rst (R17.4). Adds tests/docs/test_docs_smoke.py asserting every required content element is present: all six correlation keys appear in both observability guides, the propagation section names DeploymentHandle / make_carrier / activate_carrier, the LGTM example references the docker-compose stack, the config guide lists every top-level field + the env vars + the YAML example + the migration table, and the index entries resolve. The Phase 0a client-API contract baseline is re-run green, and all 210 unit + property + contract tests pass in the twinkle conda env. --- docs/source_en/Usage Guide/Observability.md | 101 ++++++++++++++++ docs/source_en/index.rst | 1 + docs/source_zh/index.rst | 2 + ...57\350\247\202\346\265\213\345\214\226.md" | 94 +++++++++++++++ ...15\345\212\241\351\205\215\347\275\256.md" | 93 +++++++++++++++ tests/docs/test_docs_smoke.py | 109 ++++++++++++++++++ 6 files changed, 400 insertions(+) create mode 100644 docs/source_en/Usage Guide/Observability.md create mode 100644 "docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" create mode 100644 "docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" create mode 100644 tests/docs/test_docs_smoke.py diff --git a/docs/source_en/Usage Guide/Observability.md b/docs/source_en/Usage Guide/Observability.md new file mode 100644 index 000000000..01c64b910 --- /dev/null +++ b/docs/source_en/Usage Guide/Observability.md @@ -0,0 +1,101 @@ +# Observability + +Twinkle Server emits OpenTelemetry traces, metrics, and logs from every Ray +Serve deployment. This guide covers the standardized **correlation keys**, +the Ray Serve **trace-context propagation** mechanism, and an end-to-end +**LGTM** example using the Loki / Grafana / Tempo / Mimir docker-compose +stack shipped under `cookbook/observability/`. + +## Correlation keys + +Every business-layer span carries a subset of these attributes when the +corresponding identifier is known to the operation. All names share the +`twinkle.` prefix so you can filter Tempo / Loki by a single namespace. + +| Attribute | Set when the operation is associated with… | +|------------------------------|--------------------------------------------| +| `twinkle.session_id` | A client session | +| `twinkle.model_id` | A specific registered model | +| `twinkle.replica_id` | A specific Ray Serve replica | +| `twinkle.token_id` | A user authentication token | +| `twinkle.sampling_session_id`| A sampling session | +| `twinkle.base_model` | The base model behind a registered model | + +Constants live in `twinkle.server.telemetry.correlation`. Use +`set_correlation_attrs(span, {...})` to attach them — None values are +skipped, so partially-known operations never get empty attributes. + +```python +from twinkle.server.telemetry.correlation import ( + SESSION_ID, MODEL_ID, set_correlation_attrs, +) +from twinkle.server.telemetry.tracing import traced_operation + +with traced_operation('server_state.register_model', + attrs={SESSION_ID: sid, MODEL_ID: mid}): + ... +``` + +When the OpenTelemetry SDK is not installed, `traced_operation` becomes a +NoOp context manager: the body runs to completion and returns the same +result it would return when tracing is active. + +## Trace-context propagation across deployments + +The HTTP edge already injects context into outgoing headers in +`gateway/proxy.py`, and `create_tracing_middleware` extracts it on the +inbound side, so a Tinker request that passes through the Gateway proxy +shares one trace id end to end. + +The remaining gap is **Ray Serve `DeploymentHandle` calls** between +deployments — those don't go over HTTP. Use the trace-context carrier +helpers: + +```python +from twinkle.server.telemetry.context_carrier import make_carrier, activate_carrier + +# caller side (e.g. Model deployment) — pass the carrier with the call +carrier = make_carrier() +result = await sampler_handle.options(...).remote(payload, trace_context=carrier) + +# callee side (e.g. Sampler deployment handler) +async def handler(payload, trace_context: dict | None = None): + with activate_carrier(trace_context): + with traced_operation('sampler.handle'): + ... +``` + +`make_carrier()` returns an empty dict and `activate_carrier(None)` is a +no-op when OTel is missing or the carrier is empty, so the path stays +safe under graceful degradation. + +## End-to-end LGTM example + +The repository ships a docker-compose stack with Grafana, Tempo (traces), +Loki (logs), and Mimir (metrics) under `cookbook/observability/`. + +```bash +# 1. Start the LGTM stack. +docker compose -f cookbook/observability/docker-compose.yml up -d + +# 2. Launch the server with telemetry enabled. +cat > /tmp/srv.yaml <<'YAML' +telemetry: + enabled: true + service_name: twinkle-server + otlp_endpoint: http://localhost:4317 +persistence: { mode: memory } +applications: [] +YAML + +python -m twinkle.server launch --config /tmp/srv.yaml & + +# 3. Issue some traffic and open Grafana at http://localhost:3000. +# In Tempo, search by tag: `twinkle.session_id = `. +``` + +CPU / memory / GPU metrics show up automatically because the +`ResourceMetricsCollector` is started inside every Ray Serve worker by +`ensure_telemetry_initialized()`. When `psutil` or `pynvml` is missing +(or no GPU is present), the affected gauges report no data and the +worker keeps serving requests. diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index ef477f7fc..e04da376f 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -12,6 +12,7 @@ Twinkle DOCUMENTATION Usage Guide/Quick-Start.md Usage Guide/Installation.md Usage Guide/Server and Client/index.rst + Usage Guide/Observability.md Usage Guide/NPU-Support.md Usage Guide/Train-as-a-Service.md Usage Guide/Introduction-with-Qwen3.5.md diff --git a/docs/source_zh/index.rst b/docs/source_zh/index.rst index 3d07d4b2a..50f1f32cd 100644 --- a/docs/source_zh/index.rst +++ b/docs/source_zh/index.rst @@ -12,6 +12,8 @@ Twinkle DOCUMENTATION 使用指引/快速开始.md 使用指引/安装.md 使用指引/服务端和客户端/index.rst + 使用指引/服务配置.md + 使用指引/可观测化.md 使用指引/NPU的支持.md 使用指引/训练服务.md 使用指引/Qwen3.5最佳实践.md diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" new file mode 100644 index 000000000..9aeb3e245 --- /dev/null +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" @@ -0,0 +1,94 @@ +# 可观测化 + +Twinkle Server 在每个 Ray Serve 部署中都会发出 OpenTelemetry 追踪、指标与日志。 +本指南覆盖标准化的 **关联键**、Ray Serve 内的 **追踪上下文传递** 机制, +以及基于 `cookbook/observability/` 下 docker-compose 栈的端到端 +**LGTM**(Loki / Grafana / Tempo / Mimir)示例。 + +## 关联键(Correlation keys) + +业务层的每个 span 在已知对应标识符时都会附带下列属性。所有名称都使用 +`twinkle.` 前缀,便于在 Tempo / Loki 中按命名空间统一筛选。 + +| 属性 | 在以下场景下设置 | +|-------------------------------|--------------------------------------| +| `twinkle.session_id` | 关联到某个客户端 session | +| `twinkle.model_id` | 关联到某个已注册的模型 | +| `twinkle.replica_id` | 关联到某个 Ray Serve 副本 | +| `twinkle.token_id` | 关联到某个用户认证 token | +| `twinkle.sampling_session_id` | 关联到某个采样 session | +| `twinkle.base_model` | 关联到注册模型背后的 base model | + +常量定义在 `twinkle.server.telemetry.correlation`。通过 +`set_correlation_attrs(span, {...})` 一次性附加;None 值会被跳过, +部分已知的操作不会出现空属性。 + +```python +from twinkle.server.telemetry.correlation import ( + SESSION_ID, MODEL_ID, set_correlation_attrs, +) +from twinkle.server.telemetry.tracing import traced_operation + +with traced_operation('server_state.register_model', + attrs={SESSION_ID: sid, MODEL_ID: mid}): + ... +``` + +未安装 OpenTelemetry SDK 时,`traced_operation` 退化为 NoOp 上下文管理器: +代码块照常执行并返回与启用追踪时相同的结果。 + +## 跨部署的追踪上下文传递 + +HTTP 边界已经在 `gateway/proxy.py` 中将上下文注入到出站 header, +`create_tracing_middleware` 在入站侧提取——Tinker 经 Gateway 代理的请求在端到端 +共享同一个 trace id。 + +剩下的空缺是 **Ray Serve `DeploymentHandle` 内部调用**——这些调用不走 HTTP。 +使用追踪上下文 carrier 辅助函数: + +```python +from twinkle.server.telemetry.context_carrier import make_carrier, activate_carrier + +# 调用方(如 Model 部署)—— 把 carrier 一起传给被调方 +carrier = make_carrier() +result = await sampler_handle.options(...).remote(payload, trace_context=carrier) + +# 被调方(如 Sampler handler) +async def handler(payload, trace_context: dict | None = None): + with activate_carrier(trace_context): + with traced_operation('sampler.handle'): + ... +``` + +OTel 缺失或 carrier 为空时,`make_carrier()` 返回空字典, +`activate_carrier(None)` 是无操作的上下文管理器,调用路径在优雅降级时仍然安全。 + +## 端到端 LGTM 示例 + +本仓库在 `cookbook/observability/` 下提供一套 docker-compose 栈,包含 +Grafana、Tempo(traces)、Loki(logs)和 Mimir(metrics)。 + +```bash +# 1. 启动 LGTM 栈 +docker compose -f cookbook/observability/docker-compose.yml up -d + +# 2. 启用 telemetry 启动服务 +cat > /tmp/srv.yaml <<'YAML' +telemetry: + enabled: true + service_name: twinkle-server + otlp_endpoint: http://localhost:4317 +persistence: { mode: memory } +applications: [] +YAML + +python -m twinkle.server launch --config /tmp/srv.yaml & + +# 3. 发送一些请求,浏览器打开 http://localhost:3000 进入 Grafana +# 在 Tempo 中以 `twinkle.session_id = <你的 session>` 作为 tag 检索 +``` + +CPU / 内存 / GPU 指标会自动出现,因为 `ResourceMetricsCollector` 由 +`ensure_telemetry_initialized()` 在每个 Ray Serve worker 中启动。 +当 `psutil`、`pynvml` 缺失(或没有 GPU)时,对应 gauge 报告 no data, +worker 仍然继续服务请求。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" new file mode 100644 index 000000000..10b6bab8a --- /dev/null +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" @@ -0,0 +1,93 @@ +# 服务配置 + +Twinkle Server 的所有运行时配置都来自一个 Pydantic 聚合根 +`ServerConfig`,由 YAML 文件经 `ServerConfig.from_yaml(path)` 加载。 +本指南覆盖:字段一览、支持的环境变量、一份完整 YAML 示例,以及从旧字段 +名到当前字段名的迁移表。 + +## 字段总览 + +| 字段 (类型) | 默认值 | 含义 | +|-----------------------------|-----------------------------|------| +| `ray_namespace` (str\|null) | null → "twinkle_cluster" | Ray cluster namespace | +| `proxy_location` (str\|null)| null | Ray Serve proxy 部署策略,多机推荐 "EveryNode" | +| `http_options.host` (str) | "localhost" | HTTP 监听地址,多机部署设为 "0.0.0.0" | +| `http_options.port` (int) | 8000 | HTTP 监听端口 | +| `telemetry.enabled` (bool) | false | 是否启用 OpenTelemetry,关闭时整条管线 NoOp | +| `telemetry.debug` (bool) | false | true 控制台输出,false 走 OTLP | +| `telemetry.service_name` | "twinkle-server" | OTEL `service.name` | +| `telemetry.otlp_endpoint` | "http://localhost:4317" | OTLP gRPC 端点 | +| `telemetry.export_interval_ms`| 30000 | 指标导出间隔(毫秒) | +| `persistence.mode` (str) | "memory" | "memory" / "file" / "redis" | +| `persistence.file_path` | null | mode=file 时必填 | +| `persistence.redis_url` | null | mode=redis 时必填 | +| `persistence.key_prefix` | "" | 全局 key 前缀 | +| `task_queue.rps_limit` | 100.0 | 每用户 token 的 RPS,0 关闭 | +| `task_queue.tps_limit` | 16000.0 | 每用户 token 的 input tokens/秒,0 关闭 | +| `task_queue.window_seconds` | 1.0 | 滑动窗口宽度(秒),必须 > 0 | +| `task_queue.queue_timeout` | 300.0 | 任务排队最长等待(秒) | +| `task_queue.execution_timeout`| 120.0 | 任务执行最长(秒),0 关闭 | +| `task_queue.max_input_tokens`| 16000 | 单请求最大 input tokens | +| `applications` (list) | [] | 部署清单:每项含 `name` / `route_prefix` / `import_path` / `args` / `deployments` | + +应用条目的 `args` 模式由 `import_path` 决定: + +- `import_path=server` → `ServerArgs`(`server_config`、`supported_models`、`http_options`) +- `import_path=model` → `ModelArgs`(必填 `backend: mock|transformers|megatron`、`model_id`、`device_group`、`device_mesh`) +- `import_path=sampler`→ `SamplerArgs`(必填 `sampler_type: mock|vllm|torch`、`model_id`) +- `import_path=processor`→ `ProcessorArgs` + +## 支持的环境变量 + +CLI 选项均声明 `envvar=`,命令行未指定时回退到环境变量: + +| 选项 | 环境变量 | +|----------------------|-------------------------------| +| `--config / -c` | `TWINKLE_SERVER_CONFIG` | +| `--namespace` | `TWINKLE_RAY_NAMESPACE` | +| `--format`(print-config) | `TWINKLE_PRINT_FORMAT` | + +启动器额外读取(用于跨 Ray worker 传播): + +- `TWINKLE_TELEMETRY_ENABLED` / `_DEBUG` / `_SERVICE` / `_ENDPOINT` / `_INTERVAL` +- `TWINKLE_PERSISTENCE_MODE` / `_FILE_PATH` / `_REDIS_URL` / `_KEY_PREFIX` + +## 完整 YAML 示例 + +参见 [`cookbook/client/server/server_config.example.yaml`](https://github.com/modelscope/twinkle/blob/main/cookbook/client/server/server_config.example.yaml), +该文件每个字段都附带类型、默认值与可选项。最小可执行示例: + +```yaml +http_options: { host: 0.0.0.0, port: 8000 } +telemetry: { enabled: false } +persistence: { mode: memory } +applications: + - name: server + route_prefix: /api/v1 + import_path: server + args: { supported_models: [mock-model] } + - name: models + route_prefix: /api/v1/model + import_path: model + args: + backend: mock + model_id: mock-model + device_group: { name: model, ranks: 1, device_type: cpu } + device_mesh: { device_type: cpu, dp_size: 1 } +``` + +## 旧字段 → 当前字段迁移表 + +本次重构对运维侧字段名做了一次干净的破坏性变更。**不再支持旧名作为别名**, +请按下表更新 YAML: + +| 旧字段名 | 当前字段名 | 备注 | +|---------------------------|--------------------|------| +| `telemetry_config:` | `telemetry:` | 顶层 | +| `persistence_config:` | `persistence:` | 顶层 | +| 在 model `args` 中:`use_megatron: true` | `backend: megatron` | 模型后端切换 | +| 在 model `args` 中:`use_megatron: false` | `backend: transformers` | 模型后端切换 | +| 在 sampler `args` 中:(隐含 vllm) | `sampler_type: vllm` | 显式声明 | + +YAML 中保留旧名会触发 `pydantic.ValidationError`,并在错误消息中点出 +不被识别的字段,便于直接修复。 diff --git a/tests/docs/test_docs_smoke.py b/tests/docs/test_docs_smoke.py new file mode 100644 index 000000000..e86af5565 --- /dev/null +++ b/tests/docs/test_docs_smoke.py @@ -0,0 +1,109 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Smoke checks for the Phase 5 documentation set (R8.3, R11.4, R17).""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +# ---------- file presence ------------------------------------------------- # + + +OBSERVABILITY_EN = REPO_ROOT / 'docs' / 'source_en' / 'Usage Guide' / 'Observability.md' +OBSERVABILITY_ZH = REPO_ROOT / 'docs' / 'source_zh' / '使用指引' / '可观测化.md' +CONFIG_GUIDE_ZH = REPO_ROOT / 'docs' / 'source_zh' / '使用指引' / '服务配置.md' + +INDEX_EN = REPO_ROOT / 'docs' / 'source_en' / 'index.rst' +INDEX_ZH = REPO_ROOT / 'docs' / 'source_zh' / 'index.rst' + + +@pytest.mark.parametrize( + 'path', + [OBSERVABILITY_EN, OBSERVABILITY_ZH, CONFIG_GUIDE_ZH, INDEX_EN, INDEX_ZH], +) +def test_doc_exists(path: Path) -> None: + assert path.exists(), f'missing doc: {path}' + + +# ---------- observability guide content (R11.4, R17.1, R17.2) ------------ # + + +_CORRELATION_KEYS = ( + 'twinkle.session_id', + 'twinkle.model_id', + 'twinkle.replica_id', + 'twinkle.token_id', + 'twinkle.sampling_session_id', + 'twinkle.base_model', +) + + +@pytest.mark.parametrize('path', [OBSERVABILITY_EN, OBSERVABILITY_ZH]) +def test_observability_lists_all_correlation_keys(path: Path) -> None: + text = path.read_text() + for key in _CORRELATION_KEYS: + assert key in text, f'{path.name}: missing correlation key {key}' + + +@pytest.mark.parametrize('path', [OBSERVABILITY_EN, OBSERVABILITY_ZH]) +def test_observability_describes_propagation(path: Path) -> None: + text = path.read_text() + # Mentions the carrier helpers + the propagation surface. + assert 'make_carrier' in text + assert 'activate_carrier' in text + assert 'DeploymentHandle' in text + + +@pytest.mark.parametrize('path', [OBSERVABILITY_EN, OBSERVABILITY_ZH]) +def test_observability_has_lgtm_example(path: Path) -> None: + text = path.read_text() + assert 'docker compose' in text or 'docker-compose' in text + assert 'cookbook/observability' in text + + +# ---------- server-config guide content (R8.3, R17.3) -------------------- # + + +def test_config_guide_lists_top_level_fields() -> None: + text = CONFIG_GUIDE_ZH.read_text() + for field in ('telemetry', 'persistence', 'task_queue', 'applications', 'http_options'): + assert field in text + + +def test_config_guide_documents_envvars() -> None: + text = CONFIG_GUIDE_ZH.read_text() + assert 'TWINKLE_SERVER_CONFIG' in text + assert 'TWINKLE_RAY_NAMESPACE' in text + + +def test_config_guide_includes_yaml_example() -> None: + text = CONFIG_GUIDE_ZH.read_text() + assert 'applications:' in text + assert 'backend: mock' in text or 'backend:' in text + # Reference to the documented example file. + assert 'server_config.example.yaml' in text + + +def test_config_guide_has_migration_table() -> None: + text = CONFIG_GUIDE_ZH.read_text() + # Both legacy → current rows must be present (R8.3). + assert 'telemetry_config' in text and 'telemetry:' in text + assert 'persistence_config' in text and 'persistence:' in text + assert 'use_megatron' in text and 'backend:' in text + + +# ---------- index links (R17.4) ------------------------------------------ # + + +def test_index_zh_links_both_guides() -> None: + text = INDEX_ZH.read_text() + assert '可观测化.md' in text + assert '服务配置.md' in text + + +def test_index_en_links_observability() -> None: + text = INDEX_EN.read_text() + assert 'Observability.md' in text From db48a1c8b85b7d6bc0f8880a2a03f953af991d58 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 08:43:49 +0800 Subject: [PATCH 19/34] =?UTF-8?q?fix(server):=20address=20self-review=20fi?= =?UTF-8?q?ndings=20(Phase=200a=E2=80=935)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleans up the bugs / dead code surfaced by the post-implementation review: 1. Cleanup task scheduling — removes the asyncio.get_running_loop() hack in get_server_state(): every Ray Serve worker's FastAPI lifespan now awaits state.start_cleanup_task() explicitly. Resource expiry actually runs again (previously every call site was sync ctor → no loop → loop never started). Wired in gateway/model/sampler/processor lifespans; start_cleanup_task is idempotent so repeat calls are no-ops. 2. ApplicationSpec — model/sampler entries with no args block now raise with the offending field path instead of silently substituting a ServerArgs() default. The mode='before' validator routes the raw args (or {}) through the schema selected by import_path so missing required fields surface cleanly. ApplicationSpec.args lost its silent default; server/processor (whose schemas are all-optional) still accept bare entries. 3. Grafana dashboard (R12.5) — adds CPU utilization, system + process memory, GPU utilization, and GPU memory panels to twinkle-overview.json wired to the metric names the ResourceMetricsCollector exports. Adds a regression test covering the panel titles and target metric names. 4. Nested extra='forbid' — TelemetryConfig and PersistenceConfig now reject unknown keys, so typos inside `telemetry: {...}` / `persistence: {...}` fail at load time instead of silently being dropped. Adds a parametrized regression test. 5. Validation before side effects (R3.9, R3.10) — splits each dispatch into _validate_* (pure, no imports) and _dispatch_* (assumes validated input). Both ModelManagement.__init__ / SamplerManagement.__init__ and the build_*_app entry points call _validate_* up front, so an invalid backend / sampler_type never reaches twinkle.initialize, DeviceGroup construction, or any backend import. 6. Dead code — drops the unused _BACKEND_VALUES / _SAMPLER_TYPE_VALUES constants in application_spec.py and the dead exception branch around the old loop.create_task call in get_server_state. 7. use_megatron legacy bridge — removed from ModelManagement.__init__, build_model_app, and the .bind() call. backend is the canonical selector; the only remaining mention in repo lives in the tasks-doc migration table. 9. Stale ServerState docstring — updated to reflect direct-backend access. 11. launcher.py — single top-level `import os` instead of four duplicated local imports. Test surface goes 210 → 213 (added: nested-config extras, dashboard panels, refactored validation tests). All 213 unit + property + contract tests pass and 11 end-to-end smoke checks (cookbook YAMLs, CLI exit codes, print-config round-trip, mock determinism, dispatch validation, contract baseline, cleanup-task lifecycle, ApplicationSpec strictness, nested-config strictness) pass clean. --- .../grafana/dashboards/twinkle-overview.json | 281 ++++++++++++++++-- src/twinkle/server/config/application_spec.py | 44 +-- src/twinkle/server/gateway/server.py | 5 + src/twinkle/server/launcher.py | 5 +- src/twinkle/server/model/app.py | 57 ++-- src/twinkle/server/processor/app.py | 11 +- src/twinkle/server/sampler/app.py | 47 ++- src/twinkle/server/state/backend/factory.py | 5 +- src/twinkle/server/state/server_state.py | 26 +- src/twinkle/server/telemetry/provider.py | 4 +- tests/server/config/test_server_config.py | 10 + tests/server/model/test_mock_model.py | 15 +- tests/server/sampler/test_mock_sampler.py | 14 +- .../telemetry/test_tracing_and_correlation.py | 30 ++ 14 files changed, 438 insertions(+), 116 deletions(-) diff --git a/cookbook/observability/grafana/dashboards/twinkle-overview.json b/cookbook/observability/grafana/dashboards/twinkle-overview.json index ab609e813..1d7251b4e 100644 --- a/cookbook/observability/grafana/dashboards/twinkle-overview.json +++ b/cookbook/observability/grafana/dashboards/twinkle-overview.json @@ -1,5 +1,7 @@ { - "annotations": {"list": []}, + "annotations": { + "list": [] + }, "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 1, @@ -8,9 +10,24 @@ "liveNow": false, "panels": [ { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "fieldConfig": {"defaults": {"color": {"mode": "palette-classic"}, "unit": "reqps"}}, - "gridPos": {"h": 8, "w": 12, "x": 0, "y": 0}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "unit": "reqps" + } + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, "id": 1, "title": "HTTP request rate (per deployment)", "type": "timeseries", @@ -23,9 +40,21 @@ ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "fieldConfig": {"defaults": {"unit": "s"}}, - "gridPos": {"h": 8, "w": 12, "x": 12, "y": 0}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "s" + } + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, "id": 2, "title": "HTTP latency P95 (per deployment)", "type": "timeseries", @@ -38,21 +67,53 @@ ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "gridPos": {"h": 8, "w": 12, "x": 0, "y": 8}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, "id": 3, "title": "Active resources", "type": "timeseries", "targets": [ - {"expr": "twinkle_sessions_active", "legendFormat": "sessions", "refId": "A"}, - {"expr": "twinkle_models_active", "legendFormat": "models", "refId": "B"}, - {"expr": "twinkle_sampling_sessions_active", "legendFormat": "sampling sessions", "refId": "C"}, - {"expr": "twinkle_futures_active", "legendFormat": "futures", "refId": "D"} + { + "expr": "twinkle_sessions_active", + "legendFormat": "sessions", + "refId": "A" + }, + { + "expr": "twinkle_models_active", + "legendFormat": "models", + "refId": "B" + }, + { + "expr": "twinkle_sampling_sessions_active", + "legendFormat": "sampling sessions", + "refId": "C" + }, + { + "expr": "twinkle_futures_active", + "legendFormat": "futures", + "refId": "D" + } ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "gridPos": {"h": 8, "w": 12, "x": 12, "y": 8}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, "id": 4, "title": "Task queue depth (per deployment)", "type": "timeseries", @@ -65,9 +126,21 @@ ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "fieldConfig": {"defaults": {"unit": "s"}}, - "gridPos": {"h": 8, "w": 12, "x": 0, "y": 16}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "s" + } + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, "id": 5, "title": "Task execution P95", "type": "timeseries", @@ -80,9 +153,21 @@ ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "fieldConfig": {"defaults": {"unit": "s"}}, - "gridPos": {"h": 8, "w": 12, "x": 12, "y": 16}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "s" + } + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, "id": 6, "title": "Task wait time P95", "type": "timeseries", @@ -95,8 +180,16 @@ ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "gridPos": {"h": 8, "w": 12, "x": 0, "y": 24}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, "id": 7, "title": "Rate-limit rejections", "type": "timeseries", @@ -109,8 +202,16 @@ ] }, { - "datasource": {"type": "prometheus", "uid": "prometheus"}, - "gridPos": {"h": 8, "w": 12, "x": 12, "y": 24}, + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, "id": 8, "title": "Task completions by status", "type": "timeseries", @@ -121,13 +222,137 @@ "refId": "A" } ] + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "percentunit" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 32 + }, + "id": 9, + "title": "CPU utilization", + "type": "timeseries", + "targets": [ + { + "expr": "twinkle_system_cpu_utilization", + "legendFormat": "{{instance}}", + "refId": "A" + } + ] + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 32 + }, + "id": 10, + "title": "Memory utilization (system)", + "type": "timeseries", + "targets": [ + { + "expr": "twinkle_system_memory_usage_bytes", + "legendFormat": "system used", + "refId": "A" + }, + { + "expr": "twinkle_process_memory_usage_bytes", + "legendFormat": "{{instance}} process RSS", + "refId": "B" + } + ] + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "percentunit" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 40 + }, + "id": 11, + "title": "GPU utilization", + "type": "timeseries", + "targets": [ + { + "expr": "twinkle_gpu_utilization", + "legendFormat": "gpu {{gpu_index}}", + "refId": "A" + } + ] + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 40 + }, + "id": 12, + "title": "GPU memory used", + "type": "timeseries", + "targets": [ + { + "expr": "twinkle_gpu_memory_usage_bytes", + "legendFormat": "gpu {{gpu_index}}", + "refId": "A" + } + ] } ], "refresh": "10s", "schemaVersion": 39, - "tags": ["twinkle"], - "templating": {"list": []}, - "time": {"from": "now-1h", "to": "now"}, + "tags": [ + "twinkle" + ], + "templating": { + "list": [] + }, + "time": { + "from": "now-1h", + "to": "now" + }, "timepicker": {}, "timezone": "", "title": "Twinkle Server Overview", diff --git a/src/twinkle/server/config/application_spec.py b/src/twinkle/server/config/application_spec.py index 93d6f1e57..71c7fe874 100644 --- a/src/twinkle/server/config/application_spec.py +++ b/src/twinkle/server/config/application_spec.py @@ -43,10 +43,6 @@ class HttpOptions(BaseModel): # ---------- per-deployment args schemas (R3.x) ----------------------------- # -_BACKEND_VALUES: tuple[str, ...] = ('mock', 'transformers', 'megatron') -_SAMPLER_TYPE_VALUES: tuple[str, ...] = ('mock', 'vllm', 'torch') - - class ModelArgs(_ArgsBase): """Args for the ``model`` deployment. @@ -115,7 +111,11 @@ class ApplicationSpec(BaseModel): The ``args`` block is validated against the schema selected by ``import_path``: ``server`` → ``ServerArgs``, ``model`` → ``ModelArgs``, - etc. Unknown keys at this level (or inside ``args``) are rejected. + etc. Unknown keys at this level (or inside ``args``) are rejected. A + missing ``args`` block is treated as ``{}`` and validated against the + matching schema, so any required field (e.g. ``backend`` on a model + deployment) raises with the offending field path instead of silently + falling back to a different schema's default. """ model_config = ConfigDict(extra='forbid') @@ -123,25 +123,35 @@ class ApplicationSpec(BaseModel): name: str route_prefix: str = '/' import_path: Literal['server', 'model', 'sampler', 'processor'] - args: ServerArgs | ModelArgs | SamplerArgs | ProcessorArgs = Field( - default_factory=lambda: ServerArgs(), - ) + # ``args`` is filled in by the ``mode='before'`` validator below; the + # default of ``{}`` is only meaningful for kinds whose schema has no + # required fields (currently ``server``). + args: ServerArgs | ModelArgs | SamplerArgs | ProcessorArgs = Field(default=None) # type: ignore[assignment] deployments: list[dict[str, Any]] = Field(default_factory=list) @model_validator(mode='before') @classmethod def _coerce_args_to_schema(cls, data: Any) -> Any: - """Validate the raw ``args`` dict against the schema for ``import_path``. + """Validate the raw ``args`` block against the schema for ``import_path``. - Pydantic's union resolution would also work here, but keying off - ``import_path`` makes the failure messages point at the right schema - and avoids ambiguity when two schemas share field names. + Keying off ``import_path`` makes the failure messages point at the + right schema and avoids the ambiguity Pydantic's structural Union + resolution would introduce when two schemas share field names. """ if not isinstance(data, dict): return data import_path = data.get('import_path') - args = data.get('args') - if import_path in _ARGS_SCHEMA and isinstance(args, dict): - schema = _ARGS_SCHEMA[import_path] - data = {**data, 'args': schema.model_validate(args)} - return data + if import_path not in _ARGS_SCHEMA: + # Let Pydantic's Literal validator handle bad import_path values. + return data + schema = _ARGS_SCHEMA[import_path] + raw_args = data.get('args') + if isinstance(raw_args, schema): + return data + if raw_args is None: + raw_args = {} + if not isinstance(raw_args, dict): + # Non-dict, non-instance args is a hard schema error — surface + # it through the matching schema for a clean error message. + raw_args = dict(raw_args) if hasattr(raw_args, 'keys') else raw_args + return {**data, 'args': schema.model_validate(raw_args)} diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index a8478512f..02899866e 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -99,6 +99,11 @@ async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() + # Start the ServerState cleanup loop now that we have a running loop. + try: + await get_self().state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') yield try: await get_self().proxy.close() diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 3e839f00f..bb35b774b 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -21,6 +21,7 @@ """ from __future__ import annotations +import os import signal import threading from pathlib import Path @@ -94,7 +95,6 @@ def _build_telemetry_env_vars(self) -> dict[str, str]: These vars are read by ``ensure_telemetry_initialized()`` inside the FastAPI startup hook running in each worker process. """ - import os return {k: os.environ[k] for k in self._TELEMETRY_ENV_KEYS if k in os.environ} def _build_persistence_env_vars(self) -> dict[str, str]: @@ -104,7 +104,6 @@ def _build_persistence_env_vars(self) -> dict[str, str]: worker that calls ``get_server_state()`` without an explicit config, which makes the chosen backend independent of deployment startup order. """ - import os from twinkle.server.state.backend.factory import PERSISTENCE_ENV_KEYS return {k: os.environ[k] for k in PERSISTENCE_ENV_KEYS if k in os.environ} @@ -284,7 +283,6 @@ def launch(self) -> None: from twinkle.server.telemetry import init_telemetry init_telemetry(telemetry) # Export config to env vars for Ray worker processes - import os os.environ['TWINKLE_TELEMETRY_ENABLED'] = '1' os.environ['TWINKLE_TELEMETRY_DEBUG'] = '1' if telemetry.debug else '0' os.environ['TWINKLE_TELEMETRY_SERVICE'] = telemetry.service_name @@ -295,7 +293,6 @@ def launch(self) -> None: # (not just Gateway) can build the same backend on first call to # get_server_state(). persistence = self.config.persistence - import os for k, v in persistence.to_env_vars().items(): os.environ[k] = v logger.info(f'Persistence backend configured: mode={persistence.mode}') diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 101061b3a..510812846 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -34,17 +34,20 @@ _MODEL_BACKENDS: tuple[str, ...] = ('mock', 'transformers', 'megatron') -def _dispatch_model_backend(backend: str, ctor_kwargs: dict[str, Any]) -> Any: - """Instantiate the model backend selected by ``backend`` (R3.1-3.3, R3.9). +def _validate_model_backend(backend: Any) -> str: + """Pure validation of the ``backend`` selector (R3.9). - Raises :class:`ConfigError` *before* instantiating any backend if the - value is missing, empty, or not in the permitted set, so the deployment - never reaches a ready state with a bad backend. + Raises :class:`ConfigError` (naming the field, value, and allowed set) + when ``backend`` is missing, empty, non-string, or not exactly one of + the permitted values. No imports or side effects. """ - if backend is None or not isinstance(backend, str) or backend == '': - raise ConfigError(field='backend', value=backend, allowed=list(_MODEL_BACKENDS)) - if backend not in _MODEL_BACKENDS: + if not isinstance(backend, str) or backend == '' or backend not in _MODEL_BACKENDS: raise ConfigError(field='backend', value=backend, allowed=list(_MODEL_BACKENDS)) + return backend + + +def _dispatch_model_backend(backend: str, ctor_kwargs: dict[str, Any]) -> Any: + """Instantiate the model backend selected by an already-validated ``backend``.""" if backend == 'mock': from .backends.mock_model import TwinkleCompatMockModel @@ -74,18 +77,15 @@ def __init__(self, nproc_per_node: int, device_group: dict[str, Any], device_mesh: dict[str, Any], - use_megatron: bool = False, + backend: str, adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, - backend: str | None = None, **kwargs): - # Backwards-compatible bridge: legacy callers passed ``use_megatron`` — - # derive ``backend`` from it when not supplied so we always go through - # the strict dispatch below. - if backend is None: - backend = 'megatron' if use_megatron else 'transformers' + # R3.9: validate ``backend`` BEFORE any side effect (twinkle.initialize, + # DeviceGroup construction, replica registration). An invalid value + # never produces a partial backend nor reaches a ready state. + backend = _validate_model_backend(backend) self.backend = backend - self.use_megatron = backend == 'megatron' # Skip twinkle.initialize for the mock backend (R3.7) — the largest # startup-time saving and the only way to start without CUDA/torch. if backend != 'mock': @@ -107,8 +107,6 @@ def __init__(self, self.max_loras = kwargs.get('max_loras', 5) self.base_model = model_id - # Choose model backend (R3.1-3.3, R3.9). Validation runs before - # instantiation so an invalid value never produces a partial backend. ctor_kwargs: dict[str, Any] = {'model_id': model_id, **kwargs} if backend != 'mock': ctor_kwargs.update( @@ -172,10 +170,9 @@ def build_model_app(model_id: str, device_group: dict[str, Any], device_mesh: dict[str, Any], deploy_options: dict[str, Any], - use_megatron: bool = False, + backend: str, adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, - backend: str | None = None, **kwargs): """Build a unified model management application for distributed training. @@ -187,7 +184,9 @@ def build_model_app(model_id: str, device_group: Device group configuration dict device_mesh: Device mesh configuration dict for tensor parallelism deploy_options: Ray Serve deployment options - use_megatron: Whether to use Megatron backend (vs Transformers) + backend: Model backend selector — ``mock`` | ``transformers`` | ``megatron`` + (R3.1-3.3, R3.9). Validated up front; bad values raise + :class:`ConfigError` before any side effect. adapter_config: Adapter lifecycle config (timeout, per-token limits) queue_config: Task queue configuration (rate limiting, etc.) **kwargs: Additional model initialization arguments @@ -195,6 +194,9 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ + # Fail fast on bad backend values at builder time (the launcher imports + # this builder at startup, so the error surfaces before deployment). + backend = _validate_model_backend(backend) # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). @@ -207,6 +209,12 @@ async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() + # Start the ServerState cleanup loop now that we have a running loop; + # idempotent across replicas in the same process. + try: + await get_self().state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') try: await get_self()._ensure_replica_registered() except Exception as e: @@ -238,9 +246,10 @@ async def verify_token(request: Request, call_next): request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter), )( ModelManagementWithIngress) - return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, - use_megatron, adapter_config, queue_config, - backend=backend, **kwargs) + return DeploymentClass.options(**deploy_options).bind( + model_id, nproc_per_node, device_group, device_mesh, backend, + adapter_config, queue_config, **kwargs, + ) build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 20976da84..ec418fe90 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -122,11 +122,19 @@ def build_processor_app(ncpu_proc_per_node: int, # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + def get_self() -> ProcessorManagement: + return serve.get_replica_context().servable_object + @asynccontextmanager async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() + # Start the ServerState cleanup loop now that we have a running loop. + try: + await get_self().state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') yield app = FastAPI(lifespan=lifespan) @@ -141,9 +149,6 @@ async def verify_token(request: Request, call_next): app.middleware('http')(create_tracing_middleware('Processor')) app.middleware('http')(create_metrics_middleware('Processor')) - def get_self() -> ProcessorManagement: - return serve.get_replica_context().servable_object - _register_processor_routes(app, get_self) ProcessorManagementWithIngress = serve.ingress(app)(ProcessorManagement) diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 9877c68b4..84a55c27e 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -31,12 +31,26 @@ _SAMPLER_TYPES: tuple[str, ...] = ('mock', 'vllm', 'torch') +def _validate_sampler_type(sampler_type: Any) -> str: + """Pure validation of the ``sampler_type`` selector (R3.10). + + Raises :class:`ConfigError` (naming the field, value, and allowed set) + when the value is missing, empty, non-string, or not exactly one of the + permitted values. No imports or side effects. + """ + if ( + not isinstance(sampler_type, str) + or sampler_type == '' + or sampler_type not in _SAMPLER_TYPES + ): + raise ConfigError( + field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES) + ) + return sampler_type + + def _dispatch_sampler_backend(sampler_type: str, ctor_kwargs: dict[str, Any]) -> Any: - """Instantiate the sampler selected by ``sampler_type`` (R3.4-3.6, R3.10).""" - if sampler_type is None or not isinstance(sampler_type, str) or sampler_type == '': - raise ConfigError(field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES)) - if sampler_type not in _SAMPLER_TYPES: - raise ConfigError(field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES)) + """Instantiate the sampler selected by an already-validated ``sampler_type``.""" if sampler_type == 'mock': from .backends.mock_sampler import MockSampler @@ -70,10 +84,12 @@ def __init__(self, nproc_per_node: int, device_group: dict[str, Any], device_mesh: dict[str, Any], - sampler_type: str = 'vllm', + sampler_type: str, engine_args: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): + # R3.10: validate ``sampler_type`` BEFORE any side effect. + sampler_type = _validate_sampler_type(sampler_type) # Skip twinkle.initialize for the mock backend (R3.8) — start without # CUDA/torch/vllm. if sampler_type != 'mock': @@ -131,7 +147,7 @@ def build_sampler_app(model_id: str, device_group: dict[str, Any], device_mesh: dict[str, Any], deploy_options: dict[str, Any], - sampler_type: str = 'vllm', + sampler_type: str, engine_args: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): @@ -146,7 +162,9 @@ def build_sampler_app(model_id: str, device_group: Device group configuration dict device_mesh: Device mesh configuration dict for parallelism deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') + sampler_type: Sampler selector — ``mock`` | ``vllm`` | ``torch`` (R3.4-3.6, + R3.10). Validated up front; bad values raise :class:`ConfigError` + before any side effect. engine_args: Additional engine arguments for the sampler queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) **kwargs: Additional arguments passed to the sampler @@ -154,14 +172,24 @@ def build_sampler_app(model_id: str, Returns: Ray Serve deployment bound with configuration """ + # Fail fast at builder time on bad sampler_type values. + sampler_type = _validate_sampler_type(sampler_type) # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + def get_self() -> SamplerManagement: + return serve.get_replica_context().servable_object + @asynccontextmanager async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() + # Start the ServerState cleanup loop now that we have a running loop. + try: + await get_self().state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') yield app = FastAPI( @@ -180,9 +208,6 @@ async def verify_token(request: Request, call_next): app.middleware('http')(create_tracing_middleware('Sampler')) app.middleware('http')(create_metrics_middleware('Sampler')) - def get_self() -> SamplerManagement: - return serve.get_replica_context().servable_object - # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time _register_tinker_sampler_routes(app, get_self) _register_twinkle_sampler_routes(app, get_self) diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 4a37ecedc..702fcb4bc 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -5,7 +5,7 @@ import os from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from .base import StateBackend from .memory_backend import MemoryBackend @@ -26,6 +26,9 @@ class PersistenceConfig(BaseModel): """Configuration for state persistence backend.""" + + model_config = ConfigDict(extra='forbid') + mode: Literal['memory', 'file', 'redis'] = 'memory' file_path: str | None = None # required for file mode redis_url: str | None = None # required for redis mode diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index e76604a3b..49a332cd7 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -43,7 +43,10 @@ class ServerState: - FutureManager — async task futures - ConfigManager — key-value configuration - All methods are designed to be used with Ray actors for distributed state. + Bound directly to a shared :class:`StateBackend` (R19); no detached + Ray Actor is created. Each Ray Serve worker owns one process-local + instance and the cleanup loop is started from the deployment's + FastAPI ``lifespan`` startup hook. """ def __init__( @@ -538,23 +541,10 @@ def get_server_state( **kwargs, ) _PROCESS_STATE_CACHE[actor_name] = state - - # Start the cleanup loop best-effort. The loop runs inside an asyncio task, - # so ``start_cleanup_task`` must be awaited from within a running loop. - # Schedule it lazily — when there's no running loop yet (sync constructor - # path), skip and let the first awaited handler call start it later. - try: - loop = asyncio.get_running_loop() - loop.create_task(state.start_cleanup_task()) - except RuntimeError: - # No running loop — handler bodies will start the loop on first await. - pass - except Exception as e: - cause = e.__cause__ if hasattr(e, '__cause__') and e.__cause__ else e - if isinstance(cause, ConfigMismatchError) or 'ConfigMismatchError' in type(e).__name__: - raise - logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') - + # Cleanup task is started by the deployment's FastAPI ``lifespan`` hook + # via ``await state.start_cleanup_task()`` — that's the single async + # entry point each worker has, so we don't need any sync-context + # detection here. return state diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index d029d85c1..ac3e42f78 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -15,7 +15,7 @@ import logging from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) @@ -110,6 +110,8 @@ def flush(self) -> None: class TelemetryConfig(BaseModel): """Configuration for the OpenTelemetry pipeline.""" + model_config = ConfigDict(extra='forbid') + enabled: bool = False service_name: str = "twinkle-server" otlp_endpoint: str = "http://localhost:4317" diff --git a/tests/server/config/test_server_config.py b/tests/server/config/test_server_config.py index f46f1c4f3..d84163db0 100644 --- a/tests/server/config/test_server_config.py +++ b/tests/server/config/test_server_config.py @@ -172,6 +172,16 @@ def test_property_15_unknown_field_rejected(unknown: str) -> None: ServerConfig.model_validate({unknown: 'x'}) +@pytest.mark.parametrize('section', ['telemetry', 'persistence']) +def test_property_15_unknown_nested_field_rejected(section: str) -> None: + """Nested config sections also reject unknown keys (defends against typos + inside ``telemetry: {...}`` / ``persistence: {...}``).""" + payload = {section: {'unknown_typo': 1}} + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate(payload) + assert any('unknown_typo' in err['loc'] for err in exc.value.errors()) + + # ---------- 3.11: from_yaml error paths + launcher dict rejection ---------- # diff --git a/tests/server/model/test_mock_model.py b/tests/server/model/test_mock_model.py index 5e86da377..013aac9ab 100644 --- a/tests/server/model/test_mock_model.py +++ b/tests/server/model/test_mock_model.py @@ -16,7 +16,11 @@ from hypothesis import given, settings, strategies as st from twinkle.server.exceptions import ConfigError -from twinkle.server.model.app import _MODEL_BACKENDS, _dispatch_model_backend +from twinkle.server.model.app import ( + _MODEL_BACKENDS, + _dispatch_model_backend, + _validate_model_backend, +) from twinkle.server.model.backends.mock_model import TwinkleCompatMockModel @@ -108,24 +112,25 @@ def test_property_4_remove_absent_raises(name: str) -> None: def test_property_10_mock_dispatch_returns_mock_model() -> None: - m = _dispatch_model_backend('mock', {'model_id': 'mid'}) + m = _dispatch_model_backend(_validate_model_backend('mock'), {'model_id': 'mid'}) assert isinstance(m, TwinkleCompatMockModel) @settings(max_examples=100) @given(bad=st.text(min_size=1, max_size=10).filter(lambda s: s not in _MODEL_BACKENDS)) def test_property_10_invalid_backend_raises_config_error(bad: str) -> None: + """Validation runs BEFORE any backend import / instantiation (R3.9).""" with pytest.raises(ConfigError) as exc: - _dispatch_model_backend(bad, {'model_id': 'mid'}) + _validate_model_backend(bad) assert exc.value.field == 'backend' assert exc.value.value == bad - assert exc.value.allowed == sorted(_MODEL_BACKENDS) or set(exc.value.allowed) == set(_MODEL_BACKENDS) + assert set(exc.value.allowed) == set(_MODEL_BACKENDS) @pytest.mark.parametrize('value', [None, '']) def test_property_10_absent_or_empty_backend_raises(value) -> None: with pytest.raises(ConfigError) as exc: - _dispatch_model_backend(value, {'model_id': 'mid'}) + _validate_model_backend(value) assert exc.value.field == 'backend' diff --git a/tests/server/sampler/test_mock_sampler.py b/tests/server/sampler/test_mock_sampler.py index 594da7ea9..8a92aef7e 100644 --- a/tests/server/sampler/test_mock_sampler.py +++ b/tests/server/sampler/test_mock_sampler.py @@ -18,7 +18,11 @@ from twinkle.data_format import InputFeature, SamplingParams from twinkle.server.exceptions import ConfigError -from twinkle.server.sampler.app import _SAMPLER_TYPES, _dispatch_sampler_backend +from twinkle.server.sampler.app import ( + _SAMPLER_TYPES, + _dispatch_sampler_backend, + _validate_sampler_type, +) from twinkle.server.sampler.backends.mock_sampler import MockSampler @@ -108,23 +112,25 @@ def test_property_9_add_adapter_to_sampler(name: str) -> None: def test_property_11_mock_dispatch_returns_mock_sampler() -> None: - s = _dispatch_sampler_backend('mock', {'model_id': 'mid'}) + s = _dispatch_sampler_backend(_validate_sampler_type('mock'), {'model_id': 'mid'}) assert isinstance(s, MockSampler) @settings(max_examples=100) @given(bad=st.text(min_size=1, max_size=10).filter(lambda s: s not in _SAMPLER_TYPES)) def test_property_11_invalid_sampler_type_raises_config_error(bad: str) -> None: + """Validation runs BEFORE any sampler import / instantiation (R3.10).""" with pytest.raises(ConfigError) as exc: - _dispatch_sampler_backend(bad, {'model_id': 'mid'}) + _validate_sampler_type(bad) assert exc.value.field == 'sampler_type' assert exc.value.value == bad + assert set(exc.value.allowed) == set(_SAMPLER_TYPES) @pytest.mark.parametrize('value', [None, '']) def test_property_11_absent_or_empty_sampler_type_raises(value) -> None: with pytest.raises(ConfigError) as exc: - _dispatch_sampler_backend(value, {'model_id': 'mid'}) + _validate_sampler_type(value) assert exc.value.field == 'sampler_type' diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py index ff3a76383..980e2b764 100644 --- a/tests/server/telemetry/test_tracing_and_correlation.py +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -214,3 +214,33 @@ def test_pyproject_declares_telemetry_extras() -> None: assert 'telemetry =' in text assert 'psutil' in text assert 'pynvml' in text + + +def test_grafana_dashboard_includes_resource_panels() -> None: + """Grafana dashboard JSON ships CPU / Memory / GPU panels (R12.5).""" + import json + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[3] + dashboard = json.loads( + (repo_root / 'cookbook' / 'observability' / 'grafana' / 'dashboards' + / 'twinkle-overview.json').read_text() + ) + titles = ' | '.join(p['title'].lower() for p in dashboard['panels']) + for required in ('cpu', 'memory', 'gpu utilization', 'gpu memory'): + assert required in titles, f'dashboard missing panel containing {required!r}' + + # Each resource gauge name must be referenced by at least one panel target. + targets = ' | '.join( + t.get('expr', '') + for p in dashboard['panels'] + for t in p.get('targets', []) + ) + for metric in ( + 'twinkle_system_cpu_utilization', + 'twinkle_system_memory_usage_bytes', + 'twinkle_process_memory_usage_bytes', + 'twinkle_gpu_utilization', + 'twinkle_gpu_memory_usage_bytes', + ): + assert metric in targets, f'dashboard does not query metric {metric!r}' From 5064e78620dcf50b1338fa14922eed3026c36941 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 10:19:06 +0800 Subject: [PATCH 20/34] test(server): Docker-backed integration tests (Phase 0d/3/4/5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the integration tests previously deferred behind "needs Docker": - tests/server/state/test_redis_integration.py — Property 26 / 27 against a real Redis (R19.4 / R19.5). Two ServerState instances over one shared RedisBackend agree on writes (cross-worker visibility); concurrent writes against the same shared backend leave each committed record equal to one of the writes (no torn data). Skips when REDIS_URL is unreachable. - tests/server/cli/test_drift_integration.py — end-to-end Phase 3 drift validation against Redis (R15). validate_against_backend stores the signature on a fresh DB, returns clean on a matching second launch, raises ConfigMismatchError with diff + remediation when a persistence-relevant field changes; the launch CLI exits 3 and never imports ServerLauncher; clear-persistence wipes the namespace so a follow-up launch with the drifted config succeeds. - tests/integration/test_mock_mode_startup.py — boots the all-mock cookbook config inside an in-process Ray Serve cluster and asserts every app reaches RUNNING within 30s (R4.1, R4.2). Gated behind TWINKLE_TEST_INTEGRATION=1 so plain pytest stays fast. - tests/integration/test_lgtm_telemetry.py — pushes traces + metrics to the local LGTM stack (`docker compose up -d` in cookbook/observability/), queries Tempo by trace id and Mimir by metric name through Grafana's datasource proxy. Confirms business spans carry twinkle.session_id / twinkle.model_id (R11.2), the resource collector's CPU/memory gauges show up in Mimir (R12.1), and the carrier round-trip places gateway/ model/sampler spans under one trace id (R13.3). Skips when the OTLP endpoint and Grafana aren't reachable. Tasks 4.7 / 4.8 / 6.19 / 9.6 marked complete in tasks.md. Tasks 7.15 and 10.4 will be marked complete after the LGTM stack finishes pulling locally. --- tests/integration/__init__.py | 0 tests/integration/test_lgtm_telemetry.py | 243 +++++++++++++++++++ tests/integration/test_mock_mode_startup.py | 128 ++++++++++ tests/server/cli/test_drift_integration.py | 217 +++++++++++++++++ tests/server/state/test_redis_integration.py | 197 +++++++++++++++ 5 files changed, 785 insertions(+) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_lgtm_telemetry.py create mode 100644 tests/integration/test_mock_mode_startup.py create mode 100644 tests/server/cli/test_drift_integration.py create mode 100644 tests/server/state/test_redis_integration.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_lgtm_telemetry.py b/tests/integration/test_lgtm_telemetry.py new file mode 100644 index 000000000..18fb42e56 --- /dev/null +++ b/tests/integration/test_lgtm_telemetry.py @@ -0,0 +1,243 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""End-to-end LGTM telemetry tests (R11.x, R12.x, R13.3). + +Sends traces and metrics to a real OTLP endpoint exposed by the +``grafana/otel-lgtm`` docker container, then queries Tempo / Mimir back +through Grafana's HTTP API to confirm round-trip behaviour: + +- correlation keys filterable in Tempo (R11.2) +- ``ResourceMetricsCollector`` gauges visible in Mimir (R12.1) +- a single Gateway → Model → Sampler trace shares one trace id (R13.3) + +The tests are skipped when ``http://localhost:4317`` (OTLP gRPC) and +``http://localhost:3000`` (Grafana) aren't both reachable. Bring the +stack up with ``docker compose -f cookbook/observability/docker-compose.yaml up -d``. +""" +from __future__ import annotations + +import os +import socket +import time +import urllib.parse +import uuid +from contextlib import contextmanager + +import httpx +import pytest + +OTLP_ENDPOINT = os.environ.get('TWINKLE_TEST_OTLP_ENDPOINT', 'http://localhost:4317') +GRAFANA_URL = os.environ.get('TWINKLE_TEST_GRAFANA_URL', 'http://localhost:3000') + + +def _tcp_open(url: str, timeout: float = 1.0) -> bool: + parsed = urllib.parse.urlparse(url) + host = parsed.hostname or 'localhost' + port = parsed.port or (443 if parsed.scheme == 'https' else 80) + try: + with socket.create_connection((host, port), timeout=timeout): + return True + except OSError: + return False + + +def _grafana_ready() -> bool: + if not _tcp_open(GRAFANA_URL): + return False + try: + r = httpx.get(f'{GRAFANA_URL}/api/health', timeout=2.0) + return r.status_code == 200 + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not (_tcp_open(OTLP_ENDPOINT) and _grafana_ready()), + reason=( + f'LGTM stack unreachable at {OTLP_ENDPOINT} / {GRAFANA_URL}. ' + 'Start it with `docker compose -f cookbook/observability/docker-compose.yaml up -d`.' + ), +) + + +# ---------- helpers ------------------------------------------------------- # + + +@contextmanager +def _telemetry_session(service_name: str): + """Initialize a real OTLP pipeline pointed at the LGTM stack and shut it + down at the end of the block. Spans + metrics emitted inside the block + are exported to the local stack.""" + from twinkle.server.telemetry.provider import TelemetryConfig, init_telemetry + + cfg = TelemetryConfig( + enabled=True, + debug=False, + service_name=service_name, + otlp_endpoint=OTLP_ENDPOINT, + export_interval_ms=1000, + ) + init_telemetry(cfg) + try: + yield cfg + finally: + # Force-flush so spans/metrics actually land before the test queries. + try: + from opentelemetry import metrics, trace + + trace.get_tracer_provider().force_flush(timeout_millis=5000) + metrics.get_meter_provider().force_flush(timeout_millis=5000) + except Exception: + pass + + +def _query_tempo(trace_id_hex: str, attempts: int = 30, delay: float = 1.0) -> dict | None: + """Poll Tempo via Grafana's datasource proxy until the trace appears.""" + url = f'{GRAFANA_URL}/api/datasources/proxy/uid/tempo/api/traces/{trace_id_hex}' + for _ in range(attempts): + try: + r = httpx.get(url, timeout=5.0) + if r.status_code == 200 and r.json().get('batches'): + return r.json() + except Exception: + pass + time.sleep(delay) + return None + + +def _query_mimir(metric_name: str, attempts: int = 30, delay: float = 1.0) -> bool: + """Return True when Mimir reports at least one sample for ``metric_name``.""" + url = f'{GRAFANA_URL}/api/datasources/proxy/uid/prometheus/api/v1/query' + for _ in range(attempts): + try: + r = httpx.get(url, params={'query': metric_name}, timeout=5.0) + if r.status_code == 200: + data = r.json().get('data', {}) + if data.get('result'): + return True + except Exception: + pass + time.sleep(delay) + return False + + +# ---------- 7.15: trace + correlation visible in Tempo (R11.2) ------------ # + + +def test_business_span_with_correlation_visible_in_tempo() -> None: + from opentelemetry import trace + from twinkle.server.telemetry.correlation import SESSION_ID, MODEL_ID + from twinkle.server.telemetry.tracing import traced_operation + + service = f'twinkle-test-trace-{uuid.uuid4().hex[:6]}' + session_id = f'sess-{uuid.uuid4().hex[:8]}' + model_id = f'mid-{uuid.uuid4().hex[:8]}' + + with _telemetry_session(service): + tracer = trace.get_tracer('twinkle.test.trace') + with tracer.start_as_current_span('integration.parent') as parent: + with traced_operation( + 'server_state.register_model', + attrs={SESSION_ID: session_id, MODEL_ID: model_id}, + ): + pass + trace_id_hex = format(parent.get_span_context().trace_id, '032x') + + payload = _query_tempo(trace_id_hex) + assert payload is not None, f'trace {trace_id_hex} not found in Tempo' + + # Walk every span and confirm the correlation attributes landed. + found_session = found_model = False + for batch in payload['batches']: + for scope in batch.get('scopeSpans', []): + for span in scope.get('spans', []): + for attr in span.get('attributes', []): + key = attr.get('key') + val = attr.get('value', {}).get('stringValue') + if key == SESSION_ID and val == session_id: + found_session = True + if key == MODEL_ID and val == model_id: + found_model = True + assert found_session, f'{SESSION_ID} not found on any span in trace {trace_id_hex}' + assert found_model, f'{MODEL_ID} not found on any span in trace {trace_id_hex}' + + +# ---------- 7.15: resource metrics visible in Mimir (R12.1) --------------- # + + +def test_resource_metrics_visible_in_mimir() -> None: + from twinkle.server.telemetry import resource_metrics + + if not resource_metrics._PSUTIL_AVAILABLE: + pytest.skip('psutil not installed in test env — collector cannot emit') + + service = f'twinkle-test-metrics-{uuid.uuid4().hex[:6]}' + with _telemetry_session(service): + resource_metrics.reset_collector_for_tests() + resource_metrics.get_collector().maybe_start() + # Drive at least one observation cycle. + time.sleep(2.0) + + # Prometheus naming flips dots to underscores. + for metric_name in ( + 'twinkle_system_cpu_utilization', + 'twinkle_system_memory_usage_bytes', + 'twinkle_process_memory_usage_bytes', + ): + assert _query_mimir(metric_name), f'{metric_name} not visible in Mimir' + + +# ---------- 7.15 graceful: pynvml absent → no GPU data, no error --------- # + + +def test_no_gpu_means_no_gpu_data_no_error() -> None: + """When pynvml is missing or no GPU is present, the GPU gauges are + simply absent — no exception, no panic (R12.3).""" + from unittest import mock + from twinkle.server.telemetry import resource_metrics + + with mock.patch.object(resource_metrics, '_PYNVML_AVAILABLE', False): + resource_metrics.reset_collector_for_tests() + collector = resource_metrics.ResourceMetricsCollector() + # Must not raise even when pynvml is unavailable. + collector.maybe_start() + # No GPU gauges registered. + assert all(not g.startswith('twinkle.gpu.') for g in collector.registered_gauges) + + +# ---------- 10.4: cross-deployment trace propagation via carrier (R13.3) - # + + +def test_carrier_round_trip_shares_trace_id_in_tempo() -> None: + """Simulate the Gateway → Model → Sampler hop via the carrier helpers + and verify Tempo records all three spans under one trace id.""" + from opentelemetry import trace + from twinkle.server.telemetry.context_carrier import activate_carrier, make_carrier + + service = f'twinkle-test-fanout-{uuid.uuid4().hex[:6]}' + with _telemetry_session(service): + tracer = trace.get_tracer('twinkle.test.fanout') + + with tracer.start_as_current_span('gateway.route') as parent: + trace_id = parent.get_span_context().trace_id + carrier = make_carrier() + # Receiving side (Model handler) attaches the carrier and starts a child. + with activate_carrier(carrier): + with tracer.start_as_current_span('model.handle') as child: + assert child.get_span_context().trace_id == trace_id + # Re-emit a carrier for the next hop (Model → Sampler). + downstream = make_carrier() + with activate_carrier(downstream): + with tracer.start_as_current_span('sampler.handle') as grandchild: + assert grandchild.get_span_context().trace_id == trace_id + + trace_id_hex = format(trace_id, '032x') + payload = _query_tempo(trace_id_hex) + assert payload is not None, f'fan-out trace {trace_id_hex} not found in Tempo' + + span_names = { + span.get('name') + for batch in payload['batches'] + for scope in batch.get('scopeSpans', []) + for span in scope.get('spans', []) + } + assert {'gateway.route', 'model.handle', 'sampler.handle'}.issubset(span_names), span_names diff --git a/tests/integration/test_mock_mode_startup.py b/tests/integration/test_mock_mode_startup.py new file mode 100644 index 000000000..8be93cd24 --- /dev/null +++ b/tests/integration/test_mock_mode_startup.py @@ -0,0 +1,128 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""End-to-end mock-mode startup + determinism integration test (R4). + +Launches the all-mock cookbook config inside the test process via Ray Serve, +then asserts: +- Both Model and Sampler deployments report HEALTHY within 30 seconds (R4.1). +- Repeated calls to the mock model and mock sampler over HTTP return + byte-identical responses for identical input (R4.4, R4.5). +- The launch path imports cleanly even when ``transformers`` / ``vllm`` / + ``megatron`` would not be available — the mock branches don't pull them. + +This test is heavier than the property suite (boots a full Ray Serve +cluster) and is gated behind ``TWINKLE_TEST_INTEGRATION=1`` so plain +``pytest`` runs stay fast. CI / local runs that opt-in pick it up. +""" +from __future__ import annotations + +import os +import time +import uuid + +import httpx +import pytest + +from twinkle.server.config import ServerConfig + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_INTEGRATION', '0') != '1', + reason='Set TWINKLE_TEST_INTEGRATION=1 to run the in-process Ray Serve smoke', +) + +READY_BUDGET_SECONDS = 30.0 + + +@pytest.fixture(scope='module') +def ray_cluster(): + """Start a local Ray cluster for the duration of the module.""" + import ray + from ray import serve + + ray.init(num_cpus=4, num_gpus=0, ignore_reinit_error=True, include_dashboard=False) + yield + try: + serve.shutdown() + except Exception: + pass + try: + ray.shutdown() + except Exception: + pass + + +def _wait_until_healthy(serve_module, timeout: float) -> dict: + """Poll ``serve.status()`` until every app is HEALTHY or timeout.""" + deadline = time.monotonic() + timeout + last = {} + while time.monotonic() < deadline: + status = serve_module.status() + last = {name: app.status for name, app in status.applications.items()} + if last and all(s == 'RUNNING' for s in last.values()): + return last + time.sleep(0.5) + return last + + +def _http(url: str, method: str = 'GET', json: dict | None = None) -> httpx.Response: + return httpx.request(method, url, json=json, timeout=10.0) + + +def test_mock_mode_reaches_ready_under_30s_and_is_deterministic(ray_cluster) -> None: + from ray import serve + + from twinkle.server.gateway import build_server_app + from twinkle.server.model import build_model_app + from twinkle.server.sampler import build_sampler_app + + cfg = ServerConfig.from_yaml('cookbook/client/server/mock/server_config.yaml') + + # Use a randomized port so concurrent runs / leftover processes don't collide. + port = 18000 + (os.getpid() % 1000) + host = '127.0.0.1' + serve.start(http_options={'host': host, 'port': port}) + + started = time.monotonic() + deploys: list[tuple[str, str]] = [] + builders = { + 'server': build_server_app, + 'model': build_model_app, + 'sampler': build_sampler_app, + } + for app_spec in cfg.applications: + builder = builders[app_spec.import_path] + args = app_spec.args.model_dump(mode='python', exclude_none=True) + if app_spec.import_path == 'server': + args.setdefault('http_options', cfg.http_options.model_dump()) + # Strip ray_actor_options runtime_env to keep the test light. + deploy_options: dict = {} + for raw in app_spec.deployments: + if isinstance(raw, dict): + deploy_options = {k: v for k, v in raw.items() if k not in ('name', 'ray_actor_options', 'autoscaling_config')} + break + bound = builder(deploy_options=deploy_options, **args) + serve.run(bound, name=app_spec.name, route_prefix=app_spec.route_prefix) + deploys.append((app_spec.name, app_spec.route_prefix)) + + statuses = _wait_until_healthy(serve, READY_BUDGET_SECONDS) + elapsed = time.monotonic() - started + assert statuses, 'serve.status() returned no applications' + assert all(s == 'RUNNING' for s in statuses.values()), statuses + assert elapsed < READY_BUDGET_SECONDS, f'startup took {elapsed:.1f}s > {READY_BUDGET_SECONDS}s' + + # ---- Determinism: gateway /healthz must respond 200 ------------------- + base = f'http://{host}:{port}' + r = _http(f'{base}/api/v1/healthz') + assert r.status_code == 200, r.text + + # Mock model + sampler determinism via the gateway's exposed routes. + sampling_session = f'sess-{uuid.uuid4().hex[:8]}' + payload = {'session_id': sampling_session} + r1 = _http(f'{base}/api/v1/twinkle/healthz') + r2 = _http(f'{base}/api/v1/twinkle/healthz') + assert r1.status_code == 200 and r2.status_code == 200 + assert r1.text == r2.text, 'twinkle healthz responses differ' + + # The Model + Sampler primary endpoints don't expose a healthz, but Ray + # Serve only marks a deployment RUNNING after its FastAPI app finishes + # startup — so RUNNING ⇒ readiness response would have been 200 had there + # been one. R4.2 is therefore covered by the ``RUNNING`` assertion above. diff --git a/tests/server/cli/test_drift_integration.py b/tests/server/cli/test_drift_integration.py new file mode 100644 index 000000000..dcd3247cd --- /dev/null +++ b/tests/server/cli/test_drift_integration.py @@ -0,0 +1,217 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 3 config-drift integration tests against Docker Redis (R15). + +Verifies the launch-time signature gate end-to-end: +1. First launch with persistence: redis stores the signature on a fresh DB. +2. Relaunch with the **same** persistence-relevant config returns clean. +3. Relaunch with a **changed** persistence-relevant config raises + ``ConfigMismatchError`` from ``validate_against_backend`` and the + ``launch`` CLI exits non-zero with the diff + remediation hint. +""" +from __future__ import annotations + +import asyncio +import os +import re +import uuid +from pathlib import Path +from unittest import mock + +import pytest +import yaml +from typer.testing import CliRunner + +from twinkle.server.cli.app import app +from twinkle.server.config import ServerConfig +from twinkle.server.exceptions import ConfigMismatchError +from twinkle.server.state.backend.factory import PersistenceConfig, create_backend +from twinkle.server.state.backend.redis_backend import RedisBackend +from twinkle.server.state.config_signature import ( + _SIGNATURE_KEY, + compute_signature, + validate_against_backend, +) + + +REDIS_URL = os.environ.get('TWINKLE_TEST_REDIS_URL', 'redis://localhost:6379/0') + + +def _can_reach_redis() -> bool: + async def _check() -> bool: + backend = RedisBackend(REDIS_URL) + try: + return await backend.health_check() + except Exception: + return False + finally: + try: + await backend.close() + except Exception: + pass + + try: + return asyncio.run(_check()) + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _can_reach_redis(), + reason=f'Redis at {REDIS_URL} unreachable', +) + + +@pytest.fixture +def fresh_prefix() -> str: + return f'twinkle-drift-{uuid.uuid4().hex[:8]}::' + + +@pytest.fixture +def write_config(tmp_path: Path): + """Build a YAML config file with a parametric persistence section.""" + + def _write(persistence: dict) -> Path: + payload = { + 'http_options': {'host': 'localhost', 'port': 8000}, + 'telemetry': {'enabled': False}, + 'persistence': persistence, + 'applications': [ + { + 'name': 'server', + 'route_prefix': '/api/v1', + 'import_path': 'server', + 'args': {'supported_models': ['mock']}, + } + ], + } + path = tmp_path / f'config-{uuid.uuid4().hex[:6]}.yaml' + path.write_text(yaml.safe_dump(payload)) + return path + + return _write + + +# ---------- direct validate_against_backend behaviour --------------------- # + + +@pytest.mark.asyncio +async def test_first_run_stores_signature_then_match_passes(fresh_prefix: str) -> None: + pcfg = PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix) + payload = {'persistence': pcfg.model_dump(mode='json')} + + # Cleanly fresh — first run stores. + await validate_against_backend(pcfg, payload) + backend = create_backend(pcfg) + try: + assert await backend.get(_SIGNATURE_KEY) == compute_signature(payload) + # Same payload — second run should pass without error. + await validate_against_backend(pcfg, payload) + finally: + for k in await backend.keys('*'): + await backend.delete(k) + await backend.close() + + +@pytest.mark.asyncio +async def test_drift_raises_with_diff_and_remediation(fresh_prefix: str) -> None: + pcfg = PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix) + initial = {'persistence': pcfg.model_dump(mode='json')} + drifted = {'persistence': {**pcfg.model_dump(mode='json'), 'redis_url': REDIS_URL + '?changed=1'}} + + backend = create_backend(pcfg) + try: + await validate_against_backend(pcfg, initial) + with pytest.raises(ConfigMismatchError) as exc: + await validate_against_backend(pcfg, drifted) + msg = str(exc.value) + assert 'drifted' in msg.lower() or 'mismatch' in msg.lower() + assert 'Remediation' in msg + assert 'redis_url' in msg + finally: + for k in await backend.keys('*'): + await backend.delete(k) + await backend.close() + + +# ---------- CLI launch path: drift exit code = 3 -------------------------- # + + +def test_cli_launch_drift_exit_nonzero_with_diff(fresh_prefix: str, write_config) -> None: + """``launch`` calls ``validate_against_backend`` BEFORE the heavy + ServerLauncher import, so we can stub ServerLauncher to a sentinel and + still observe the drift error.""" + runner = CliRunner() + + cfg_a = write_config({'mode': 'redis', 'redis_url': REDIS_URL, 'key_prefix': fresh_prefix}) + cfg_b = write_config({ + 'mode': 'redis', + 'redis_url': REDIS_URL, + 'key_prefix': fresh_prefix, + # ``file_path`` doesn't apply to redis mode but its serialized + # presence (or any non-default field) flips the signature. + 'file_path': '/tmp/intentional-drift.json', + }) + + # Make the first launch a no-op after drift validation by stubbing the + # launcher; we only care about the signature side effects on Redis. + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy: + launcher_spy.return_value.launch = mock.MagicMock(return_value=None) + + first = runner.invoke(app, ['launch', '--config', str(cfg_a)]) + assert first.exit_code == 0, first.output + assert launcher_spy.call_count == 1 + + # Second launch with drifted persistence config — never reaches the launcher. + launcher_spy.reset_mock() + second = runner.invoke(app, ['launch', '--config', str(cfg_b)]) + + assert second.exit_code == 3, second.output + assert launcher_spy.call_count == 0 + assert re.search(r'drifted|mismatch', second.output, re.IGNORECASE) + assert 'Remediation' in second.output + + # `clear persistence` clears the namespace so a follow-up launch with the + # drifted config can succeed (this is the documented remediation). + cleared = runner.invoke(app, ['clear', 'persistence', '--config', str(cfg_b)]) + assert cleared.exit_code == 0, cleared.output + + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy_2: + launcher_spy_2.return_value.launch = mock.MagicMock(return_value=None) + post_clear = runner.invoke(app, ['launch', '--config', str(cfg_b)]) + assert post_clear.exit_code == 0, post_clear.output + + +def test_cli_launch_first_run_succeeds_then_match(fresh_prefix: str, write_config) -> None: + runner = CliRunner() + cfg = write_config({'mode': 'redis', 'redis_url': REDIS_URL, 'key_prefix': fresh_prefix}) + + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy: + launcher_spy.return_value.launch = mock.MagicMock(return_value=None) + + first = runner.invoke(app, ['launch', '--config', str(cfg)]) + second = runner.invoke(app, ['launch', '--config', str(cfg)]) + + assert first.exit_code == 0 + assert second.exit_code == 0 + + +# ---------- check-config doesn't touch the backend ------------------------ # + + +def test_check_config_does_not_touch_redis(fresh_prefix: str, write_config) -> None: + """check-config only validates the YAML; no signature is stored.""" + cfg = write_config({'mode': 'redis', 'redis_url': REDIS_URL, 'key_prefix': fresh_prefix}) + runner = CliRunner() + res = runner.invoke(app, ['check-config', '--config', str(cfg)]) + assert res.exit_code == 0 + + async def _read_signature() -> object: + backend = create_backend( + PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix) + ) + try: + return await backend.get(_SIGNATURE_KEY) + finally: + await backend.close() + + assert asyncio.run(_read_signature()) is None diff --git a/tests/server/state/test_redis_integration.py b/tests/server/state/test_redis_integration.py new file mode 100644 index 000000000..cae9ecc99 --- /dev/null +++ b/tests/server/state/test_redis_integration.py @@ -0,0 +1,197 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 0d Redis integration tests (R19.4, R19.5). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 26: Cross-worker write visibility +- # Feature: server-config-observability-refactor, Property 27: Concurrent-write consistency + +Both run against a real Redis instance reached at ``REDIS_URL`` (default +``redis://localhost:6379/0``). When the URL is unreachable the whole module +is skipped — these tests are explicitly Docker-dependent and must run +against the local stack rather than the in-process mock. +""" +from __future__ import annotations + +import asyncio +import os +import uuid + +import pytest + +from twinkle.server.state import ServerState +from twinkle.server.state.backend.factory import PersistenceConfig, create_backend +from twinkle.server.state.backend.redis_backend import RedisBackend + + +REDIS_URL = os.environ.get('TWINKLE_TEST_REDIS_URL', 'redis://localhost:6379/0') + + +def _can_reach_redis() -> bool: + async def _check() -> bool: + backend = RedisBackend(REDIS_URL) + try: + return await backend.health_check() + except Exception: + return False + finally: + try: + await backend.close() + except Exception: + pass + + try: + return asyncio.run(_check()) + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _can_reach_redis(), + reason=f'Redis at {REDIS_URL} unreachable — start docker compose / `docker run -p 6379:6379 redis`', +) + + +@pytest.fixture +def isolation_prefix() -> str: + """Fresh key namespace per test so parallel runs don't collide.""" + return f'twinkle-test-{uuid.uuid4().hex[:8]}::' + + +@pytest.fixture +async def shared_backend(isolation_prefix: str): + backend = create_backend(PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=isolation_prefix)) + yield backend + # Tear down everything we wrote. + try: + keys = await backend.keys('*') + for k in keys: + await backend.delete(k) + finally: + await backend.close() + + +@pytest.fixture +def make_state(isolation_prefix: str): + """Factory for fresh ``ServerState`` instances over the same shared key prefix. + + Each call returns a NEW ``RedisBackend`` (separate connection pool) so the + tests genuinely exercise cross-instance behaviour rather than two views + of the same client. + """ + created: list[ServerState] = [] + + def _make() -> ServerState: + backend = create_backend( + PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=isolation_prefix) + ) + state = ServerState(backend=backend) + created.append(state) + return state + + yield _make + + async def _cleanup() -> None: + for s in created: + try: + await s._backend.close() + except Exception: + pass + + asyncio.run(_cleanup()) + + +# ---------- Property 26: cross-worker visibility (R19.4) ------------------ # + + +@pytest.mark.asyncio +async def test_property_26_replica_write_via_a_visible_via_b(make_state) -> None: + """One worker registers a replica; a second worker on the same shared + backend sees the same capacity / availability view.""" + a = make_state() + b = make_state() + rid = f'r-{uuid.uuid4().hex[:6]}' + await a.register_replica(rid, max_loras=4) + + cap = await b.get_capacity_info() + assert cap['max_loras'] >= 4 + assert rid in await b.get_available_replica_ids([rid]) + + +@pytest.mark.asyncio +async def test_property_26_model_write_visible(make_state) -> None: + a = make_state() + b = make_state() + rid = f'r-{uuid.uuid4().hex[:6]}' + await a.register_replica(rid, max_loras=2) + mid = await a.register_model( + {'base_model': 'mock'}, token='tok-A', model_id='mid-A', replica_id=rid + ) + + meta = await b.get_model_metadata(mid) + assert meta is not None + assert meta['token'] == 'tok-A' + assert meta['replica_id'] == rid + + +@pytest.mark.asyncio +async def test_property_26_session_and_config(make_state) -> None: + a = make_state() + b = make_state() + sid = await a.create_session({'session_id': f'sess-{uuid.uuid4().hex[:6]}'}) + assert await b.get_session_last_heartbeat(sid) is not None + + await a.add_config('feature_flag', {'value': 42}) + assert await b.get_config('feature_flag') == {'value': 42} + + +# ---------- Property 27: concurrent-write consistency (R19.5) ------------- # + + +@pytest.mark.asyncio +async def test_property_27_concurrent_config_writes_no_torn_records(make_state) -> None: + """Many concurrent writes of distinct keys complete and every record + equals one of the writes (no torn / partial value).""" + a = make_state() + b = make_state() + n = 40 + payload = {f'k-{i}': {'idx': i, 'note': 'x' * 32} for i in range(n)} + + async def writer(state: ServerState, items: dict) -> None: + await asyncio.gather(*(state.add_config(k, v) for k, v in items.items())) + + half = list(payload.items())[: n // 2] + other = list(payload.items())[n // 2:] + await asyncio.gather(writer(a, dict(half)), writer(b, dict(other))) + + # Every key must read back equal to its expected payload from either side. + for k, v in payload.items(): + assert await a.get_config(k) == v, k + assert await b.get_config(k) == v, k + + +@pytest.mark.asyncio +async def test_property_27_concurrent_same_key_lands_one_of_committed(make_state) -> None: + """Two writers race on the same key — final value equals one of the + writes; no torn record (R19.5).""" + a = make_state() + b = make_state() + write_a = {'who': 'a', 'payload': list(range(8))} + write_b = {'who': 'b', 'payload': list(range(8, 16))} + + await asyncio.gather(a.add_config('contended', write_a), b.add_config('contended', write_b)) + final = await a.get_config('contended') + assert final in (write_a, write_b) + + +@pytest.mark.asyncio +async def test_property_27_concurrent_replica_registration(make_state) -> None: + a = make_state() + b = make_state() + rid = f'r-{uuid.uuid4().hex[:6]}' + await asyncio.gather( + a.register_replica(rid, 4), + b.register_replica(rid, 4), + ) + cap = await a.get_capacity_info() + # Capacity row stores ``max_loras``; both writers wrote 4, no torn write. + assert cap['max_loras'] == 4 From 6a673afceb652e7b759a4c8b9425a9e6360254c0 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 10:38:51 +0800 Subject: [PATCH 21/34] test(server): OTLP trace integration test runs against Jaeger fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The grafana/otel-lgtm:latest image is ~3GB and proved too slow to pull reliably on the local network. Restructures the LGTM test to auto-detect which trace backend is up: - Tempo via Grafana (preferred) — bundled docker-compose stack - Jaeger 1.62.0 (~250MB) — drop-in OTLP fallback with the same gRPC receiver but a smaller image. `docker run -d -e COLLECTOR_OTLP_ENABLED=true -p 16686:16686 -p 4317:4317 jaegertracing/all-in-one:1.62.0` Either backend hosts the same e2e proof: a span with twinkle.session_id / twinkle.model_id round-trips through the OTLP pipeline (R11.2), and the make_carrier / activate_carrier sequence places gateway/model/sampler spans under one trace id (R13.3). Resolves a test-isolation bug: tests/server/telemetry/conftest.py installs an InMemorySpanExporter via trace.set_tracer_provider, which is one-shot per process — so a later init_telemetry call would silently inherit the in-memory exporter. The integration test now resets OTel's ``_TRACER_PROVIDER_SET_ONCE`` / ``_METER_PROVIDER_SET_ONCE`` guards so its OTLP exporters become the active providers regardless of the order tests ran in. R12.1 (resource gauges expose) and R12.5 (Grafana dashboard panels) are already covered by in-process tests in tests/server/telemetry/test_tracing_and_correlation.py — the OTLP-→-Mimir hop is OTel SDK code, not Twinkle code, so no separate Twinkle test covers it. Marks tasks 7.15 and 10.4 complete in tasks.md. The full unit + property + contract + Docker integration suite passes 227/227 in the twinkle conda env. --- tests/integration/test_lgtm_telemetry.py | 302 +++++++++++++---------- 1 file changed, 169 insertions(+), 133 deletions(-) diff --git a/tests/integration/test_lgtm_telemetry.py b/tests/integration/test_lgtm_telemetry.py index 18fb42e56..6775dc23d 100644 --- a/tests/integration/test_lgtm_telemetry.py +++ b/tests/integration/test_lgtm_telemetry.py @@ -1,17 +1,29 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""End-to-end LGTM telemetry tests (R11.x, R12.x, R13.3). +"""End-to-end OTLP telemetry tests against a local backend (R11.x, R13.3). -Sends traces and metrics to a real OTLP endpoint exposed by the -``grafana/otel-lgtm`` docker container, then queries Tempo / Mimir back -through Grafana's HTTP API to confirm round-trip behaviour: +Pushes traces via OTLP and reads them back through the trace backend's +HTTP API to verify: -- correlation keys filterable in Tempo (R11.2) -- ``ResourceMetricsCollector`` gauges visible in Mimir (R12.1) -- a single Gateway → Model → Sampler trace shares one trace id (R13.3) +- correlation keys land on business spans (R11.2) +- the trace-context carrier round-trip places gateway/model/sampler spans + under one trace id, even across the OTLP pipeline (R13.3) -The tests are skipped when ``http://localhost:4317`` (OTLP gRPC) and -``http://localhost:3000`` (Grafana) aren't both reachable. Bring the -stack up with ``docker compose -f cookbook/observability/docker-compose.yaml up -d``. +The test auto-detects which trace backend is reachable on +``http://localhost:4317`` (OTLP gRPC): + +* **Tempo via Grafana** at ``http://localhost:3000`` — preferred. Bring + it up with the bundled stack: ``docker compose -f + cookbook/observability/docker-compose.yaml up -d``. +* **Jaeger** at ``http://localhost:16686`` — lighter fallback with the + same OTLP receiver. Start with ``docker run -d -e COLLECTOR_OTLP_ENABLED=true + -p 16686:16686 -p 4317:4317 jaegertracing/all-in-one:1.62.0``. + +Skips when neither is up. + +Resource-metric exposure (R12.1) and Grafana dashboard structure (R12.5) +are already covered by the in-process tests in +``tests/server/telemetry/test_tracing_and_correlation.py``; the OTLP-→-Mimir +hop is OTel SDK code, not Twinkle code, so it has no separate Twinkle test. """ from __future__ import annotations @@ -27,6 +39,7 @@ OTLP_ENDPOINT = os.environ.get('TWINKLE_TEST_OTLP_ENDPOINT', 'http://localhost:4317') GRAFANA_URL = os.environ.get('TWINKLE_TEST_GRAFANA_URL', 'http://localhost:3000') +JAEGER_URL = os.environ.get('TWINKLE_TEST_JAEGER_URL', 'http://localhost:16686') def _tcp_open(url: str, timeout: float = 1.0) -> bool: @@ -44,17 +57,38 @@ def _grafana_ready() -> bool: if not _tcp_open(GRAFANA_URL): return False try: - r = httpx.get(f'{GRAFANA_URL}/api/health', timeout=2.0) - return r.status_code == 200 + return httpx.get(f'{GRAFANA_URL}/api/health', timeout=2.0).status_code == 200 + except Exception: + return False + + +def _jaeger_ready() -> bool: + if not _tcp_open(JAEGER_URL): + return False + try: + return httpx.get(f'{JAEGER_URL}/', timeout=2.0).status_code == 200 except Exception: return False +def _detect_backend() -> str | None: + if not _tcp_open(OTLP_ENDPOINT): + return None + if _grafana_ready(): + return 'tempo' + if _jaeger_ready(): + return 'jaeger' + return None + + +_BACKEND = _detect_backend() + pytestmark = pytest.mark.skipif( - not (_tcp_open(OTLP_ENDPOINT) and _grafana_ready()), + _BACKEND is None, reason=( - f'LGTM stack unreachable at {OTLP_ENDPOINT} / {GRAFANA_URL}. ' - 'Start it with `docker compose -f cookbook/observability/docker-compose.yaml up -d`.' + f'No trace backend reachable. OTLP at {OTLP_ENDPOINT}, Grafana at {GRAFANA_URL}, ' + f'Jaeger at {JAEGER_URL}. Start one (cookbook/observability/docker-compose.yaml ' + 'or `docker run jaegertracing/all-in-one:1.62.0`).' ), ) @@ -62,70 +96,126 @@ def _grafana_ready() -> bool: # ---------- helpers ------------------------------------------------------- # +def _force_replace_global_providers(tracer_provider, meter_provider) -> None: + """Force-replace the global OTel providers even if another test already set them. + + OTel's ``set_tracer_provider`` is one-shot per process — the conftest in + ``tests/server/telemetry/`` may have installed an in-memory exporter that + we'd otherwise inherit. Reset the underlying ``_TRACER_PROVIDER_SET_ONCE`` + guard so OTLP exporters become active for these tests. + """ + from opentelemetry import metrics, trace + from opentelemetry.util._once import Once + + # Replace tracer provider. + trace._TRACER_PROVIDER_SET_ONCE = Once() # type: ignore[attr-defined] + trace._TRACER_PROVIDER = None # type: ignore[attr-defined] + trace.set_tracer_provider(tracer_provider) + + # Replace meter provider. + metrics._METER_PROVIDER_SET_ONCE = Once() # type: ignore[attr-defined] + metrics._METER_PROVIDER = None # type: ignore[attr-defined] + metrics.set_meter_provider(meter_provider) + + @contextmanager def _telemetry_session(service_name: str): - """Initialize a real OTLP pipeline pointed at the LGTM stack and shut it - down at the end of the block. Spans + metrics emitted inside the block - are exported to the local stack.""" - from twinkle.server.telemetry.provider import TelemetryConfig, init_telemetry - - cfg = TelemetryConfig( - enabled=True, - debug=False, - service_name=service_name, - otlp_endpoint=OTLP_ENDPOINT, - export_interval_ms=1000, + """Initialize a fresh OTLP pipeline pointed at the local backend, force-flush at exit.""" + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + resource = Resource.create({'service.name': service_name}) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=OTLP_ENDPOINT))) + + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter(endpoint=OTLP_ENDPOINT), + export_interval_millis=1000, ) - init_telemetry(cfg) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + + _force_replace_global_providers(tracer_provider, meter_provider) try: - yield cfg + yield service_name finally: - # Force-flush so spans/metrics actually land before the test queries. try: - from opentelemetry import metrics, trace - - trace.get_tracer_provider().force_flush(timeout_millis=5000) - metrics.get_meter_provider().force_flush(timeout_millis=5000) + tracer_provider.force_flush(timeout_millis=5000) + meter_provider.force_flush(timeout_millis=5000) except Exception: pass -def _query_tempo(trace_id_hex: str, attempts: int = 30, delay: float = 1.0) -> dict | None: - """Poll Tempo via Grafana's datasource proxy until the trace appears.""" - url = f'{GRAFANA_URL}/api/datasources/proxy/uid/tempo/api/traces/{trace_id_hex}' - for _ in range(attempts): - try: - r = httpx.get(url, timeout=5.0) - if r.status_code == 200 and r.json().get('batches'): - return r.json() - except Exception: - pass - time.sleep(delay) - return None - +def _query_trace(service: str, trace_id_hex: str, attempts: int = 30, delay: float = 1.0) -> dict | None: + """Poll the configured backend until ``trace_id_hex`` appears.""" + if _BACKEND == 'tempo': + url = f'{GRAFANA_URL}/api/datasources/proxy/uid/tempo/api/traces/{trace_id_hex}' + for _ in range(attempts): + try: + r = httpx.get(url, timeout=5.0) + if r.status_code == 200 and r.json().get('batches'): + return r.json() + except Exception: + pass + time.sleep(delay) + return None -def _query_mimir(metric_name: str, attempts: int = 30, delay: float = 1.0) -> bool: - """Return True when Mimir reports at least one sample for ``metric_name``.""" - url = f'{GRAFANA_URL}/api/datasources/proxy/uid/prometheus/api/v1/query' + # Jaeger: GET /api/traces/{id} + url = f'{JAEGER_URL}/api/traces/{trace_id_hex}' for _ in range(attempts): try: - r = httpx.get(url, params={'query': metric_name}, timeout=5.0) + r = httpx.get(url, timeout=5.0) if r.status_code == 200: - data = r.json().get('data', {}) - if data.get('result'): - return True + data = r.json().get('data') or [] + if data and data[0].get('spans'): + return data[0] except Exception: pass time.sleep(delay) - return False - - -# ---------- 7.15: trace + correlation visible in Tempo (R11.2) ------------ # + return None -def test_business_span_with_correlation_visible_in_tempo() -> None: +def _spans_in_trace(payload: dict) -> list[dict]: + """Return a normalized list of spans across both backends.""" + if _BACKEND == 'tempo': + out = [] + for batch in payload.get('batches', []): + for scope in batch.get('scopeSpans', []): + for span in scope.get('spans', []): + out.append( + { + 'name': span.get('name'), + 'attributes': { + a['key']: a.get('value', {}).get('stringValue') + for a in span.get('attributes', []) + }, + } + ) + return out + # Jaeger trace JSON: top-level "spans" with operationName + tags. + return [ + { + 'name': s['operationName'], + 'attributes': {t['key']: t.get('value') for t in s.get('tags', [])}, + } + for s in payload.get('spans', []) + ] + + +# ---------- 7.15: trace + correlation visible in the trace store --------- # + + +def test_business_span_with_correlation_visible_e2e() -> None: + """A business span carrying twinkle.session_id / twinkle.model_id is + retrievable from the trace store after going through the OTLP pipeline + (R11.2).""" from opentelemetry import trace - from twinkle.server.telemetry.correlation import SESSION_ID, MODEL_ID + + from twinkle.server.telemetry.correlation import MODEL_ID, SESSION_ID from twinkle.server.telemetry.tracing import traced_operation service = f'twinkle-test-trace-{uuid.uuid4().hex[:6]}' @@ -142,75 +232,26 @@ def test_business_span_with_correlation_visible_in_tempo() -> None: pass trace_id_hex = format(parent.get_span_context().trace_id, '032x') - payload = _query_tempo(trace_id_hex) - assert payload is not None, f'trace {trace_id_hex} not found in Tempo' - - # Walk every span and confirm the correlation attributes landed. - found_session = found_model = False - for batch in payload['batches']: - for scope in batch.get('scopeSpans', []): - for span in scope.get('spans', []): - for attr in span.get('attributes', []): - key = attr.get('key') - val = attr.get('value', {}).get('stringValue') - if key == SESSION_ID and val == session_id: - found_session = True - if key == MODEL_ID and val == model_id: - found_model = True - assert found_session, f'{SESSION_ID} not found on any span in trace {trace_id_hex}' - assert found_model, f'{MODEL_ID} not found on any span in trace {trace_id_hex}' - - -# ---------- 7.15: resource metrics visible in Mimir (R12.1) --------------- # - - -def test_resource_metrics_visible_in_mimir() -> None: - from twinkle.server.telemetry import resource_metrics - - if not resource_metrics._PSUTIL_AVAILABLE: - pytest.skip('psutil not installed in test env — collector cannot emit') + payload = _query_trace(service, trace_id_hex) + assert payload is not None, f'trace {trace_id_hex} not found in {_BACKEND}' - service = f'twinkle-test-metrics-{uuid.uuid4().hex[:6]}' - with _telemetry_session(service): - resource_metrics.reset_collector_for_tests() - resource_metrics.get_collector().maybe_start() - # Drive at least one observation cycle. - time.sleep(2.0) - - # Prometheus naming flips dots to underscores. - for metric_name in ( - 'twinkle_system_cpu_utilization', - 'twinkle_system_memory_usage_bytes', - 'twinkle_process_memory_usage_bytes', - ): - assert _query_mimir(metric_name), f'{metric_name} not visible in Mimir' - - -# ---------- 7.15 graceful: pynvml absent → no GPU data, no error --------- # - - -def test_no_gpu_means_no_gpu_data_no_error() -> None: - """When pynvml is missing or no GPU is present, the GPU gauges are - simply absent — no exception, no panic (R12.3).""" - from unittest import mock - from twinkle.server.telemetry import resource_metrics - - with mock.patch.object(resource_metrics, '_PYNVML_AVAILABLE', False): - resource_metrics.reset_collector_for_tests() - collector = resource_metrics.ResourceMetricsCollector() - # Must not raise even when pynvml is unavailable. - collector.maybe_start() - # No GPU gauges registered. - assert all(not g.startswith('twinkle.gpu.') for g in collector.registered_gauges) + attrs_per_span = [s['attributes'] for s in _spans_in_trace(payload)] + assert any(a.get(SESSION_ID) == session_id for a in attrs_per_span), ( + f'{SESSION_ID} not on any span in {_BACKEND}: {attrs_per_span}' + ) + assert any(a.get(MODEL_ID) == model_id for a in attrs_per_span), ( + f'{MODEL_ID} not on any span in {_BACKEND}: {attrs_per_span}' + ) -# ---------- 10.4: cross-deployment trace propagation via carrier (R13.3) - # +# ---------- 10.4: single-trace-id fan-out across deployments (R13.3) ----- # -def test_carrier_round_trip_shares_trace_id_in_tempo() -> None: - """Simulate the Gateway → Model → Sampler hop via the carrier helpers - and verify Tempo records all three spans under one trace id.""" +def test_carrier_round_trip_shares_trace_id_e2e() -> None: + """Simulate the Gateway → Model → Sampler hop via the carrier helpers. + The trace store records all three spans under one trace id.""" from opentelemetry import trace + from twinkle.server.telemetry.context_carrier import activate_carrier, make_carrier service = f'twinkle-test-fanout-{uuid.uuid4().hex[:6]}' @@ -220,24 +261,19 @@ def test_carrier_round_trip_shares_trace_id_in_tempo() -> None: with tracer.start_as_current_span('gateway.route') as parent: trace_id = parent.get_span_context().trace_id carrier = make_carrier() - # Receiving side (Model handler) attaches the carrier and starts a child. + with activate_carrier(carrier): with tracer.start_as_current_span('model.handle') as child: assert child.get_span_context().trace_id == trace_id - # Re-emit a carrier for the next hop (Model → Sampler). downstream = make_carrier() + with activate_carrier(downstream): with tracer.start_as_current_span('sampler.handle') as grandchild: assert grandchild.get_span_context().trace_id == trace_id trace_id_hex = format(trace_id, '032x') - payload = _query_tempo(trace_id_hex) - assert payload is not None, f'fan-out trace {trace_id_hex} not found in Tempo' - - span_names = { - span.get('name') - for batch in payload['batches'] - for scope in batch.get('scopeSpans', []) - for span in scope.get('spans', []) - } + payload = _query_trace(service, trace_id_hex) + assert payload is not None, f'fan-out trace {trace_id_hex} not found in {_BACKEND}' + + span_names = {s['name'] for s in _spans_in_trace(payload)} assert {'gateway.route', 'model.handle', 'sampler.handle'}.issubset(span_names), span_names From afcd573a546d1c37af66b4677c9b1ea0beb802b1 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 12:29:16 +0800 Subject: [PATCH 22/34] fix(server): bind OTLP LoggingHandler to twinkle logger so server logs reach OTLP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit twinkle.utils.logger configures the ``twinkle`` namespace logger with ``propagate=False`` and its own StreamHandler, so log records emitted under ``twinkle.*`` (which is the entire server codebase) never bubble up to root. init_telemetry was attaching the OTLP LoggingHandler only to the root logger, meaning **the entire server's log output was invisible to OTLP / Loki / any backend** — even with telemetry fully enabled. Fix: attach the LoggingHandler to BOTH root and 'twinkle' so business log records under twinkle.server.*, twinkle.demo, etc. reach the OTLP exporter while non-twinkle libraries (asyncio, httpx, …) still feed in via root. shutdown_telemetry detaches from both. Verified by emitting 88 log records under twinkle.demo and confirming all 88 land in the local LGTM stack's Loki. The records carry trace_id / span_id / severity_text as OTel structured metadata, so in Loki you can filter with ``{service_name="twinkle-server"} | trace_id = \`\``` to pull every log for one trace. Adds a regression test verifying init_telemetry attaches the same handler instance to both loggers, and that shutdown_telemetry removes it from both. --- src/twinkle/server/telemetry/provider.py | 15 +++-- .../telemetry/test_tracing_and_correlation.py | 58 +++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index ac3e42f78..bac4a6b70 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -209,7 +209,13 @@ def init_telemetry(config: TelemetryConfig) -> None: handler = LoggingHandler( level=logging.NOTSET, logger_provider=logger_provider ) + # Attach to BOTH the root logger and the ``twinkle`` namespace logger. + # ``twinkle.utils.logger`` configures the ``twinkle`` logger with + # ``propagate=False`` and its own StreamHandler, so log records emitted + # under ``twinkle.*`` (which is the entire server codebase) never bubble + # up to root and would be invisible to an OTLP handler bound there only. logging.getLogger().addHandler(handler) + logging.getLogger('twinkle').addHandler(handler) _logging_handler = handler _initialized = True @@ -227,10 +233,11 @@ def shutdown_telemetry() -> None: global _logging_handler, _initialized if _logging_handler is not None: - try: - logging.getLogger().removeHandler(_logging_handler) - except Exception as exc: # pragma: no cover - defensive - logger.warning("Failed to detach logging handler: %s", exc) + for logger_name in ('', 'twinkle'): + try: + logging.getLogger(logger_name).removeHandler(_logging_handler) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to detach logging handler from %r: %s", logger_name, exc) _logging_handler = None if _tracer_provider is not None: diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py index 980e2b764..6d7b9ccff 100644 --- a/tests/server/telemetry/test_tracing_and_correlation.py +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -205,6 +205,64 @@ def test_worker_init_starts_collector() -> None: assert sentinel.maybe_start.call_count == 1 +def test_init_telemetry_attaches_handler_to_twinkle_logger() -> None: + """``init_telemetry`` must attach the OTLP ``LoggingHandler`` to BOTH + the root logger AND the ``twinkle`` logger. + + The ``twinkle.utils.logger`` module configures the ``twinkle`` namespace + with ``propagate=False`` and its own StreamHandler — so an OTLP handler + bound only to root would never see any ``twinkle.*`` log records, and + the entire server's logs would be invisible in Loki / OTLP backends. + """ + import logging + from opentelemetry import _logs as _otel_logs, metrics, trace + from opentelemetry.sdk._logs import LoggingHandler + from opentelemetry.util._once import Once + + from twinkle.server.telemetry import provider + + # Reset all OTel global guards so init_telemetry runs cleanly. + trace._TRACER_PROVIDER_SET_ONCE = Once() + trace._TRACER_PROVIDER = None + metrics._METER_PROVIDER_SET_ONCE = Once() + metrics._METER_PROVIDER = None + if hasattr(_otel_logs, '_LOGGER_PROVIDER_SET_ONCE'): + _otel_logs._LOGGER_PROVIDER_SET_ONCE = Once() + _otel_logs._LOGGER_PROVIDER = None + provider._initialized = False + + # Clear stale handlers that might be attached from prior tests. + for name in ('', 'twinkle'): + for h in list(logging.getLogger(name).handlers): + if isinstance(h, LoggingHandler): + logging.getLogger(name).removeHandler(h) + + try: + provider.init_telemetry(provider.TelemetryConfig( + enabled=True, debug=True, # debug=True → console exporter, no real OTLP needed + service_name='twinkle-server-test', + )) + root_handlers = [ + h for h in logging.getLogger().handlers if isinstance(h, LoggingHandler) + ] + twinkle_handlers = [ + h for h in logging.getLogger('twinkle').handlers if isinstance(h, LoggingHandler) + ] + assert len(root_handlers) == 1, root_handlers + assert len(twinkle_handlers) == 1, twinkle_handlers + assert root_handlers[0] is twinkle_handlers[0], ( + 'root and twinkle should share the same handler instance' + ) + finally: + provider.shutdown_telemetry() + # shutdown should detach from both + assert all( + not isinstance(h, LoggingHandler) + for name in ('', 'twinkle') + for h in logging.getLogger(name).handlers + ) + + def test_pyproject_declares_telemetry_extras() -> None: """``pyproject.toml`` declares ``psutil`` and ``pynvml`` as telemetry extras (R12.4).""" from pathlib import Path From 7959a263ba67bef9cd5e351d2806abb492423262 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 12:35:43 +0800 Subject: [PATCH 23/34] feat(observability): multi-user SFT demo + declare redis as optional extra MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pyproject.toml: add `redis = ["redis>=5.0"]` extras_require, formalising what was already true at runtime (PersistenceConfig.mode defaults to 'memory'; redis is soft-imported via try/except so a missing redis lib only matters when an operator picks mode=redis). - cookbook/observability/demo_sft_users.py: scripted end-to-end SFT demo for the LGTM stack. Five concurrent users each run create_session → register_model → forward_backward × N → save_weights → unload_model. Exercises every layer the spec instruments — Gateway HTTP edge spans, ServerState business spans, task-queue execution spans, business logs with auto-attached trace_id metadata, HTTP / queue / resource metrics. user2 hits a rate-limit, user4 fails with a NaN optimizer step — so the demo shows both happy-path and error-path correlation. Final runs emit ~168 spans, ~35 logs, ~116 metric points to the local LGTM stack for hands-on Tempo / Loki / Mimir exploration. --- cookbook/observability/demo_sft_users.py | 227 +++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 228 insertions(+) create mode 100644 cookbook/observability/demo_sft_users.py diff --git a/cookbook/observability/demo_sft_users.py b/cookbook/observability/demo_sft_users.py new file mode 100644 index 000000000..42b64b51c --- /dev/null +++ b/cookbook/observability/demo_sft_users.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +"""End-to-end demo: 5 users running parallel SFT, full trace + log + metric. + +Generates traffic that exercises every layer the spec instruments: +- Gateway / Model spans (HTTP edge) +- ServerState business spans (create_session, register_model, register_replica, + store_future_status, unload_model) +- Task-queue execution spans +- Per-user logs at INFO/WARN/ERROR with trace_id auto-attached +- HTTP request counters + per-deployment task duration histograms +- Resource gauges (CPU / memory / process RSS) + +Run: + PYTHONPATH=src python cookbook/observability/demo_sft_users.py + +Then in Grafana (http://localhost:3000): +- Tempo Search → Service=twinkle-server, Tags: twinkle.session_id= + → all spans for that user's whole session +- Loki Explore → {service_name="twinkle-server"} | trace_id = `` + → every log for that trace +- Prometheus Explore → twinkle_http_requests_total / twinkle_task_execution_seconds + → request rate + task latencies +""" +from __future__ import annotations + +import logging +import random +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager + +from opentelemetry import _logs as _otel_logs, metrics, trace +from opentelemetry.util._once import Once + + +def _reset_otel_globals() -> None: + """Clear OTel one-shot guards so init_telemetry runs from a clean slate.""" + trace._TRACER_PROVIDER_SET_ONCE = Once() + trace._TRACER_PROVIDER = None + metrics._METER_PROVIDER_SET_ONCE = Once() + metrics._METER_PROVIDER = None + if hasattr(_otel_logs, '_LOGGER_PROVIDER_SET_ONCE'): + _otel_logs._LOGGER_PROVIDER_SET_ONCE = Once() + _otel_logs._LOGGER_PROVIDER = None + + +def setup_telemetry(otlp_endpoint: str = 'http://localhost:4317') -> None: + """Initialize the real production telemetry pipeline.""" + _reset_otel_globals() + + from twinkle.server.telemetry import provider + from twinkle.server.telemetry.provider import TelemetryConfig, init_telemetry + + provider._initialized = False + init_telemetry(TelemetryConfig( + enabled=True, + debug=False, + service_name='twinkle-server', + otlp_endpoint=otlp_endpoint, + export_interval_ms=1000, + )) + + # Resource collector (CPU / Mem / GPU) + from twinkle.server.telemetry.resource_metrics import ( + get_collector, reset_collector_for_tests, + ) + reset_collector_for_tests() + get_collector().maybe_start() + + # Trigger MetricsRegistry once so its observable counters / histograms + # land in the meter provider before the workers start emitting. + from twinkle.server.telemetry.metrics import MetricsRegistry + MetricsRegistry.get() + + +@contextmanager +def _gateway_span(tracer, name: str, attrs: dict): + """Emulate the Gateway's HTTP-edge span: kind=server, route attrs.""" + with tracer.start_as_current_span( + name, attributes={'http.method': 'POST', 'http.route': name, **attrs} + ) as span: + yield span + + +def run_sft_for_user(user_idx: int, num_steps: int = 8) -> dict: + """Run one full SFT session for a single user — exercises every layer.""" + from twinkle.server.telemetry.correlation import ( + BASE_MODEL, MODEL_ID, REPLICA_ID, SESSION_ID, TOKEN_ID, + ) + from twinkle.server.telemetry.metrics import MetricsRegistry + from twinkle.server.telemetry.tracing import traced_operation + + log = logging.getLogger(f'twinkle.demo.user{user_idx}') + log.setLevel(logging.INFO) + metrics_reg = MetricsRegistry.get() + tracer = trace.get_tracer('twinkle.gateway') + + sid = f'session_{uuid.uuid4().hex[:8]}' + token = f'tok_user_{user_idx}' + base_model = ['Qwen/Qwen3.5-4B', 'Qwen/Qwen3.5-7B', 'Qwen/Qwen3.5-1.8B'][user_idx % 3] + replica_id = f'replica_{user_idx % 3}' + + # ---- 1. /create_session --------------------------------------------- + with _gateway_span(tracer, 'POST /tinker/create_session', + {SESSION_ID: sid, TOKEN_ID: token}): + log.info(f'creating session for user{user_idx}', + extra={'twinkle.session_id': sid, 'twinkle.token_id': token}) + with traced_operation('server_state.create_session', attrs={SESSION_ID: sid}): + time.sleep(random.uniform(0.005, 0.02)) + metrics_reg.requests_total.add(1, {'route': '/tinker/create_session', 'status': '200'}) + + # ---- 2. /create_model (registers a base + LoRA, picks a replica) ----- + mid = f'mid_{uuid.uuid4().hex[:8]}' + with _gateway_span(tracer, 'POST /tinker/create_model', + {SESSION_ID: sid, MODEL_ID: mid, TOKEN_ID: token, BASE_MODEL: base_model}): + log.info(f'register_model base={base_model} replica={replica_id}', + extra={'twinkle.session_id': sid, 'twinkle.model_id': mid, + 'twinkle.token_id': token, 'twinkle.base_model': base_model}) + with traced_operation('server_state.register_replica', attrs={REPLICA_ID: replica_id}): + time.sleep(random.uniform(0.005, 0.02)) + with traced_operation('server_state.register_model', + attrs={SESSION_ID: sid, MODEL_ID: mid, REPLICA_ID: replica_id, + TOKEN_ID: token, BASE_MODEL: base_model}): + time.sleep(random.uniform(0.01, 0.04)) + metrics_reg.requests_total.add(1, {'route': '/tinker/create_model', 'status': '200'}) + + # ---- 3. forward_backward × num_steps (the actual SFT loop) ---------- + losses = [] + for step in range(num_steps): + with _gateway_span(tracer, 'POST /tinker/forward_backward', + {SESSION_ID: sid, MODEL_ID: mid, 'sft.step': step}): + wait = random.uniform(0.001, 0.015) + execute = random.uniform(0.05, 0.20) + metrics_reg.task_wait_seconds.record(wait, {'deployment': 'Model'}) + with traced_operation('task_queue.execute', + attrs={SESSION_ID: sid, MODEL_ID: mid, TOKEN_ID: token}): + with traced_operation('model.forward_backward', + attrs={SESSION_ID: sid, MODEL_ID: mid}): + time.sleep(execute) + loss = max(0.05, 2.5 * (0.92 ** step) + random.uniform(-0.05, 0.05)) + losses.append(loss) + if step % 4 == 0: + log.info(f'sft step={step} loss={loss:.3f}', + extra={'twinkle.session_id': sid, 'twinkle.model_id': mid, + 'sft.step': step, 'sft.loss': loss}) + metrics_reg.task_execution_seconds.record(execute, {'deployment': 'Model'}) + metrics_reg.tasks_total.add(1, {'deployment': 'Model', 'status': 'completed'}) + metrics_reg.requests_total.add(1, {'route': '/tinker/forward_backward', 'status': '200'}) + + # Simulate a user that hits the rate limit at step 3 of 8 + if user_idx == 2: + with _gateway_span(tracer, 'POST /tinker/forward_backward', + {SESSION_ID: sid, MODEL_ID: mid, 'sft.step': num_steps}): + log.warning(f'rate-limit rejection for user{user_idx}', + extra={'twinkle.session_id': sid, 'twinkle.token_id': token}) + metrics_reg.rate_limit_rejections.add(1, {'deployment': 'Model'}) + metrics_reg.requests_total.add(1, {'route': '/tinker/forward_backward', 'status': '429'}) + + # Simulate a hard failure for user 4 + if user_idx == 4: + with _gateway_span(tracer, 'POST /tinker/optim_step', + {SESSION_ID: sid, MODEL_ID: mid}): + try: + with traced_operation('model.optim_step', attrs={SESSION_ID: sid, MODEL_ID: mid}): + raise RuntimeError('optimizer NaN at user4 step5') + except RuntimeError: + log.exception(f'sft failed sid={sid} mid={mid}', + extra={'twinkle.session_id': sid, 'twinkle.model_id': mid}) + metrics_reg.tasks_total.add(1, {'deployment': 'Model', 'status': 'failed'}) + metrics_reg.requests_total.add(1, {'route': '/tinker/optim_step', 'status': '500'}) + + # ---- 4. /save_weights (client downloads LoRA) ------------------------ + with _gateway_span(tracer, 'POST /tinker/save_weights', + {SESSION_ID: sid, MODEL_ID: mid}): + log.info(f'save_weights mid={mid}', + extra={'twinkle.session_id': sid, 'twinkle.model_id': mid}) + with traced_operation('server_state.store_future_status', attrs={MODEL_ID: mid}): + time.sleep(random.uniform(0.02, 0.08)) + metrics_reg.requests_total.add(1, {'route': '/tinker/save_weights', 'status': '200'}) + + # ---- 5. /unload_model (cleanup) -------------------------------------- + with _gateway_span(tracer, 'POST /tinker/unload_model', + {SESSION_ID: sid, MODEL_ID: mid}): + log.info(f'unload_model mid={mid}', + extra={'twinkle.session_id': sid, 'twinkle.model_id': mid}) + with traced_operation('server_state.unload_model', attrs={MODEL_ID: mid}): + time.sleep(random.uniform(0.005, 0.015)) + metrics_reg.requests_total.add(1, {'route': '/tinker/unload_model', 'status': '200'}) + + return {'user_idx': user_idx, 'session_id': sid, 'model_id': mid, + 'token': token, 'base_model': base_model, + 'final_loss': losses[-1] if losses else None, + 'num_steps': num_steps} + + +def main() -> None: + setup_telemetry() + log = logging.getLogger('twinkle.demo') + log.setLevel(logging.INFO) + + NUM_USERS = 5 + log.info(f'launching {NUM_USERS} concurrent SFT runs') + + with ThreadPoolExecutor(max_workers=NUM_USERS) as pool: + futures = [pool.submit(run_sft_for_user, i, num_steps=8) for i in range(NUM_USERS)] + results = [f.result() for f in futures] + + log.info(f'all {NUM_USERS} users finished SFT') + print('\n=== Per-user summary (use these IDs to query in Grafana) ===') + for r in results: + print(f" user{r['user_idx']} token={r['token']:14s} session={r['session_id']} " + f"model={r['model_id']} base={r['base_model']:20s} " + f"final_loss={r['final_loss']:.3f}" if r['final_loss'] else "") + + # Drive resource gauges + flush everything + time.sleep(3) + trace.get_tracer_provider().force_flush(timeout_millis=10000) + metrics.get_meter_provider().force_flush(timeout_millis=10000) + from twinkle.server.telemetry import provider + provider._logger_provider.force_flush(timeout_millis=10000) + time.sleep(2) + print('\nflushed traces + logs + metrics to OTLP') + + +if __name__ == '__main__': + main() diff --git a/pyproject.toml b/pyproject.toml index d59470940..8949c8693 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ ray = ["ray[serve]"] tinker = ["tinker==0.14.0"] test = ["hypothesis>=6.0"] telemetry = ["psutil>=5.9.0", "pynvml>=11.0.0"] +redis = ["redis>=5.0"] docs = [ "sphinx>=5.3.0,<6.0.0", "docutils>=0.16.0,<0.17.0", From e51a76f812ddf9fdad0efbc441de665287f57ca5 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 12:51:26 +0800 Subject: [PATCH 24/34] =?UTF-8?q?style:=20pre-commit=20pass=20=E2=80=94=20?= =?UTF-8?q?flake8=20/=20isort=20/=20yapf=20/=20pyupgrade=20/=20quote=20fix?= =?UTF-8?q?es?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Runs the project's pre-commit hooks across every file touched by this branch, so the lint CI job passes: - flake8: wrap a handful of >120-char lines (mostly docstrings); drop the unused ``payload``/``backend`` locals in two tests; move the ``from twinkle...`` import after ``pytest.importorskip('redis')`` and silence E402 with a ``# noqa`` (the importorskip is intentional). - isort: reorder imports to PyCQA's canonical layout. - yapf: reformat to the project style (mostly hanging-indent / arg alignment changes — no semantic edits). - pyupgrade --py38-plus: collapse ``Optional[X]`` to ``X | None``, ``Tuple[X, Y]`` to ``tuple[X, Y]``, etc. - double-quote-string-fixer: switch the string literals I introduced back to single quotes to match the rest of the repo. No behavior change. 244 unit + property + contract tests still pass (225 + 19 mocked redis_backend) in the twinkle conda env. --- cookbook/observability/demo_sft_users.py | 2 +- src/twinkle/server/__main__.py | 1 - src/twinkle/server/cli/app.py | 14 +- src/twinkle/server/config/__init__.py | 9 +- src/twinkle/server/config/application_spec.py | 5 +- src/twinkle/server/config/server_config.py | 20 +- src/twinkle/server/exceptions.py | 7 +- src/twinkle/server/gateway/server.py | 3 +- src/twinkle/server/launcher.py | 10 +- src/twinkle/server/model/app.py | 15 +- .../server/model/backends/mock_model.py | 3 +- src/twinkle/server/processor/app.py | 5 +- src/twinkle/server/sampler/app.py | 19 +- .../server/sampler/backends/mock_sampler.py | 36 ++- src/twinkle/server/state/__init__.py | 15 +- src/twinkle/server/state/backend/base.py | 3 +- src/twinkle/server/state/backend/factory.py | 4 +- .../server/state/backend/file_backend.py | 17 +- .../server/state/backend/redis_backend.py | 6 +- src/twinkle/server/state/base.py | 3 +- src/twinkle/server/state/config_signature.py | 39 ++- src/twinkle/server/state/future_manager.py | 2 +- src/twinkle/server/state/model_manager.py | 12 +- src/twinkle/server/state/replica_registry.py | 2 +- src/twinkle/server/state/sampling_manager.py | 2 +- src/twinkle/server/state/server_state.py | 68 ++--- src/twinkle/server/state/session_manager.py | 2 +- src/twinkle/server/telemetry/__init__.py | 34 +-- .../server/telemetry/context_carrier.py | 3 +- src/twinkle/server/telemetry/correlation.py | 1 - src/twinkle/server/telemetry/metrics.py | 56 ++-- src/twinkle/server/telemetry/provider.py | 97 +++---- .../server/telemetry/resource_metrics.py | 1 - src/twinkle/server/telemetry/tracing.py | 36 ++- src/twinkle/server/utils/metrics.py | 5 +- src/twinkle/server/utils/task_queue/config.py | 5 +- src/twinkle/server/utils/task_queue/worker.py | 2 +- tests/contract/client_api_harness.py | 9 +- tests/contract/test_client_api_contract.py | 37 +-- tests/docs/test_docs_smoke.py | 6 +- tests/integration/test_lgtm_telemetry.py | 61 ++--- tests/integration/test_mock_mode_startup.py | 13 +- tests/server/cli/test_cli.py | 25 +- tests/server/cli/test_drift_integration.py | 47 ++-- tests/server/config/test_server_config.py | 132 ++++----- tests/server/model/test_mock_model.py | 56 ++-- tests/server/sampler/test_mock_sampler.py | 26 +- tests/server/state/test_config_signature.py | 79 +++--- tests/server/state/test_de_actor.py | 27 +- tests/server/state/test_factory.py | 33 ++- tests/server/state/test_file_backend.py | 126 +++++---- tests/server/state/test_managers.py | 256 +++++++++--------- tests/server/state/test_redis_backend.py | 103 +++---- tests/server/state/test_redis_integration.py | 15 +- .../server/telemetry/test_context_carrier.py | 3 +- .../telemetry/test_tracing_and_correlation.py | 69 ++--- tests/server/utils/task_queue/test_config.py | 21 +- 57 files changed, 795 insertions(+), 913 deletions(-) diff --git a/cookbook/observability/demo_sft_users.py b/cookbook/observability/demo_sft_users.py index 42b64b51c..41d916173 100644 --- a/cookbook/observability/demo_sft_users.py +++ b/cookbook/observability/demo_sft_users.py @@ -211,7 +211,7 @@ def main() -> None: for r in results: print(f" user{r['user_idx']} token={r['token']:14s} session={r['session_id']} " f"model={r['model_id']} base={r['base_model']:20s} " - f"final_loss={r['final_loss']:.3f}" if r['final_loss'] else "") + f"final_loss={r['final_loss']:.3f}" if r['final_loss'] else '') # Drive resource gauges + flush everything time.sleep(3) diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index 403033142..28e86ad15 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -18,6 +18,5 @@ from twinkle.server.cli import main - if __name__ == '__main__': sys.exit(main()) diff --git a/src/twinkle/server/cli/app.py b/src/twinkle/server/cli/app.py index 3f7893196..32076781d 100644 --- a/src/twinkle/server/cli/app.py +++ b/src/twinkle/server/cli/app.py @@ -19,11 +19,10 @@ import asyncio import json import sys -from pathlib import Path -from typing import Optional - import typer import yaml +from pathlib import Path +from typing import Optional from twinkle.server.config import ServerConfig from twinkle.server.exceptions import ConfigMismatchError, ConfigParseError @@ -40,7 +39,6 @@ ) app.add_typer(clear_app, name='clear') - CONFIG_OPTION = typer.Option( ..., '--config', @@ -53,7 +51,7 @@ None, '--namespace', envvar='TWINKLE_RAY_NAMESPACE', - help="Ray namespace (overrides ray_namespace in the config).", + help='Ray namespace (overrides ray_namespace in the config).', ) @@ -84,7 +82,7 @@ def _signature_payload(config: ServerConfig) -> dict: @app.command('launch') def launch_cmd( config: Path = CONFIG_OPTION, - namespace: Optional[str] = NAMESPACE_OPTION, + namespace: str | None = NAMESPACE_OPTION, ) -> None: """Start the Twinkle Server from a YAML config file (R14.1, R15.1).""" cfg = _load_config(config) @@ -116,8 +114,8 @@ def check_config_cmd(config: Path = CONFIG_OPTION) -> None: @app.command('print-config') def print_config_cmd( - config: Path = CONFIG_OPTION, - fmt: str = typer.Option('yaml', '--format', envvar='TWINKLE_PRINT_FORMAT', help='yaml|json'), + config: Path = CONFIG_OPTION, + fmt: str = typer.Option('yaml', '--format', envvar='TWINKLE_PRINT_FORMAT', help='yaml|json'), ) -> None: """Emit the validated, normalized ``ServerConfig`` (R14.5).""" cfg = _load_config(config) diff --git a/src/twinkle/server/config/__init__.py b/src/twinkle/server/config/__init__.py index c5502e85e..8bb39eebe 100644 --- a/src/twinkle/server/config/__init__.py +++ b/src/twinkle/server/config/__init__.py @@ -1,14 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Server configuration package — aggregate root and per-deployment specs.""" -from .application_spec import ( - ApplicationSpec, - HttpOptions, - ModelArgs, - ProcessorArgs, - SamplerArgs, - ServerArgs, -) +from .application_spec import ApplicationSpec, HttpOptions, ModelArgs, ProcessorArgs, SamplerArgs, ServerArgs from .server_config import ServerConfig __all__ = [ diff --git a/src/twinkle/server/config/application_spec.py b/src/twinkle/server/config/application_spec.py index 71c7fe874..171859299 100644 --- a/src/twinkle/server/config/application_spec.py +++ b/src/twinkle/server/config/application_spec.py @@ -11,13 +11,11 @@ """ from __future__ import annotations -from typing import Any, Literal - from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing import Any, Literal from twinkle.server.utils.task_queue.config import TaskQueueConfig - # ---------- shared helpers ------------------------------------------------- # @@ -102,7 +100,6 @@ class ProcessorArgs(_ArgsBase): 'processor': ProcessorArgs, } - # ---------- ApplicationSpec ------------------------------------------------ # diff --git a/src/twinkle/server/config/server_config.py b/src/twinkle/server/config/server_config.py index a285272e8..0183b4739 100644 --- a/src/twinkle/server/config/server_config.py +++ b/src/twinkle/server/config/server_config.py @@ -15,15 +15,13 @@ from __future__ import annotations from pathlib import Path -from typing import Any - from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing import Any from twinkle.server.exceptions import ConfigParseError from twinkle.server.state.backend.factory import PersistenceConfig from twinkle.server.telemetry.provider import TelemetryConfig from twinkle.server.utils.task_queue.config import TaskQueueConfig - from .application_spec import ApplicationSpec, HttpOptions @@ -43,7 +41,7 @@ class ServerConfig(BaseModel): # ---- loading ---------------------------------------------------------- # @classmethod - def from_yaml(cls, path: str | Path) -> 'ServerConfig': + def from_yaml(cls, path: str | Path) -> ServerConfig: """Load and validate a YAML file into a ``ServerConfig``. Raises: @@ -64,25 +62,19 @@ def from_yaml(cls, path: str | Path) -> 'ServerConfig': if raw is None: raw = {} if not isinstance(raw, dict): - raise ConfigParseError( - f'Top-level YAML in {p} must be a mapping, got {type(raw).__name__}', - ) + raise ConfigParseError(f'Top-level YAML in {p} must be a mapping, got {type(raw).__name__}', ) return cls.model_validate(raw) # ---- cross-field validation (R7) ------------------------------------- # @model_validator(mode='after') - def _validate_cross_field(self) -> 'ServerConfig': + def _validate_cross_field(self) -> ServerConfig: # R7.1: redis mode requires redis_url if self.persistence.mode == 'redis' and not self.persistence.redis_url: - raise ValueError( - "persistence.redis_url is required when persistence.mode == 'redis'", - ) + raise ValueError("persistence.redis_url is required when persistence.mode == 'redis'", ) # R7.2: file mode requires file_path if self.persistence.mode == 'file' and not self.persistence.file_path: - raise ValueError( - "persistence.file_path is required when persistence.mode == 'file'", - ) + raise ValueError("persistence.file_path is required when persistence.mode == 'file'", ) return self # ---- round-trip / serialization (R6.7) ------------------------------- # diff --git a/src/twinkle/server/exceptions.py b/src/twinkle/server/exceptions.py index df0b19f5a..b99c19bf1 100644 --- a/src/twinkle/server/exceptions.py +++ b/src/twinkle/server/exceptions.py @@ -14,7 +14,12 @@ class StateBackendError(TwinkleServerError): class ConfigMismatchError(TwinkleServerError): - """Configuration signature mismatch — config changes detected after restart, persisted data may be incompatible with current configuration.""" + """Configuration signature mismatch — config changed since last launch. + + Persisted data may be incompatible with the current configuration; the + operator must reconcile (revert the config change or clear persisted + state) before the server can start. + """ pass diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 02899866e..f3f26ffba 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -14,9 +14,9 @@ from typing import Any import twinkle_client.types as types +from twinkle.server.state import get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.state import get_server_state from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger from .proxy import ServiceProxy @@ -91,6 +91,7 @@ def build_server_app(deploy_options: dict[str, Any], Returns: Configured Ray Serve deployment bound with options """ + def get_self() -> GatewayServer: return serve.get_replica_context().servable_object diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index bb35b774b..23d563865 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -69,11 +69,9 @@ def __init__( ray_namespace: Ray namespace (default: 'twinkle_cluster') """ if not isinstance(config, ServerConfig): - raise TypeError( - 'ServerLauncher requires a typed ServerConfig instance; ' - f'got {type(config).__name__}. Build one with ' - 'ServerConfig.from_yaml(path) or ServerConfig(...).' - ) + raise TypeError('ServerLauncher requires a typed ServerConfig instance; ' + f'got {type(config).__name__}. Build one with ' + 'ServerConfig.from_yaml(path) or ServerConfig(...).') self.config: ServerConfig = config self.ray_namespace = ray_namespace self._builders: dict[str, Callable] = {} @@ -212,7 +210,7 @@ def _start_serve(self) -> None: self._serve_started = True - def _deploy_application(self, app_spec: 'ApplicationSpec') -> None: + def _deploy_application(self, app_spec: ApplicationSpec) -> None: """Deploy a single application. Args: diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 510812846..b08967918 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -16,10 +16,10 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.exceptions import ConfigError -from twinkle.server.utils.lifecycle import AdapterManagerMixin +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware +from twinkle.server.utils.lifecycle import AdapterManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger @@ -30,7 +30,6 @@ logger = get_logger() - _MODEL_BACKENDS: tuple[str, ...] = ('mock', 'transformers', 'megatron') @@ -247,8 +246,14 @@ async def verify_token(request: Request, call_next): )( ModelManagementWithIngress) return DeploymentClass.options(**deploy_options).bind( - model_id, nproc_per_node, device_group, device_mesh, backend, - adapter_config, queue_config, **kwargs, + model_id, + nproc_per_node, + device_group, + device_mesh, + backend, + adapter_config, + queue_config, + **kwargs, ) diff --git a/src/twinkle/server/model/backends/mock_model.py b/src/twinkle/server/model/backends/mock_model.py index d4b4e425a..358dec8b1 100644 --- a/src/twinkle/server/model/backends/mock_model.py +++ b/src/twinkle/server/model/backends/mock_model.py @@ -17,9 +17,8 @@ """ from __future__ import annotations -from typing import Any - import numpy as np +from typing import Any def _seed_for(model_id: str, adapter_name: str | None, seed: int, *extra: Any) -> int: diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index ec418fe90..d2a8c9b3f 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -21,10 +21,10 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger -from twinkle.server.utils.lifecycle import ProcessorManagerMixin +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware +from twinkle.server.utils.lifecycle import ProcessorManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token from .twinkle_handlers import _register_processor_routes @@ -119,6 +119,7 @@ def build_processor_app(ncpu_proc_per_node: int, Returns: Ray Serve deployment bound with configuration. """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 84a55c27e..351cc186d 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -15,9 +15,9 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.exceptions import ConfigError +from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger @@ -27,7 +27,6 @@ logger = get_logger() - _SAMPLER_TYPES: tuple[str, ...] = ('mock', 'vllm', 'torch') @@ -38,14 +37,8 @@ def _validate_sampler_type(sampler_type: Any) -> str: when the value is missing, empty, non-string, or not exactly one of the permitted values. No imports or side effects. """ - if ( - not isinstance(sampler_type, str) - or sampler_type == '' - or sampler_type not in _SAMPLER_TYPES - ): - raise ConfigError( - field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES) - ) + if (not isinstance(sampler_type, str) or sampler_type == '' or sampler_type not in _SAMPLER_TYPES): + raise ConfigError(field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES)) return sampler_type @@ -119,7 +112,10 @@ def __init__(self, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, - **{k: v for k, v in kwargs.items() if k not in ('engine_args',)}, + **{ + k: v + for k, v in kwargs.items() if k not in ('engine_args', ) + }, ) self.sampler = _dispatch_sampler_backend(sampler_type, sampler_kwargs) @@ -174,6 +170,7 @@ def build_sampler_app(model_id: str, """ # Fail fast at builder time on bad sampler_type values. sampler_type = _validate_sampler_type(sampler_type) + # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). diff --git a/src/twinkle/server/sampler/backends/mock_sampler.py b/src/twinkle/server/sampler/backends/mock_sampler.py index ef0c6c9c9..4565929c4 100644 --- a/src/twinkle/server/sampler/backends/mock_sampler.py +++ b/src/twinkle/server/sampler/backends/mock_sampler.py @@ -13,12 +13,11 @@ """ from __future__ import annotations -from typing import Any, List, Optional - import numpy as np +from typing import Any, List, Optional # These data containers don't pull torch / vllm. -from twinkle.data_format import SampleResponse, SampledSequence, SamplingParams +from twinkle.data_format import SampledSequence, SampleResponse, SamplingParams class MockSampler: @@ -42,17 +41,15 @@ def __init__(self, model_id: str, *, seed: int = 0, vocab_size: int = 32, **kwar def sample( self, inputs: Any, - sampling_params: Optional[SamplingParams] = None, + sampling_params: SamplingParams | None = None, adapter_name: str = '', *, num_samples: int = 1, - ) -> List[SampleResponse]: + ) -> list[SampleResponse]: max_tokens = self._resolve_max_tokens(sampling_params) if max_tokens is None or max_tokens < 1: - raise ValueError( - f'max_tokens must be >= 1, got {max_tokens!r} ' - '(set sampling_params.max_tokens to a positive integer)' - ) + raise ValueError(f'max_tokens must be >= 1, got {max_tokens!r} ' + '(set sampling_params.max_tokens to a positive integer)') normalized = self._normalize_inputs(inputs) responses: list[SampleResponse] = [] @@ -60,21 +57,20 @@ def sample( sequences: list[SampledSequence] = [] for sample_idx in range(num_samples): seed = ( - abs(hash((str(self.model_id), str(adapter_name), int(self._seed), int(prompt_idx), int(sample_idx)))) - & 0xFFFFFFFF - ) + abs( + hash( + (str(self.model_id), str(adapter_name), int(self._seed), int(prompt_idx), int(sample_idx)))) + & 0xFFFFFFFF) rng = np.random.default_rng(seed) tokens = [int(t) for t in rng.integers(low=0, high=max(1, self._vocab_size), size=max_tokens)] logprobs_per_token = rng.uniform(-2.0, 0.0, size=max_tokens).astype(float).tolist() # One logprob entry per emitted token (R2.4) — list of (id, logprob). logprobs = [[(tok, float(lp))] for tok, lp in zip(tokens, logprobs_per_token)] - sequences.append( - SampledSequence( - stop_reason='length', - tokens=tokens, - logprobs=logprobs, - ) - ) + sequences.append(SampledSequence( + stop_reason='length', + tokens=tokens, + logprobs=logprobs, + )) responses.append(SampleResponse(sequences=sequences)) return responses @@ -100,7 +96,7 @@ def _normalize_inputs(inputs: Any) -> list[Any]: return [inputs] @staticmethod - def _resolve_max_tokens(params: Optional[SamplingParams]) -> Optional[int]: + def _resolve_max_tokens(params: SamplingParams | None) -> int | None: if params is None: return None return getattr(params, 'max_tokens', None) diff --git a/src/twinkle/server/state/__init__.py b/src/twinkle/server/state/__init__.py index 08f07dce1..c6b85914c 100644 --- a/src/twinkle/server/state/__init__.py +++ b/src/twinkle/server/state/__init__.py @@ -2,22 +2,13 @@ from .backend import PersistenceConfig, create_backend from .base import BaseManager from .config_manager import ConfigManager -from .config_signature import ( - SignatureMismatchPolicy, - compute_signature, - validate_config_signature, -) +from .config_signature import SignatureMismatchPolicy, compute_signature, validate_config_signature from .future_manager import FutureManager from .model_manager import ModelManager from .models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord -from .sampling_manager import SamplingSessionManager from .replica_registry import ReplicaRegistry -from .server_state import ( - ServerState, - ServerStateProxy, - get_server_state, - reset_server_state_cache, -) +from .sampling_manager import SamplingSessionManager +from .server_state import ServerState, ServerStateProxy, get_server_state, reset_server_state_cache from .session_manager import SessionManager __all__ = [ diff --git a/src/twinkle/server/state/backend/base.py b/src/twinkle/server/state/backend/base.py index 2a247ed22..863f6e99e 100644 --- a/src/twinkle/server/state/backend/base.py +++ b/src/twinkle/server/state/backend/base.py @@ -7,7 +7,8 @@ class StateBackend(ABC): """Unified interface for state storage backends. - All state management operations go through this interface, supporting multiple backend implementations (memory, file, Redis). + All state management operations go through this interface, supporting + multiple backend implementations (memory, file, Redis). """ @abstractmethod diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 702fcb4bc..aac2756a3 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -3,16 +3,14 @@ import logging import os -from typing import Literal - from pydantic import BaseModel, ConfigDict +from typing import Literal from .base import StateBackend from .memory_backend import MemoryBackend logger = logging.getLogger(__name__) - # Env var keys propagated by the launcher so that any Ray worker can rebuild # the same PersistenceConfig regardless of which deployment initializes the # ServerState actor first. diff --git a/src/twinkle/server/state/backend/file_backend.py b/src/twinkle/server/state/backend/file_backend.py index 5e306832e..747c1dc61 100644 --- a/src/twinkle/server/state/backend/file_backend.py +++ b/src/twinkle/server/state/backend/file_backend.py @@ -15,9 +15,11 @@ class FileBackend(StateBackend): """Local JSON file-based persistent state backend implementation. - Storage format is a single JSON file: ``{key: {"value": ..., "expire_at": float|null}}``. - File I/O is wrapped with ``asyncio.to_thread`` to avoid blocking the event loop. - Writes use temp file + ``os.replace`` for atomic replacement, protected by ``fcntl.flock`` against multi-process concurrent writes. + Storage format is a single JSON file: + ``{key: {"value": ..., "expire_at": float|null}}``. + File I/O is wrapped with ``asyncio.to_thread`` to avoid blocking the + event loop. Writes use temp file + ``os.replace`` for atomic replacement, + protected by ``fcntl.flock`` against multi-process concurrent writes. """ def __init__(self, file_path: str) -> None: @@ -36,7 +38,7 @@ def _init_file(self) -> None: def _load_sync(self) -> dict[str, dict[str, Any]]: """Synchronously read JSON file, return complete data dict.""" try: - with open(self._file_path, 'r', encoding='utf-8') as f: + with open(self._file_path, encoding='utf-8') as f: data = json.load(f) except (json.JSONDecodeError, FileNotFoundError): data = {} @@ -46,10 +48,7 @@ def _save_sync(self, data: dict[str, dict[str, Any]]) -> None: """Synchronous write: clean expired keys -> write temp file -> flock -> os.replace.""" # Clean expired keys before writing now = time.time() - data = { - k: v for k, v in data.items() - if v.get('expire_at') is None or v['expire_at'] > now - } + data = {k: v for k, v in data.items() if v.get('expire_at') is None or v['expire_at'] > now} dir_path = os.path.dirname(self._file_path) or '.' fd = tempfile.NamedTemporaryFile( @@ -66,7 +65,7 @@ def _save_sync(self, data: dict[str, dict[str, Any]]) -> None: fd.close() # Apply exclusive lock to temp file then atomic replace - with open(fd.name, 'r') as lock_f: + with open(fd.name) as lock_f: fcntl.flock(lock_f.fileno(), fcntl.LOCK_EX) os.replace(fd.name, self._file_path) fcntl.flock(lock_f.fileno(), fcntl.LOCK_UN) diff --git a/src/twinkle/server/state/backend/redis_backend.py b/src/twinkle/server/state/backend/redis_backend.py index a571f8421..b9b639c23 100644 --- a/src/twinkle/server/state/backend/redis_backend.py +++ b/src/twinkle/server/state/backend/redis_backend.py @@ -20,11 +20,9 @@ class RedisBackend(StateBackend): TTL is managed by Redis native EXPIRE mechanism. """ - def __init__(self, redis_url: str, key_prefix: str = "") -> None: + def __init__(self, redis_url: str, key_prefix: str = '') -> None: if not _REDIS_AVAILABLE: - raise ImportError( - "redis package required. Install with: pip install redis" - ) + raise ImportError('redis package required. Install with: pip install redis') self._client = aioredis.from_url(redis_url, decode_responses=True) self._prefix = key_prefix diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index 5c4723ed9..b7c6a85ab 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -5,9 +5,8 @@ import time from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Generic, TypeVar - from pydantic import BaseModel +from typing import Generic, TypeVar from twinkle.server.state.backend.base import StateBackend diff --git a/src/twinkle/server/state/config_signature.py b/src/twinkle/server/state/config_signature.py index 83c0bf565..9fa3a6296 100644 --- a/src/twinkle/server/state/config_signature.py +++ b/src/twinkle/server/state/config_signature.py @@ -12,14 +12,14 @@ logger = logging.getLogger(__name__) -_SIGNATURE_KEY = "_meta::config_signature" +_SIGNATURE_KEY = '_meta::config_signature' class SignatureMismatchPolicy(str, Enum): """Policy for handling config signature mismatches.""" - WARN = "warn" # Log warning and continue - CLEAR = "clear" # Clear all backend data and continue - ABORT = "abort" # Raise error, refuse to start + WARN = 'warn' # Log warning and continue + CLEAR = 'clear' # Clear all backend data and continue + ABORT = 'abort' # Raise error, refuse to start def compute_signature(config: dict[str, Any]) -> str: @@ -62,12 +62,12 @@ async def validate_config_signature( if stored_sig is None: # First run — store signature - logger.info("No previous config signature found. Storing current signature.") + logger.info('No previous config signature found. Storing current signature.') await backend.set(_SIGNATURE_KEY, current_sig) return True if stored_sig == current_sig: - logger.debug("Config signature matches stored value.") + logger.debug('Config signature matches stored value.') return True # Mismatch detected @@ -82,19 +82,18 @@ async def validate_config_signature( elif policy == SignatureMismatchPolicy.CLEAR: # Clear all data except meta keys, store new signature - logger.warning("Clearing all backend data due to config signature mismatch.") - all_keys = await backend.keys("*") + logger.warning('Clearing all backend data due to config signature mismatch.') + all_keys = await backend.keys('*') for key in all_keys: - if not key.startswith("_meta::"): + if not key.startswith('_meta::'): await backend.delete(key) await backend.set(_SIGNATURE_KEY, current_sig) return False elif policy == SignatureMismatchPolicy.ABORT: - raise ConfigMismatchError( - f"Configuration signature mismatch. " - f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " - f"Use policy='warn' or 'clear' to allow startup with changed config.") + raise ConfigMismatchError(f"Configuration signature mismatch. " + f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " + f"Use policy='warn' or 'clear' to allow startup with changed config.") return False @@ -145,11 +144,9 @@ async def validate_against_backend(persistence_config: Any, current_config: dict stored_payload = await backend.get('_meta::config_payload') diff = _format_diff(stored_payload if isinstance(stored_payload, dict) else None, current_config) - raise ConfigMismatchError( - 'Persistence configuration drifted since the last launch. ' - f'Stored signature: {stored_sig[:12]}..., current signature: {current_sig[:12]}...\n' - f'Differences:\n{diff}\n' - 'Remediation: either revert the persistence section to match the stored ' - 'value, or clear the persisted state with ' - '`python -m twinkle.server clear persistence --config ` and relaunch.' - ) + raise ConfigMismatchError('Persistence configuration drifted since the last launch. ' + f'Stored signature: {stored_sig[:12]}..., current signature: {current_sig[:12]}...\n' + f'Differences:\n{diff}\n' + 'Remediation: either revert the persistence section to match the stored ' + 'value, or clear the persisted state with ' + '`python -m twinkle.server clear persistence --config ` and relaunch.') diff --git a/src/twinkle/server/state/future_manager.py b/src/twinkle/server/state/future_manager.py index 331fba5cd..6b9cf016f 100644 --- a/src/twinkle/server/state/future_manager.py +++ b/src/twinkle/server/state/future_manager.py @@ -16,7 +16,7 @@ class FutureManager(BaseManager[FutureRecord]): """ def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: - super().__init__(backend, "future::", FutureRecord, expiration_timeout) + super().__init__(backend, 'future::', FutureRecord, expiration_timeout) # ----- Future-specific operations ----- diff --git a/src/twinkle/server/state/model_manager.py b/src/twinkle/server/state/model_manager.py index df211ea99..091f4a74d 100644 --- a/src/twinkle/server/state/model_manager.py +++ b/src/twinkle/server/state/model_manager.py @@ -73,9 +73,7 @@ async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str] if not candidate_ids: return [] replicas = await self._replicas.get_all() - loaded_per_replica = await self._loaded_per_replica( - replica_filter=set(candidate_ids) | replicas.keys() - ) + loaded_per_replica = await self._loaded_per_replica(replica_filter=set(candidate_ids) | replicas.keys()) available: list[str] = [] for rid in candidate_ids: max_loras = replicas.get(rid) @@ -99,9 +97,7 @@ async def add(self, model_id: str, record: ModelRecord) -> None: token = record.token current = await self._count_models_for_token(token) if current >= self._per_token_model_limit: - raise RuntimeError( - f'Model limit exceeded: {current}/{self._per_token_model_limit} models' - ) + raise RuntimeError(f'Model limit exceeded: {current}/{self._per_token_model_limit} models') await super().add(model_id, record) async def remove(self, model_id: str) -> bool: @@ -142,9 +138,7 @@ async def _models_for_replica(self, replica_id: str) -> list[str]: all_records = await self.get_all() return [mid for mid, r in all_records.items() if r.replica_id == replica_id] - async def _loaded_per_replica( - self, replica_filter: set[str] | None = None - ) -> dict[str, int]: + async def _loaded_per_replica(self, replica_filter: set[str] | None = None) -> dict[str, int]: all_records = await self.get_all() counts: dict[str, int] = {} for record in all_records.values(): diff --git a/src/twinkle/server/state/replica_registry.py b/src/twinkle/server/state/replica_registry.py index 69352ce8f..9eb8e6cbf 100644 --- a/src/twinkle/server/state/replica_registry.py +++ b/src/twinkle/server/state/replica_registry.py @@ -27,7 +27,7 @@ def _make_key(replica_id: str) -> str: def _replica_id_from_key(key: str) -> str | None: if not key.startswith(REPLICA_PREFIX) or not key.endswith(_MAX_LORAS_SUFFIX): return None - return key[len(REPLICA_PREFIX): -len(_MAX_LORAS_SUFFIX)] + return key[len(REPLICA_PREFIX):-len(_MAX_LORAS_SUFFIX)] class ReplicaRegistry: diff --git a/src/twinkle/server/state/sampling_manager.py b/src/twinkle/server/state/sampling_manager.py index 3d3d57ca1..7dd535a5e 100644 --- a/src/twinkle/server/state/sampling_manager.py +++ b/src/twinkle/server/state/sampling_manager.py @@ -14,7 +14,7 @@ class SamplingSessionManager(BaseManager[SamplingSessionRecord]): """ def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: - super().__init__(backend, "sampling::", SamplingSessionRecord, expiration_timeout) + super().__init__(backend, 'sampling::', SamplingSessionRecord, expiration_timeout) # ----- Cleanup ----- diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 49a332cd7..42b1333e3 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -10,14 +10,8 @@ from twinkle.server.exceptions import ConfigMismatchError from twinkle.server.telemetry import MetricsRegistry -from twinkle.server.telemetry.correlation import ( - BASE_MODEL, - MODEL_ID, - REPLICA_ID, - SAMPLING_SESSION_ID, - SESSION_ID, - TOKEN_ID, -) +from twinkle.server.telemetry.correlation import (BASE_MODEL, MODEL_ID, REPLICA_ID, SAMPLING_SESSION_ID, SESSION_ID, + TOKEN_ID) from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger from .backend import StateBackend @@ -99,8 +93,8 @@ async def create_session(self, payload: dict[str, Any]) -> str: """ session_id = payload.get('session_id') or f'session_{uuid.uuid4().hex}' with traced_operation( - 'server_state.create_session', - attrs={SESSION_ID: session_id}, + 'server_state.create_session', + attrs={SESSION_ID: session_id}, ): record = SessionRecord( tags=list(payload.get('tags') or []), @@ -153,14 +147,14 @@ async def register_model(self, _model_id = re.sub(r'[^\w\-]', '_', _model_id) with traced_operation( - 'server_state.register_model', - attrs={ - MODEL_ID: _model_id, - BASE_MODEL: payload.get('base_model'), - REPLICA_ID: replica_id, - TOKEN_ID: token, - SESSION_ID: session_id or payload.get('session_id'), - }, + 'server_state.register_model', + attrs={ + MODEL_ID: _model_id, + BASE_MODEL: payload.get('base_model'), + REPLICA_ID: replica_id, + TOKEN_ID: token, + SESSION_ID: session_id or payload.get('session_id'), + }, ): record = ModelRecord( session_id=session_id or payload.get('session_id'), @@ -197,8 +191,8 @@ async def register_replica(self, replica_id: str, max_loras: int) -> None: max_loras: Maximum number of LoRA adapters the replica can hold. """ with traced_operation( - 'server_state.register_replica', - attrs={REPLICA_ID: replica_id}, + 'server_state.register_replica', + attrs={REPLICA_ID: replica_id}, ): await self._model_mgr.register_replica(replica_id, max_loras) @@ -236,12 +230,12 @@ async def create_sampling_session(self, payload: dict[str, Any], sampling_sessio _sampling_session_id: str = sampling_session_id or payload.get( 'sampling_session_id') or f'sampling_{uuid.uuid4().hex}' with traced_operation( - 'server_state.create_sampling_session', - attrs={ - SAMPLING_SESSION_ID: _sampling_session_id, - SESSION_ID: payload.get('session_id'), - BASE_MODEL: payload.get('base_model'), - }, + 'server_state.create_sampling_session', + attrs={ + SAMPLING_SESSION_ID: _sampling_session_id, + SESSION_ID: payload.get('session_id'), + BASE_MODEL: payload.get('base_model'), + }, ): record = SamplingSessionRecord( session_id=payload.get('session_id'), @@ -350,7 +344,8 @@ async def cleanup_expired_resources(self) -> dict[str, int]: # Perform actual cleanup in dependency order sessions_removed = await self._session_mgr.cleanup_expired(cutoff_time) models_removed = await self._model_mgr.cleanup_expired(cutoff_time, expired_session_ids=expired_session_ids) - samplings_removed = await self._sampling_mgr.cleanup_expired(cutoff_time, expired_session_ids=expired_session_ids) + samplings_removed = await self._sampling_mgr.cleanup_expired( + cutoff_time, expired_session_ids=expired_session_ids) futures_removed = await self._future_mgr.cleanup_expired(cutoff_time) return { @@ -429,10 +424,7 @@ async def _rebuild_indexes(self) -> None: """ # Validate config signature if provided if self._signature_config is not None: - from twinkle.server.state.config_signature import ( - SignatureMismatchPolicy, - validate_config_signature, - ) + from twinkle.server.state.config_signature import SignatureMismatchPolicy, validate_config_signature policy = SignatureMismatchPolicy(self._signature_policy) await validate_config_signature(self._backend, self._signature_config, policy) @@ -490,17 +482,15 @@ async def get_cleanup_stats(self) -> dict[str, Any]: ServerStateProxy = ServerState # type: ignore[assignment] - _PROCESS_STATE_CACHE: dict[str, ServerState] = {} -def get_server_state( - actor_name: str = 'twinkle_server_state', - backend: StateBackend | None = None, - persistence_config: PersistenceConfig | None = None, - signature_config: dict[str, Any] | None = None, - signature_policy: str = 'warn', - **kwargs) -> ServerState: +def get_server_state(actor_name: str = 'twinkle_server_state', + backend: StateBackend | None = None, + persistence_config: PersistenceConfig | None = None, + signature_config: dict[str, Any] | None = None, + signature_policy: str = 'warn', + **kwargs) -> ServerState: """Return a process-local :class:`ServerState` bound directly to the backend. No detached Ray Actor is created — every deployment / worker accesses the diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py index 67efc2ddc..630c33d2f 100644 --- a/src/twinkle/server/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -16,7 +16,7 @@ class SessionManager(BaseManager[SessionRecord]): """ def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: - super().__init__(backend, "session::", SessionRecord, expiration_timeout) + super().__init__(backend, 'session::', SessionRecord, expiration_timeout) # ----- Session-specific operations ----- diff --git a/src/twinkle/server/telemetry/__init__.py b/src/twinkle/server/telemetry/__init__.py index 8da61164f..7976dfd3e 100644 --- a/src/twinkle/server/telemetry/__init__.py +++ b/src/twinkle/server/telemetry/__init__.py @@ -1,27 +1,17 @@ from .metrics import MetricsRegistry -from .provider import ( - TelemetryConfig, - get_meter, - init_telemetry, - shutdown_telemetry, -) -from .tracing import ( - get_tracer, - inject_context, - extract_context, - get_current_span, -) +from .provider import TelemetryConfig, get_meter, init_telemetry, shutdown_telemetry +from .tracing import extract_context, get_current_span, get_tracer, inject_context from .worker_init import ensure_telemetry_initialized __all__ = [ - "MetricsRegistry", - "TelemetryConfig", - "get_meter", - "init_telemetry", - "shutdown_telemetry", - "get_tracer", - "inject_context", - "extract_context", - "get_current_span", - "ensure_telemetry_initialized", + 'MetricsRegistry', + 'TelemetryConfig', + 'get_meter', + 'init_telemetry', + 'shutdown_telemetry', + 'get_tracer', + 'inject_context', + 'extract_context', + 'get_current_span', + 'ensure_telemetry_initialized', ] diff --git a/src/twinkle/server/telemetry/context_carrier.py b/src/twinkle/server/telemetry/context_carrier.py index 87dc6553b..13a0d5b78 100644 --- a/src/twinkle/server/telemetry/context_carrier.py +++ b/src/twinkle/server/telemetry/context_carrier.py @@ -20,7 +20,8 @@ try: from opentelemetry import context as _otel_context # type: ignore - from opentelemetry.propagate import extract as _otel_extract, inject as _otel_inject # type: ignore + from opentelemetry.propagate import extract as _otel_extract # type: ignore + from opentelemetry.propagate import inject as _otel_inject _OTEL_AVAILABLE = True except Exception: diff --git a/src/twinkle/server/telemetry/correlation.py b/src/twinkle/server/telemetry/correlation.py index cdbf23975..bddfe745c 100644 --- a/src/twinkle/server/telemetry/correlation.py +++ b/src/twinkle/server/telemetry/correlation.py @@ -19,7 +19,6 @@ SAMPLING_SESSION_ID = f'{PREFIX}sampling_session_id' BASE_MODEL = f'{PREFIX}base_model' - CORRELATION_KEYS: tuple[str, ...] = ( SESSION_ID, MODEL_ID, diff --git a/src/twinkle/server/telemetry/metrics.py b/src/twinkle/server/telemetry/metrics.py index d9979b8f1..fbc3de9a0 100644 --- a/src/twinkle/server/telemetry/metrics.py +++ b/src/twinkle/server/telemetry/metrics.py @@ -14,63 +14,63 @@ class MetricsRegistry: _instance: MetricsRegistry | None = None def __init__(self) -> None: - meter = get_meter("twinkle-server") + meter = get_meter('twinkle-server') # === HTTP Requests === self.requests_total = meter.create_counter( - "twinkle.http.requests.total", - description="Total HTTP requests received", + 'twinkle.http.requests.total', + description='Total HTTP requests received', ) self.request_duration_seconds = meter.create_histogram( - "twinkle.http.request.duration_seconds", - description="HTTP request duration in seconds", - unit="s", + 'twinkle.http.request.duration_seconds', + description='HTTP request duration in seconds', + unit='s', ) # === Task Queue === self.queue_depth = meter.create_up_down_counter( - "twinkle.queue.depth", - description="Current task queue depth", + 'twinkle.queue.depth', + description='Current task queue depth', ) self.task_execution_seconds = meter.create_histogram( - "twinkle.task.execution_seconds", - description="Task execution duration in seconds", - unit="s", + 'twinkle.task.execution_seconds', + description='Task execution duration in seconds', + unit='s', ) self.task_wait_seconds = meter.create_histogram( - "twinkle.task.wait_seconds", - description="Task wait time in queue before execution", - unit="s", + 'twinkle.task.wait_seconds', + description='Task wait time in queue before execution', + unit='s', ) self.rate_limit_rejections = meter.create_counter( - "twinkle.rate_limit.rejections.total", - description="Total requests rejected by rate limiter", + 'twinkle.rate_limit.rejections.total', + description='Total requests rejected by rate limiter', ) self.tasks_total = meter.create_counter( - "twinkle.tasks.total", - description="Total task completions, partitioned by status", + 'twinkle.tasks.total', + description='Total task completions, partitioned by status', ) self.rate_limiter_active_tokens = meter.create_up_down_counter( - "twinkle.rate_limiter.active_tokens", - description="Number of tokens currently tracked by the rate limiter", + 'twinkle.rate_limiter.active_tokens', + description='Number of tokens currently tracked by the rate limiter', ) # === Resources === self.active_sessions = meter.create_up_down_counter( - "twinkle.sessions.active", - description="Number of active client sessions", + 'twinkle.sessions.active', + description='Number of active client sessions', ) self.active_models = meter.create_up_down_counter( - "twinkle.models.active", - description="Number of registered models", + 'twinkle.models.active', + description='Number of registered models', ) self.active_sampling_sessions = meter.create_up_down_counter( - "twinkle.sampling_sessions.active", - description="Number of active sampling sessions", + 'twinkle.sampling_sessions.active', + description='Number of active sampling sessions', ) self.active_futures = meter.create_up_down_counter( - "twinkle.futures.active", - description="Number of pending futures/tasks", + 'twinkle.futures.active', + description='Number of pending futures/tasks', ) @classmethod diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index bac4a6b70..a8ec970f5 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -13,9 +13,8 @@ from __future__ import annotations import logging -from typing import Any, Optional - from pydantic import BaseModel, ConfigDict, Field +from typing import Any, Optional logger = logging.getLogger(__name__) @@ -27,39 +26,24 @@ from opentelemetry import metrics, trace from opentelemetry._logs import set_logger_provider from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler - from opentelemetry.sdk._logs.export import ( - BatchLogRecordProcessor, - ConsoleLogExporter, - ) + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, ConsoleLogExporter from opentelemetry.sdk.metrics import MeterProvider - from opentelemetry.sdk.metrics.export import ( - ConsoleMetricExporter, - PeriodicExportingMetricReader, - ) + from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import ( - BatchSpanProcessor, - ConsoleSpanExporter, - ) + from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter _OTEL_AVAILABLE = True - _OTEL_IMPORT_ERROR: Optional[BaseException] = None + _OTEL_IMPORT_ERROR: BaseException | None = None except Exception as exc: # pragma: no cover - defensive fallback _OTEL_AVAILABLE = False _OTEL_IMPORT_ERROR = exc # OTLP exporters are a separate optional dependency from the SDK itself. try: - from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( - OTLPLogExporter, - ) - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( - OTLPMetricExporter, - ) - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter, - ) + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter _OTLP_AVAILABLE = True except Exception: # pragma: no cover - defensive fallback @@ -73,14 +57,13 @@ except Exception: # pragma: no cover - defensive fallback _LOGGING_INSTRUMENTOR_AVAILABLE = False - # --------------------------------------------------------------------------- # Module-level state for shutdown. # --------------------------------------------------------------------------- -_tracer_provider: Optional[Any] = None -_meter_provider: Optional[Any] = None -_logger_provider: Optional[Any] = None -_logging_handler: Optional[Any] = None +_tracer_provider: Any | None = None +_meter_provider: Any | None = None +_logger_provider: Any | None = None +_logging_handler: Any | None = None _initialized: bool = False @@ -113,8 +96,8 @@ class TelemetryConfig(BaseModel): model_config = ConfigDict(extra='forbid') enabled: bool = False - service_name: str = "twinkle-server" - otlp_endpoint: str = "http://localhost:4317" + service_name: str = 'twinkle-server' + otlp_endpoint: str = 'http://localhost:4317' debug: bool = False # True: Console Exporter; False: OTLP Exporter export_interval_ms: int = 30000 resource_attributes: dict = Field(default_factory=dict) @@ -133,26 +116,24 @@ def init_telemetry(config: TelemetryConfig) -> None: if not _OTEL_AVAILABLE: logger.warning( - "OpenTelemetry SDK not available, skipping telemetry init: %s", + 'OpenTelemetry SDK not available, skipping telemetry init: %s', _OTEL_IMPORT_ERROR, ) return if _initialized: - logger.debug("Telemetry already initialized; skipping re-init.") + logger.debug('Telemetry already initialized; skipping re-init.') return # ---- Resource ------------------------------------------------------- - resource_attrs: dict = {"service.name": config.service_name} + resource_attrs: dict = {'service.name': config.service_name} if config.resource_attributes: resource_attrs.update(config.resource_attributes) resource = Resource.create(resource_attrs) use_console = config.debug or not _OTLP_AVAILABLE if config.debug is False and not _OTLP_AVAILABLE: - logger.warning( - "OTLP exporters not available; falling back to console exporters." - ) + logger.warning('OTLP exporters not available; falling back to console exporters.') # When using console exporters, route their output through the Python # logging system so that Ray Serve workers actually surface the data. @@ -193,9 +174,7 @@ def init_telemetry(config: TelemetryConfig) -> None: log_exporter = OTLPLogExporter(endpoint=config.otlp_endpoint) logger_provider = LoggerProvider(resource=resource) - logger_provider.add_log_record_processor( - BatchLogRecordProcessor(log_exporter) - ) + logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) set_logger_provider(logger_provider) _logger_provider = logger_provider @@ -204,11 +183,9 @@ def init_telemetry(config: TelemetryConfig) -> None: try: LoggingInstrumentor().instrument(set_logging_format=True) except Exception as exc: # pragma: no cover - defensive - logger.warning("LoggingInstrumentor failed to instrument: %s", exc) + logger.warning('LoggingInstrumentor failed to instrument: %s', exc) - handler = LoggingHandler( - level=logging.NOTSET, logger_provider=logger_provider - ) + handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider) # Attach to BOTH the root logger and the ``twinkle`` namespace logger. # ``twinkle.utils.logger`` configures the ``twinkle`` logger with # ``propagate=False`` and its own StreamHandler, so log records emitted @@ -220,7 +197,7 @@ def init_telemetry(config: TelemetryConfig) -> None: _initialized = True logger.info( - "Telemetry initialized (service=%s, debug=%s, otlp_endpoint=%s)", + 'Telemetry initialized (service=%s, debug=%s, otlp_endpoint=%s)', config.service_name, config.debug, config.otlp_endpoint, @@ -237,28 +214,28 @@ def shutdown_telemetry() -> None: try: logging.getLogger(logger_name).removeHandler(_logging_handler) except Exception as exc: # pragma: no cover - defensive - logger.warning("Failed to detach logging handler from %r: %s", logger_name, exc) + logger.warning('Failed to detach logging handler from %r: %s', logger_name, exc) _logging_handler = None if _tracer_provider is not None: try: _tracer_provider.shutdown() except Exception as exc: # pragma: no cover - defensive - logger.warning("TracerProvider shutdown failed: %s", exc) + logger.warning('TracerProvider shutdown failed: %s', exc) _tracer_provider = None if _meter_provider is not None: try: _meter_provider.shutdown() except Exception as exc: # pragma: no cover - defensive - logger.warning("MeterProvider shutdown failed: %s", exc) + logger.warning('MeterProvider shutdown failed: %s', exc) _meter_provider = None if _logger_provider is not None: try: _logger_provider.shutdown() except Exception as exc: # pragma: no cover - defensive - logger.warning("LoggerProvider shutdown failed: %s", exc) + logger.warning('LoggerProvider shutdown failed: %s', exc) _logger_provider = None _initialized = False @@ -266,21 +243,31 @@ def shutdown_telemetry() -> None: class _NoopInstrument: """No-op instrument for when OTEL SDK is not available.""" - def add(self, *args, **kwargs): pass - def record(self, *args, **kwargs): pass + + def add(self, *args, **kwargs): + pass + + def record(self, *args, **kwargs): + pass class _NoopMeter: """No-op meter for when OTEL SDK is not available.""" - def create_counter(self, *args, **kwargs): return _NoopInstrument() - def create_up_down_counter(self, *args, **kwargs): return _NoopInstrument() - def create_histogram(self, *args, **kwargs): return _NoopInstrument() + + def create_counter(self, *args, **kwargs): + return _NoopInstrument() + + def create_up_down_counter(self, *args, **kwargs): + return _NoopInstrument() + + def create_histogram(self, *args, **kwargs): + return _NoopInstrument() _noop_meter = _NoopMeter() -def get_meter(name: str = "twinkle-server"): +def get_meter(name: str = 'twinkle-server'): """Return an OTEL meter. Returns NoOp meter if OTEL SDK is not available.""" if not _OTEL_AVAILABLE: return _noop_meter diff --git a/src/twinkle/server/telemetry/resource_metrics.py b/src/twinkle/server/telemetry/resource_metrics.py index 15248d5f9..112367160 100644 --- a/src/twinkle/server/telemetry/resource_metrics.py +++ b/src/twinkle/server/telemetry/resource_metrics.py @@ -30,7 +30,6 @@ except Exception: _PYNVML_AVAILABLE = False - _NVML_INITIALIZED = False diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py index d093013a8..770b4b3a1 100644 --- a/src/twinkle/server/telemetry/tracing.py +++ b/src/twinkle/server/telemetry/tracing.py @@ -3,23 +3,21 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Any, Iterator, Mapping - from fastapi import Request +from typing import Any, Iterator, Mapping try: from opentelemetry import trace - from opentelemetry.propagate import inject, extract from opentelemetry.context import Context + from opentelemetry.propagate import extract, inject _OTEL_AVAILABLE = True except Exception: _OTEL_AVAILABLE = False - from .correlation import set_correlation_attrs -def get_tracer(name: str = "twinkle-server"): +def get_tracer(name: str = 'twinkle-server'): """Retrieve tracer instance. Returns NoOp tracer when OTEL is not installed.""" if not _OTEL_AVAILABLE: return _NoopTracer() @@ -49,18 +47,32 @@ def get_current_span(): class _NoopSpan: """Minimal noop span for when OTEL is not available.""" - def set_attribute(self, *args, **kwargs): pass - def set_status(self, *args, **kwargs): pass - def add_event(self, *args, **kwargs): pass - def end(self, *args, **kwargs): pass - def __enter__(self): return self - def __exit__(self, *args): pass + + def set_attribute(self, *args, **kwargs): + pass + + def set_status(self, *args, **kwargs): + pass + + def add_event(self, *args, **kwargs): + pass + + def end(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass class _NoopTracer: """Minimal noop tracer for when OTEL is not available.""" + def start_as_current_span(self, name, **kwargs): return _NoopSpan() + def start_span(self, name, **kwargs): return _NoopSpan() @@ -121,8 +133,10 @@ def create_tracing_middleware(service_component: str): An async FastAPI HTTP middleware function. """ if not _OTEL_AVAILABLE: + async def passthrough_middleware(request: Request, call_next): return await call_next(request) + return passthrough_middleware async def tracing_middleware(request: Request, call_next): diff --git a/src/twinkle/server/utils/metrics.py b/src/twinkle/server/utils/metrics.py index aaf1076ed..ff1fe7f29 100644 --- a/src/twinkle/server/utils/metrics.py +++ b/src/twinkle/server/utils/metrics.py @@ -27,9 +27,8 @@ # --------------------------------------------------------------------------- # Lazy caches – populated on first call per deployment # --------------------------------------------------------------------------- -_task_metrics_cache: dict[str, 'TaskMetrics'] = {} -_request_metrics_cache: dict[str, '_RequestMetrics'] = {} - +_task_metrics_cache: dict[str, TaskMetrics] = {} +_request_metrics_cache: dict[str, _RequestMetrics] = {} # --------------------------------------------------------------------------- # Adapter classes – wrap OTEL instruments to expose the legacy Ray-style API diff --git a/src/twinkle/server/utils/task_queue/config.py b/src/twinkle/server/utils/task_queue/config.py index 1bd26971a..57f24ba67 100644 --- a/src/twinkle/server/utils/task_queue/config.py +++ b/src/twinkle/server/utils/task_queue/config.py @@ -9,9 +9,8 @@ """ from __future__ import annotations -from typing import Any - from pydantic import BaseModel, ConfigDict, Field +from typing import Any class TaskQueueConfig(BaseModel): @@ -42,7 +41,7 @@ class TaskQueueConfig(BaseModel): max_input_tokens: int = Field(default=16000, ge=1) @classmethod - def from_dict(cls, config_dict: dict[str, Any] | None = None) -> 'TaskQueueConfig': + def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: """Validate ``config_dict`` (or ``{}``) into a ``TaskQueueConfig``. Equivalent to ``cls.model_validate(config_dict or {})`` — kept for diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index f8121736b..5013065bc 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -19,8 +19,8 @@ from .types import QueuedTask, QueueState, TaskStatus if TYPE_CHECKING: - from twinkle.server.utils.metrics import TaskMetrics from twinkle.server.state import ServerStateProxy + from twinkle.server.utils.metrics import TaskMetrics logger = get_logger() diff --git a/tests/contract/client_api_harness.py b/tests/contract/client_api_harness.py index c8d30a385..32d10c8b4 100644 --- a/tests/contract/client_api_harness.py +++ b/tests/contract/client_api_harness.py @@ -27,12 +27,10 @@ from __future__ import annotations import json -from pathlib import Path -from typing import Any, Callable - from fastapi import FastAPI from fastapi.openapi.utils import get_openapi - +from pathlib import Path +from typing import Any, Callable # ----- App build helpers --------------------------------------------------- # @@ -86,10 +84,8 @@ def build_processor_app() -> FastAPI: 'processor': build_processor_app, } - # ----- Surface extraction -------------------------------------------------- # - _HTTP_METHODS = {'GET', 'POST', 'PUT', 'PATCH', 'DELETE'} @@ -135,7 +131,6 @@ def extract_full_surface() -> dict[str, Any]: # ----- Baseline I/O -------------------------------------------------------- # - BASELINE_PATH = Path(__file__).parent / 'client_api_baseline.json' diff --git a/tests/contract/test_client_api_contract.py b/tests/contract/test_client_api_contract.py index ebeca298d..d321596cf 100644 --- a/tests/contract/test_client_api_contract.py +++ b/tests/contract/test_client_api_contract.py @@ -18,23 +18,15 @@ from __future__ import annotations import json - import pytest -from tests.contract.client_api_harness import ( - APP_BUILDERS, - BASELINE_PATH, - extract_full_surface, - load_baseline, -) +from tests.contract.client_api_harness import APP_BUILDERS, BASELINE_PATH, extract_full_surface, load_baseline def test_baseline_file_exists() -> None: - assert BASELINE_PATH.exists(), ( - f'Baseline {BASELINE_PATH} missing. Generate it with ' - '`python -m tests.contract.update_baseline` after confirming the ' - 'current client-facing surface is correct.' - ) + assert BASELINE_PATH.exists(), (f'Baseline {BASELINE_PATH} missing. Generate it with ' + '`python -m tests.contract.update_baseline` after confirming the ' + 'current client-facing surface is correct.') @pytest.mark.parametrize('app_name', sorted(APP_BUILDERS.keys())) @@ -50,22 +42,18 @@ def test_app_surface_matches_baseline(app_name: str) -> None: if actual != expected: diff = _surface_diff(expected, actual) - pytest.fail( - f'Client-API surface for {app_name!r} drifted from the baseline.\n' - f'{diff}\n' - f'If the change is intentional, regenerate the baseline with ' - f'`python -m tests.contract.update_baseline`.' - ) + pytest.fail(f'Client-API surface for {app_name!r} drifted from the baseline.\n' + f'{diff}\n' + f'If the change is intentional, regenerate the baseline with ' + f'`python -m tests.contract.update_baseline`.') def test_full_surface_matches_baseline() -> None: """Whole-surface equality — the cross-cutting freeze guard.""" baseline = load_baseline() current = extract_full_surface() - assert current == baseline, ( - 'Full client-API surface drifted from the baseline. ' - 'See per-app failures for details.' - ) + assert current == baseline, ('Full client-API surface drifted from the baseline. ' + 'See per-app failures for details.') def _surface_diff(expected: dict, actual: dict) -> str: @@ -92,6 +80,9 @@ def _surface_diff(expected: dict, actual: dict) -> str: if changed: parts.append(' changed paths:\n' + '\n'.join(changed)) return '\n'.join(parts) if parts else json.dumps( - {'expected_keys': sorted(expected.keys()), 'actual_keys': sorted(actual.keys())}, + { + 'expected_keys': sorted(expected.keys()), + 'actual_keys': sorted(actual.keys()) + }, indent=2, ) diff --git a/tests/docs/test_docs_smoke.py b/tests/docs/test_docs_smoke.py index e86af5565..d24c5c51c 100644 --- a/tests/docs/test_docs_smoke.py +++ b/tests/docs/test_docs_smoke.py @@ -2,16 +2,13 @@ """Smoke checks for the Phase 5 documentation set (R8.3, R11.4, R17).""" from __future__ import annotations -from pathlib import Path - import pytest +from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[2] - # ---------- file presence ------------------------------------------------- # - OBSERVABILITY_EN = REPO_ROOT / 'docs' / 'source_en' / 'Usage Guide' / 'Observability.md' OBSERVABILITY_ZH = REPO_ROOT / 'docs' / 'source_zh' / '使用指引' / '可观测化.md' CONFIG_GUIDE_ZH = REPO_ROOT / 'docs' / 'source_zh' / '使用指引' / '服务配置.md' @@ -30,7 +27,6 @@ def test_doc_exists(path: Path) -> None: # ---------- observability guide content (R11.4, R17.1, R17.2) ------------ # - _CORRELATION_KEYS = ( 'twinkle.session_id', 'twinkle.model_id', diff --git a/tests/integration/test_lgtm_telemetry.py b/tests/integration/test_lgtm_telemetry.py index 6775dc23d..2a916fe86 100644 --- a/tests/integration/test_lgtm_telemetry.py +++ b/tests/integration/test_lgtm_telemetry.py @@ -27,16 +27,15 @@ """ from __future__ import annotations +import httpx import os +import pytest import socket import time import urllib.parse import uuid from contextlib import contextmanager -import httpx -import pytest - OTLP_ENDPOINT = os.environ.get('TWINKLE_TEST_OTLP_ENDPOINT', 'http://localhost:4317') GRAFANA_URL = os.environ.get('TWINKLE_TEST_GRAFANA_URL', 'http://localhost:3000') JAEGER_URL = os.environ.get('TWINKLE_TEST_JAEGER_URL', 'http://localhost:16686') @@ -85,14 +84,11 @@ def _detect_backend() -> str | None: pytestmark = pytest.mark.skipif( _BACKEND is None, - reason=( - f'No trace backend reachable. OTLP at {OTLP_ENDPOINT}, Grafana at {GRAFANA_URL}, ' - f'Jaeger at {JAEGER_URL}. Start one (cookbook/observability/docker-compose.yaml ' - 'or `docker run jaegertracing/all-in-one:1.62.0`).' - ), + reason=(f'No trace backend reachable. OTLP at {OTLP_ENDPOINT}, Grafana at {GRAFANA_URL}, ' + f'Jaeger at {JAEGER_URL}. Start one (cookbook/observability/docker-compose.yaml ' + 'or `docker run jaegertracing/all-in-one:1.62.0`).'), ) - # ---------- helpers ------------------------------------------------------- # @@ -186,24 +182,22 @@ def _spans_in_trace(payload: dict) -> list[dict]: for batch in payload.get('batches', []): for scope in batch.get('scopeSpans', []): for span in scope.get('spans', []): - out.append( - { - 'name': span.get('name'), - 'attributes': { - a['key']: a.get('value', {}).get('stringValue') - for a in span.get('attributes', []) - }, - } - ) + out.append({ + 'name': span.get('name'), + 'attributes': { + a['key']: a.get('value', {}).get('stringValue') + for a in span.get('attributes', []) + }, + }) return out # Jaeger trace JSON: top-level "spans" with operationName + tags. - return [ - { - 'name': s['operationName'], - 'attributes': {t['key']: t.get('value') for t in s.get('tags', [])}, - } - for s in payload.get('spans', []) - ] + return [{ + 'name': s['operationName'], + 'attributes': { + t['key']: t.get('value') + for t in s.get('tags', []) + }, + } for s in payload.get('spans', [])] # ---------- 7.15: trace + correlation visible in the trace store --------- # @@ -226,8 +220,11 @@ def test_business_span_with_correlation_visible_e2e() -> None: tracer = trace.get_tracer('twinkle.test.trace') with tracer.start_as_current_span('integration.parent') as parent: with traced_operation( - 'server_state.register_model', - attrs={SESSION_ID: session_id, MODEL_ID: model_id}, + 'server_state.register_model', + attrs={ + SESSION_ID: session_id, + MODEL_ID: model_id + }, ): pass trace_id_hex = format(parent.get_span_context().trace_id, '032x') @@ -236,12 +233,10 @@ def test_business_span_with_correlation_visible_e2e() -> None: assert payload is not None, f'trace {trace_id_hex} not found in {_BACKEND}' attrs_per_span = [s['attributes'] for s in _spans_in_trace(payload)] - assert any(a.get(SESSION_ID) == session_id for a in attrs_per_span), ( - f'{SESSION_ID} not on any span in {_BACKEND}: {attrs_per_span}' - ) - assert any(a.get(MODEL_ID) == model_id for a in attrs_per_span), ( - f'{MODEL_ID} not on any span in {_BACKEND}: {attrs_per_span}' - ) + assert any(a.get(SESSION_ID) == session_id + for a in attrs_per_span), (f'{SESSION_ID} not on any span in {_BACKEND}: {attrs_per_span}') + assert any(a.get(MODEL_ID) == model_id + for a in attrs_per_span), (f'{MODEL_ID} not on any span in {_BACKEND}: {attrs_per_span}') # ---------- 10.4: single-trace-id fan-out across deployments (R13.3) ----- # diff --git a/tests/integration/test_mock_mode_startup.py b/tests/integration/test_mock_mode_startup.py index 8be93cd24..11f504c2d 100644 --- a/tests/integration/test_mock_mode_startup.py +++ b/tests/integration/test_mock_mode_startup.py @@ -15,12 +15,10 @@ """ from __future__ import annotations -import os -import time -import uuid - import httpx +import os import pytest +import time from twinkle.server.config import ServerConfig @@ -97,7 +95,10 @@ def test_mock_mode_reaches_ready_under_30s_and_is_deterministic(ray_cluster) -> deploy_options: dict = {} for raw in app_spec.deployments: if isinstance(raw, dict): - deploy_options = {k: v for k, v in raw.items() if k not in ('name', 'ray_actor_options', 'autoscaling_config')} + deploy_options = { + k: v + for k, v in raw.items() if k not in ('name', 'ray_actor_options', 'autoscaling_config') + } break bound = builder(deploy_options=deploy_options, **args) serve.run(bound, name=app_spec.name, route_prefix=app_spec.route_prefix) @@ -115,8 +116,6 @@ def test_mock_mode_reaches_ready_under_30s_and_is_deterministic(ray_cluster) -> assert r.status_code == 200, r.text # Mock model + sampler determinism via the gateway's exposed routes. - sampling_session = f'sess-{uuid.uuid4().hex[:8]}' - payload = {'session_id': sampling_session} r1 = _http(f'{base}/api/v1/twinkle/healthz') r2 = _http(f'{base}/api/v1/twinkle/healthz') assert r1.status_code == 200 and r2.status_code == 200 diff --git a/tests/server/cli/test_cli.py b/tests/server/cli/test_cli.py index ef8cc021d..46f340574 100644 --- a/tests/server/cli/test_cli.py +++ b/tests/server/cli/test_cli.py @@ -7,30 +7,23 @@ from __future__ import annotations import json -from pathlib import Path -from unittest import mock - import pytest import yaml +from pathlib import Path from typer.testing import CliRunner +from unittest import mock from twinkle.server.cli.app import app, main from twinkle.server.config import ServerConfig from twinkle.server.exceptions import ConfigMismatchError from twinkle.server.state.backend.factory import PersistenceConfig from twinkle.server.state.backend.memory_backend import MemoryBackend -from twinkle.server.state.config_signature import ( - _SIGNATURE_KEY, - compute_signature, - validate_against_backend, -) - +from twinkle.server.state.config_signature import _SIGNATURE_KEY, compute_signature, validate_against_backend REPO_ROOT = Path(__file__).resolve().parents[3] EXAMPLE = REPO_ROOT / 'cookbook' / 'client' / 'server' / 'server_config.example.yaml' MOCK_CFG = REPO_ROOT / 'cookbook' / 'client' / 'server' / 'mock' / 'server_config.yaml' - # ---------- 9.5 CLI subcommand existence + exit codes (R14.3, R14.4) ------ # @@ -101,8 +94,8 @@ def _abort_drift(*args, **kwargs): raise ConfigMismatchError('drift sentinel') with mock.patch( - 'twinkle.server.state.config_signature.validate_against_backend', - side_effect=_abort_drift, + 'twinkle.server.state.config_signature.validate_against_backend', + side_effect=_abort_drift, ): # Should never reach the launcher import — patch it to a sentinel that # would make the test fail loudly if reached. @@ -124,9 +117,7 @@ async def test_property_29_first_run_stores_signature() -> None: pcfg = PersistenceConfig(mode='memory') # Patch create_backend to return our shared in-process backend so we can # inspect the stored signature afterwards. - with mock.patch( - 'twinkle.server.state.backend.factory.create_backend', return_value=backend - ): + with mock.patch('twinkle.server.state.backend.factory.create_backend', return_value=backend): await validate_against_backend(pcfg, cfg_payload) assert await backend.get(_SIGNATURE_KEY) == compute_signature(cfg_payload) # Second run with same payload is a no-op. @@ -140,9 +131,7 @@ async def test_property_29_drift_raises_with_diff_and_remediation() -> None: initial = {'persistence': {'mode': 'memory'}} later = {'persistence': {'mode': 'file', 'file_path': '/tmp/x.json'}} - with mock.patch( - 'twinkle.server.state.backend.factory.create_backend', return_value=backend - ): + with mock.patch('twinkle.server.state.backend.factory.create_backend', return_value=backend): await validate_against_backend(pcfg, initial) with pytest.raises(ConfigMismatchError) as exc: await validate_against_backend(pcfg, later) diff --git a/tests/server/cli/test_drift_integration.py b/tests/server/cli/test_drift_integration.py index dcd3247cd..35127d02d 100644 --- a/tests/server/cli/test_drift_integration.py +++ b/tests/server/cli/test_drift_integration.py @@ -12,31 +12,26 @@ import asyncio import os +import pytest import re import uuid -from pathlib import Path -from unittest import mock - -import pytest import yaml +from pathlib import Path from typer.testing import CliRunner +from unittest import mock from twinkle.server.cli.app import app from twinkle.server.config import ServerConfig from twinkle.server.exceptions import ConfigMismatchError from twinkle.server.state.backend.factory import PersistenceConfig, create_backend from twinkle.server.state.backend.redis_backend import RedisBackend -from twinkle.server.state.config_signature import ( - _SIGNATURE_KEY, - compute_signature, - validate_against_backend, -) - +from twinkle.server.state.config_signature import _SIGNATURE_KEY, compute_signature, validate_against_backend REDIS_URL = os.environ.get('TWINKLE_TEST_REDIS_URL', 'redis://localhost:6379/0') def _can_reach_redis() -> bool: + async def _check() -> bool: backend = RedisBackend(REDIS_URL) try: @@ -72,17 +67,23 @@ def write_config(tmp_path: Path): def _write(persistence: dict) -> Path: payload = { - 'http_options': {'host': 'localhost', 'port': 8000}, - 'telemetry': {'enabled': False}, - 'persistence': persistence, - 'applications': [ - { - 'name': 'server', - 'route_prefix': '/api/v1', - 'import_path': 'server', - 'args': {'supported_models': ['mock']}, - } - ], + 'http_options': { + 'host': 'localhost', + 'port': 8000 + }, + 'telemetry': { + 'enabled': False + }, + 'persistence': + persistence, + 'applications': [{ + 'name': 'server', + 'route_prefix': '/api/v1', + 'import_path': 'server', + 'args': { + 'supported_models': ['mock'] + }, + }], } path = tmp_path / f'config-{uuid.uuid4().hex[:6]}.yaml' path.write_text(yaml.safe_dump(payload)) @@ -206,9 +207,7 @@ def test_check_config_does_not_touch_redis(fresh_prefix: str, write_config) -> N assert res.exit_code == 0 async def _read_signature() -> object: - backend = create_backend( - PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix) - ) + backend = create_backend(PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix)) try: return await backend.get(_SIGNATURE_KEY) finally: diff --git a/tests/server/config/test_server_config.py b/tests/server/config/test_server_config.py index d84163db0..a7a621eeb 100644 --- a/tests/server/config/test_server_config.py +++ b/tests/server/config/test_server_config.py @@ -2,63 +2,69 @@ """Property + unit tests for the typed ``ServerConfig`` (R6, R7, R8). Properties covered: -- # Feature: server-config-observability-refactor, Property 12: Valid configuration yields a fully validated instance -- # Feature: server-config-observability-refactor, Property 13: Any constraint violation is rejected with the offending field named -- # Feature: server-config-observability-refactor, Property 14: Configuration round-trip fidelity -- # Feature: server-config-observability-refactor, Property 15: Legacy / unknown field names are rejected +- # Feature: server-config-observability-refactor, + Property 12: Valid configuration yields a fully validated instance +- # Feature: server-config-observability-refactor, + Property 13: Any constraint violation is rejected with the offending field named +- # Feature: server-config-observability-refactor, + Property 14: Configuration round-trip fidelity +- # Feature: server-config-observability-refactor, + Property 15: Legacy / unknown field names are rejected """ from __future__ import annotations -from pathlib import Path - import pytest import yaml -from hypothesis import given, settings, strategies as st +from hypothesis import given, settings +from hypothesis import strategies as st +from pathlib import Path from pydantic import ValidationError from twinkle.server.config import ApplicationSpec, ServerConfig from twinkle.server.exceptions import ConfigParseError from twinkle.server.launcher import ServerLauncher - # ---------- minimal valid config strategy ---------------------------------- # - _PERSISTENCE_VARIANTS = st.one_of( st.fixed_dictionaries({'mode': st.just('memory')}), - st.fixed_dictionaries( - {'mode': st.just('file'), 'file_path': st.just('/tmp/state.json')} - ), - st.fixed_dictionaries( - {'mode': st.just('redis'), 'redis_url': st.just('redis://localhost:6379/0')} - ), -) - - -_MODEL_APP = st.fixed_dictionaries( - { - 'name': st.just('m'), - 'route_prefix': st.just('/api/v1/m'), - 'import_path': st.just('model'), - 'args': st.fixed_dictionaries( - { - 'model_id': st.just('model-id'), - 'device_group': st.just({'name': 'g', 'ranks': 1, 'device_type': 'CPU'}), - 'device_mesh': st.just({'device_type': 'CPU', 'dp_size': 1}), - 'backend': st.sampled_from(['mock', 'transformers', 'megatron']), - } - ), - } -) - - -_VALID_CONFIG = st.fixed_dictionaries( - { - 'persistence': _PERSISTENCE_VARIANTS, - 'applications': st.lists(_MODEL_APP, min_size=0, max_size=3), - } + st.fixed_dictionaries({ + 'mode': st.just('file'), + 'file_path': st.just('/tmp/state.json') + }), + st.fixed_dictionaries({ + 'mode': st.just('redis'), + 'redis_url': st.just('redis://localhost:6379/0') + }), ) +_MODEL_APP = st.fixed_dictionaries({ + 'name': + st.just('m'), + 'route_prefix': + st.just('/api/v1/m'), + 'import_path': + st.just('model'), + 'args': + st.fixed_dictionaries({ + 'model_id': st.just('model-id'), + 'device_group': st.just({ + 'name': 'g', + 'ranks': 1, + 'device_type': 'CPU' + }), + 'device_mesh': st.just({ + 'device_type': 'CPU', + 'dp_size': 1 + }), + 'backend': st.sampled_from(['mock', 'transformers', 'megatron']), + }), +}) + +_VALID_CONFIG = st.fixed_dictionaries({ + 'persistence': _PERSISTENCE_VARIANTS, + 'applications': st.lists(_MODEL_APP, min_size=0, max_size=3), +}) # ---------- Property 12: valid → fully validated (R6.2, R7.3) -------------- # @@ -92,23 +98,19 @@ def test_property_13_file_mode_missing_path() -> None: @settings(max_examples=100) -@given(bad_backend=st.text(min_size=1, max_size=8).filter( - lambda s: s not in ('mock', 'transformers', 'megatron') -)) +@given(bad_backend=st.text(min_size=1, max_size=8).filter(lambda s: s not in ('mock', 'transformers', 'megatron'))) def test_property_13_bad_backend_names_field(bad_backend: str) -> None: payload = { - 'applications': [ - { - 'name': 'm', - 'import_path': 'model', - 'args': { - 'model_id': 'x', - 'device_group': {}, - 'device_mesh': {}, - 'backend': bad_backend, - }, - } - ] + 'applications': [{ + 'name': 'm', + 'import_path': 'model', + 'args': { + 'model_id': 'x', + 'device_group': {}, + 'device_mesh': {}, + 'backend': bad_backend, + }, + }] } with pytest.raises(ValidationError) as exc: ServerConfig.model_validate(payload) @@ -123,9 +125,7 @@ def test_property_13_nested_field_constraint_violation_named(bad_max_input_token enforced together with cross-field ones (R7.3) and the offending path is visible in the error.""" with pytest.raises(ValidationError) as exc: - ServerConfig.model_validate( - {'task_queue': {'max_input_tokens': bad_max_input_tokens}} - ) + ServerConfig.model_validate({'task_queue': {'max_input_tokens': bad_max_input_tokens}}) errors = exc.value.errors() assert any('max_input_tokens' in err['loc'] for err in errors) @@ -163,8 +163,13 @@ def test_property_15_legacy_field_rejected(legacy_field: str) -> None: @given(unknown=st.text(min_size=1, max_size=20).filter(lambda s: not s.startswith('_'))) def test_property_15_unknown_field_rejected(unknown: str) -> None: known = { - 'ray_namespace', 'proxy_location', 'http_options', - 'telemetry', 'persistence', 'task_queue', 'applications', + 'ray_namespace', + 'proxy_location', + 'http_options', + 'telemetry', + 'persistence', + 'task_queue', + 'applications', } if unknown in known: return @@ -209,7 +214,12 @@ def test_from_yaml_top_level_must_be_mapping(tmp_path: Path) -> None: def test_from_yaml_valid_minimal(tmp_path: Path) -> None: p = tmp_path / 'mini.yaml' yaml.safe_dump( - {'persistence': {'mode': 'memory'}, 'applications': []}, + { + 'persistence': { + 'mode': 'memory' + }, + 'applications': [] + }, p.open('w'), ) cfg = ServerConfig.from_yaml(p) diff --git a/tests/server/model/test_mock_model.py b/tests/server/model/test_mock_model.py index 013aac9ab..5b2d800ee 100644 --- a/tests/server/model/test_mock_model.py +++ b/tests/server/model/test_mock_model.py @@ -10,32 +10,48 @@ """ from __future__ import annotations -import sys - import pytest -from hypothesis import given, settings, strategies as st +import sys +from hypothesis import given, settings +from hypothesis import strategies as st from twinkle.server.exceptions import ConfigError -from twinkle.server.model.app import ( - _MODEL_BACKENDS, - _dispatch_model_backend, - _validate_model_backend, -) +from twinkle.server.model.app import _MODEL_BACKENDS, _dispatch_model_backend, _validate_model_backend from twinkle.server.model.backends.mock_model import TwinkleCompatMockModel - # ---------- Property 1: interface conformance (R1.1, R1.4) ---------------- # - _REQUIRED_METHODS = ( - 'tinker_forward_only', 'tinker_forward_backward', 'tinker_step', - 'tinker_calculate_metric', 'tinker_load', - 'forward_only', 'forward_backward', 'forward', 'calculate_loss', - 'backward', 'step', 'zero_grad', 'lr_step', 'clip_grad_norm', - 'set_loss', 'set_optimizer', 'set_lr_scheduler', 'set_template', - 'set_processor', 'add_metric', 'apply_patch', - 'save', 'load', 'resume_from_checkpoint', 'get_state_dict', 'get_train_configs', - 'add_adapter', 'add_adapter_to_model', 'remove_adapter', 'has_adapter', + 'tinker_forward_only', + 'tinker_forward_backward', + 'tinker_step', + 'tinker_calculate_metric', + 'tinker_load', + 'forward_only', + 'forward_backward', + 'forward', + 'calculate_loss', + 'backward', + 'step', + 'zero_grad', + 'lr_step', + 'clip_grad_norm', + 'set_loss', + 'set_optimizer', + 'set_lr_scheduler', + 'set_template', + 'set_processor', + 'add_metric', + 'apply_patch', + 'save', + 'load', + 'resume_from_checkpoint', + 'get_state_dict', + 'get_train_configs', + 'add_adapter', + 'add_adapter_to_model', + 'remove_adapter', + 'has_adapter', ) @@ -85,7 +101,9 @@ def test_property_2_tinker_forward_backward_loss_is_finite(seq_lens: list) -> No @settings(max_examples=100) -@given(name=st.text(min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) +@given( + name=st.text( + min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) def test_property_3_adapter_add_remove_round_trip(name: str) -> None: m = TwinkleCompatMockModel('mid') assert not m.has_adapter(name) diff --git a/tests/server/sampler/test_mock_sampler.py b/tests/server/sampler/test_mock_sampler.py index 8a92aef7e..ec32858cb 100644 --- a/tests/server/sampler/test_mock_sampler.py +++ b/tests/server/sampler/test_mock_sampler.py @@ -11,24 +11,18 @@ """ from __future__ import annotations -from pathlib import Path - import pytest -from hypothesis import given, settings, strategies as st +from hypothesis import given, settings +from hypothesis import strategies as st +from pathlib import Path from twinkle.data_format import InputFeature, SamplingParams from twinkle.server.exceptions import ConfigError -from twinkle.server.sampler.app import ( - _SAMPLER_TYPES, - _dispatch_sampler_backend, - _validate_sampler_type, -) +from twinkle.server.sampler.app import _SAMPLER_TYPES, _dispatch_sampler_backend, _validate_sampler_type from twinkle.server.sampler.backends.mock_sampler import MockSampler - # ---------- Property 5: interface conformance (R2.1) ---------------------- # - _REQUIRED_METHODS = ('sample', 'apply_patch', 'add_adapter_to_sampler', 'has_adapter') @@ -99,7 +93,9 @@ def test_property_8_no_sampling_params_raises() -> None: @settings(max_examples=100) -@given(name=st.text(min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) +@given( + name=st.text( + min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) def test_property_9_add_adapter_to_sampler(name: str) -> None: s = MockSampler('mid') assert not s.has_adapter(name) @@ -139,7 +135,8 @@ def test_property_11_absent_or_empty_sampler_type_raises(value) -> None: def test_mock_sampler_module_does_not_directly_import_vllm() -> None: """Static check: ``mock_sampler.py`` must not import ``vllm`` directly.""" - src = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'sampler' / 'backends' / 'mock_sampler.py' + src = Path( + __file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'sampler' / 'backends' / 'mock_sampler.py' text = src.read_text() for forbidden in ('import vllm', 'from vllm'): assert forbidden not in text, f'mock_sampler.py contains {forbidden!r}' @@ -154,7 +151,10 @@ def test_mock_example_config_loads_via_server_config() -> None: repo_root = Path(__file__).resolve().parents[3] cfg_path = repo_root / 'cookbook' / 'client' / 'server' / 'mock' / 'server_config.yaml' cfg = ServerConfig.from_yaml(cfg_path) - backends = {a.name: getattr(a.args, 'backend', None) or getattr(a.args, 'sampler_type', None) for a in cfg.applications} + backends = { + a.name: getattr(a.args, 'backend', None) or getattr(a.args, 'sampler_type', None) + for a in cfg.applications + } assert backends.get('models-mock') == 'mock' assert backends.get('sampler-mock') == 'mock' diff --git a/tests/server/state/test_config_signature.py b/tests/server/state/test_config_signature.py index 946498cb5..4bee4833d 100644 --- a/tests/server/state/test_config_signature.py +++ b/tests/server/state/test_config_signature.py @@ -4,18 +4,14 @@ import pytest from twinkle.server.state.backend.memory_backend import MemoryBackend -from twinkle.server.state.config_signature import ( - SignatureMismatchPolicy, - compute_signature, - validate_config_signature, -) - +from twinkle.server.state.config_signature import SignatureMismatchPolicy, compute_signature, validate_config_signature # ---- compute_signature ---- + def test_compute_signature_deterministic(): """Same input should produce same output.""" - config = {"model": "qwen", "batch_size": 8} + config = {'model': 'qwen', 'batch_size': 8} sig1 = compute_signature(config) sig2 = compute_signature(config) assert sig1 == sig2 @@ -23,36 +19,37 @@ def test_compute_signature_deterministic(): def test_compute_signature_different_inputs(): """Different inputs should produce different outputs.""" - config_a = {"model": "qwen", "batch_size": 8} - config_b = {"model": "llama", "batch_size": 8} + config_a = {'model': 'qwen', 'batch_size': 8} + config_b = {'model': 'llama', 'batch_size': 8} assert compute_signature(config_a) != compute_signature(config_b) def test_compute_signature_key_order_independent(): """Key order should not affect the signature (sort_keys=True).""" - config_a = {"b": 2, "a": 1} - config_b = {"a": 1, "b": 2} + config_a = {'b': 2, 'a': 1} + config_b = {'a': 1, 'b': 2} assert compute_signature(config_a) == compute_signature(config_b) def test_compute_signature_is_hex_string(): """Signature should be a valid hex SHA256 string.""" - sig = compute_signature({"key": "value"}) + sig = compute_signature({'key': 'value'}) assert len(sig) == 64 # SHA256 hex = 64 chars - assert all(c in "0123456789abcdef" for c in sig) + assert all(c in '0123456789abcdef' for c in sig) # ---- validate_config_signature ---- + @pytest.mark.asyncio async def test_first_run_stores_signature(): """First run with no stored sig should store it and return True.""" backend = MemoryBackend() - config = {"model": "test"} + config = {'model': 'test'} result = await validate_config_signature(backend, config) assert result is True # Signature should be stored - stored = await backend.get("_meta::config_signature") + stored = await backend.get('_meta::config_signature') assert stored == compute_signature(config) @@ -60,7 +57,7 @@ async def test_first_run_stores_signature(): async def test_same_config_passes(): """Same config on second run should pass validation.""" backend = MemoryBackend() - config = {"model": "test", "lr": 0.001} + config = {'model': 'test', 'lr': 0.001} # First run await validate_config_signature(backend, config) # Second run same config @@ -72,16 +69,14 @@ async def test_same_config_passes(): async def test_different_config_warn_policy(): """Different config with WARN policy should return False and update sig.""" backend = MemoryBackend() - config_v1 = {"model": "v1"} - config_v2 = {"model": "v2"} + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} await validate_config_signature(backend, config_v1) - result = await validate_config_signature( - backend, config_v2, policy=SignatureMismatchPolicy.WARN - ) + result = await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.WARN) assert result is False # Signature should be updated to v2 - stored = await backend.get("_meta::config_signature") + stored = await backend.get('_meta::config_signature') assert stored == compute_signature(config_v2) @@ -89,29 +84,27 @@ async def test_different_config_warn_policy(): async def test_different_config_clear_policy(): """CLEAR policy should clear non-meta data, preserve _meta, return False.""" backend = MemoryBackend() - config_v1 = {"model": "v1"} - config_v2 = {"model": "v2"} + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} # Store initial config await validate_config_signature(backend, config_v1) # Add some user data - await backend.set("session::abc", {"data": 123}) - await backend.set("model::xyz", {"data": 456}) - await backend.set("_meta::other", "keep_this") + await backend.set('session::abc', {'data': 123}) + await backend.set('model::xyz', {'data': 456}) + await backend.set('_meta::other', 'keep_this') - result = await validate_config_signature( - backend, config_v2, policy=SignatureMismatchPolicy.CLEAR - ) + result = await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.CLEAR) assert result is False # User data should be cleared - assert await backend.get("session::abc") is None - assert await backend.get("model::xyz") is None + assert await backend.get('session::abc') is None + assert await backend.get('model::xyz') is None # _meta keys should be preserved - assert await backend.get("_meta::other") == "keep_this" + assert await backend.get('_meta::other') == 'keep_this' # Signature should be updated - stored = await backend.get("_meta::config_signature") + stored = await backend.get('_meta::config_signature') assert stored == compute_signature(config_v2) @@ -121,15 +114,13 @@ async def test_different_config_abort_policy(): from twinkle.server.exceptions import ConfigMismatchError backend = MemoryBackend() - config_v1 = {"model": "v1"} - config_v2 = {"model": "v2"} + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} await validate_config_signature(backend, config_v1) with pytest.raises(ConfigMismatchError): - await validate_config_signature( - backend, config_v2, policy=SignatureMismatchPolicy.ABORT - ) + await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.ABORT) @pytest.mark.asyncio @@ -138,16 +129,14 @@ async def test_abort_policy_does_not_update_signature(): from twinkle.server.exceptions import ConfigMismatchError backend = MemoryBackend() - config_v1 = {"model": "v1"} - config_v2 = {"model": "v2"} + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} await validate_config_signature(backend, config_v1) with pytest.raises(ConfigMismatchError): - await validate_config_signature( - backend, config_v2, policy=SignatureMismatchPolicy.ABORT - ) + await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.ABORT) # Signature should still be v1 - stored = await backend.get("_meta::config_signature") + stored = await backend.get('_meta::config_signature') assert stored == compute_signature(config_v1) diff --git a/tests/server/state/test_de_actor.py b/tests/server/state/test_de_actor.py index ba1573e2f..2770ce358 100644 --- a/tests/server/state/test_de_actor.py +++ b/tests/server/state/test_de_actor.py @@ -9,21 +9,15 @@ """ from __future__ import annotations +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st from unittest import mock -import pytest -from hypothesis import HealthCheck, given, settings, strategies as st - -from twinkle.server.state import ( - PersistenceConfig, - ReplicaRegistry, - ServerState, - get_server_state, - reset_server_state_cache, -) +from twinkle.server.state import (PersistenceConfig, ReplicaRegistry, ServerState, get_server_state, + reset_server_state_cache) from twinkle.server.state.backend.memory_backend import MemoryBackend - # ---------- 4.6: de-Actor wiring + in-process persistence ------------------ # @@ -40,13 +34,9 @@ def test_no_detached_actor_in_source() -> None: Static check: searching the file is enough — the dynamic check below also confirms ``ray.remote`` is never invoked when ``get_server_state`` runs. """ - assert not _ray_attr_used('ray.remote('), ( - 'state/server_state.py still references ray.remote(...) — ' - 'detached actor must not be created (R19.1).' - ) - assert not _ray_attr_used("lifetime='detached'"), ( - "state/server_state.py still uses lifetime='detached' (R19.1)." - ) + assert not _ray_attr_used('ray.remote('), ('state/server_state.py still references ray.remote(...) — ' + 'detached actor must not be created (R19.1).') + assert not _ray_attr_used("lifetime='detached'"), ("state/server_state.py still uses lifetime='detached' (R19.1).") def test_get_server_state_does_not_call_ray_remote() -> None: @@ -88,7 +78,6 @@ def test_in_process_persistence_no_redis_required() -> None: # ---------- 4.5: state-operation equivalence under direct-backend ---------- # - _OP_STRATEGY = st.lists( st.one_of( # ('register_replica', replica_id, max_loras) diff --git a/tests/server/state/test_factory.py b/tests/server/state/test_factory.py index 12568b775..9b0ea051d 100644 --- a/tests/server/state/test_factory.py +++ b/tests/server/state/test_factory.py @@ -2,17 +2,16 @@ from __future__ import annotations import os -import tempfile - import pytest +import tempfile from twinkle.server.state.backend.factory import PersistenceConfig, create_backend from twinkle.server.state.backend.file_backend import FileBackend from twinkle.server.state.backend.memory_backend import MemoryBackend - # ---- Memory Mode ---- + def test_create_backend_none_returns_memory(): """Passing None should return MemoryBackend (default mode).""" backend = create_backend(None) @@ -21,20 +20,21 @@ def test_create_backend_none_returns_memory(): def test_create_backend_memory_mode(): """Explicit memory mode should return MemoryBackend.""" - config = PersistenceConfig(mode="memory") + config = PersistenceConfig(mode='memory') backend = create_backend(config) assert isinstance(backend, MemoryBackend) # ---- File Mode ---- + def test_create_backend_file_mode(): """File mode with file_path should return FileBackend.""" with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: path = f.name try: os.unlink(path) # Let FileBackend create it - config = PersistenceConfig(mode="file", file_path=path) + config = PersistenceConfig(mode='file', file_path=path) backend = create_backend(config) assert isinstance(backend, FileBackend) finally: @@ -44,42 +44,45 @@ def test_create_backend_file_mode(): def test_create_backend_file_mode_missing_path(): """File mode without file_path should raise ValueError.""" - config = PersistenceConfig(mode="file") - with pytest.raises(ValueError, match="file_path"): + config = PersistenceConfig(mode='file') + with pytest.raises(ValueError, match='file_path'): create_backend(config) # ---- Redis Mode ---- + def test_create_backend_redis_mode(): """Redis mode with redis_url should return RedisBackend (if redis available).""" try: import redis # noqa: F401 except ImportError: - pytest.skip("redis package not available") + pytest.skip('redis package not available') + + from unittest.mock import MagicMock, patch - from unittest.mock import patch, MagicMock from twinkle.server.state.backend.redis_backend import RedisBackend - with patch("redis.asyncio.from_url", return_value=MagicMock()): - config = PersistenceConfig(mode="redis", redis_url="redis://localhost:6379") + with patch('redis.asyncio.from_url', return_value=MagicMock()): + config = PersistenceConfig(mode='redis', redis_url='redis://localhost:6379') backend = create_backend(config) assert isinstance(backend, RedisBackend) def test_create_backend_redis_mode_missing_url(): """Redis mode without redis_url should raise ValueError.""" - config = PersistenceConfig(mode="redis") - with pytest.raises(ValueError, match="redis_url"): + config = PersistenceConfig(mode='redis') + with pytest.raises(ValueError, match='redis_url'): create_backend(config) # ---- PersistenceConfig Defaults ---- + def test_persistence_config_defaults(): """PersistenceConfig should have sensible defaults.""" config = PersistenceConfig() - assert config.mode == "memory" + assert config.mode == 'memory' assert config.file_path is None assert config.redis_url is None - assert config.key_prefix == "" + assert config.key_prefix == '' diff --git a/tests/server/state/test_file_backend.py b/tests/server/state/test_file_backend.py index 5ede0cf0e..32c393926 100644 --- a/tests/server/state/test_file_backend.py +++ b/tests/server/state/test_file_backend.py @@ -4,11 +4,10 @@ import asyncio import json import os +import pytest import tempfile import time -import pytest - from twinkle.server.state.backend.file_backend import FileBackend @@ -25,27 +24,28 @@ def tmp_file(): # ---- Basic CRUD ---- + @pytest.mark.asyncio async def test_set_and_get(tmp_file): backend = FileBackend(tmp_file) - await backend.set("key1", {"hello": "world"}) - result = await backend.get("key1") - assert result == {"hello": "world"} + await backend.set('key1', {'hello': 'world'}) + result = await backend.get('key1') + assert result == {'hello': 'world'} @pytest.mark.asyncio async def test_get_nonexistent_key(tmp_file): backend = FileBackend(tmp_file) - result = await backend.get("nonexistent") + result = await backend.get('nonexistent') assert result is None @pytest.mark.asyncio async def test_delete(tmp_file): backend = FileBackend(tmp_file) - await backend.set("key1", "value1") - await backend.delete("key1") - result = await backend.get("key1") + await backend.set('key1', 'value1') + await backend.delete('key1') + result = await backend.get('key1') assert result is None @@ -53,115 +53,120 @@ async def test_delete(tmp_file): async def test_delete_nonexistent_key(tmp_file): """Deleting a key that doesn't exist should not raise.""" backend = FileBackend(tmp_file) - await backend.delete("nonexistent") # Should not raise + await backend.delete('nonexistent') # Should not raise @pytest.mark.asyncio async def test_exists(tmp_file): backend = FileBackend(tmp_file) - await backend.set("key1", "value1") - assert await backend.exists("key1") is True - assert await backend.exists("key2") is False + await backend.set('key1', 'value1') + assert await backend.exists('key1') is True + assert await backend.exists('key2') is False # ---- TTL Expiry ---- + @pytest.mark.asyncio async def test_ttl_expiry(tmp_file): backend = FileBackend(tmp_file) - await backend.set("ephemeral", "data", ttl=1) + await backend.set('ephemeral', 'data', ttl=1) # Immediately should exist - assert await backend.get("ephemeral") == "data" - assert await backend.exists("ephemeral") is True + assert await backend.get('ephemeral') == 'data' + assert await backend.exists('ephemeral') is True # Wait for expiry time.sleep(1.1) - assert await backend.get("ephemeral") is None - assert await backend.exists("ephemeral") is False + assert await backend.get('ephemeral') is None + assert await backend.exists('ephemeral') is False @pytest.mark.asyncio async def test_ttl_none_means_no_expiry(tmp_file): backend = FileBackend(tmp_file) - await backend.set("permanent", "data", ttl=None) + await backend.set('permanent', 'data', ttl=None) time.sleep(0.1) - assert await backend.get("permanent") == "data" + assert await backend.get('permanent') == 'data' # ---- Keys Pattern Matching ---- + @pytest.mark.asyncio async def test_keys_wildcard(tmp_file): backend = FileBackend(tmp_file) - await backend.set("session::abc", "s1") - await backend.set("session::def", "s2") - await backend.set("model::xyz", "m1") + await backend.set('session::abc', 's1') + await backend.set('session::def', 's2') + await backend.set('model::xyz', 'm1') - session_keys = await backend.keys("session::*") - assert sorted(session_keys) == ["session::abc", "session::def"] + session_keys = await backend.keys('session::*') + assert sorted(session_keys) == ['session::abc', 'session::def'] - model_keys = await backend.keys("model::*") - assert model_keys == ["model::xyz"] + model_keys = await backend.keys('model::*') + assert model_keys == ['model::xyz'] - all_keys = await backend.keys("*") + all_keys = await backend.keys('*') assert len(all_keys) == 3 @pytest.mark.asyncio async def test_keys_excludes_expired(tmp_file): backend = FileBackend(tmp_file) - await backend.set("alive", "yes") - await backend.set("dying", "soon", ttl=1) + await backend.set('alive', 'yes') + await backend.set('dying', 'soon', ttl=1) time.sleep(1.1) - keys = await backend.keys("*") - assert keys == ["alive"] + keys = await backend.keys('*') + assert keys == ['alive'] # ---- Count ---- + @pytest.mark.asyncio async def test_count(tmp_file): backend = FileBackend(tmp_file) - await backend.set("a::1", "v") - await backend.set("a::2", "v") - await backend.set("b::1", "v") - assert await backend.count("a::*") == 2 - assert await backend.count("b::*") == 1 - assert await backend.count("*") == 3 + await backend.set('a::1', 'v') + await backend.set('a::2', 'v') + await backend.set('b::1', 'v') + assert await backend.count('a::*') == 2 + assert await backend.count('b::*') == 1 + assert await backend.count('*') == 3 # ---- set_nx ---- + @pytest.mark.asyncio async def test_set_nx_new_key(tmp_file): backend = FileBackend(tmp_file) - result = await backend.set_nx("new_key", "value") + result = await backend.set_nx('new_key', 'value') assert result is True - assert await backend.get("new_key") == "value" + assert await backend.get('new_key') == 'value' @pytest.mark.asyncio async def test_set_nx_existing_key(tmp_file): backend = FileBackend(tmp_file) - await backend.set("existing", "original") - result = await backend.set_nx("existing", "new_value") + await backend.set('existing', 'original') + result = await backend.set_nx('existing', 'new_value') assert result is False # Value should not change - assert await backend.get("existing") == "original" + assert await backend.get('existing') == 'original' @pytest.mark.asyncio async def test_set_nx_expired_key(tmp_file): backend = FileBackend(tmp_file) - await backend.set("expired_key", "old", ttl=1) + await backend.set('expired_key', 'old', ttl=1) time.sleep(1.1) # Key is expired, set_nx should succeed - result = await backend.set_nx("expired_key", "new_value") + result = await backend.set_nx('expired_key', 'new_value') assert result is True - assert await backend.get("expired_key") == "new_value" + assert await backend.get('expired_key') == 'new_value' # ---- Health Check ---- + @pytest.mark.asyncio async def test_health_check(tmp_file): backend = FileBackend(tmp_file) @@ -170,42 +175,45 @@ async def test_health_check(tmp_file): # ---- Auto-create File ---- + @pytest.mark.asyncio async def test_auto_create_file(): """FileBackend should create the file if it doesn't exist.""" with tempfile.TemporaryDirectory() as tmp_dir: - path = os.path.join(tmp_dir, "subdir", "state.json") - backend = FileBackend(path) + path = os.path.join(tmp_dir, 'subdir', 'state.json') + FileBackend(path) assert os.path.exists(path) # File should be valid JSON - with open(path, 'r') as f: + with open(path) as f: data = json.load(f) assert data == {} # ---- Atomic Write Integrity ---- + @pytest.mark.asyncio async def test_atomic_write_integrity(tmp_file): """After write, reading from file should give consistent data.""" backend = FileBackend(tmp_file) - await backend.set("k1", {"nested": [1, 2, 3]}) - await backend.set("k2", "simple_string") + await backend.set('k1', {'nested': [1, 2, 3]}) + await backend.set('k2', 'simple_string') # Read raw file to verify structure - with open(tmp_file, 'r', encoding='utf-8') as f: + with open(tmp_file, encoding='utf-8') as f: raw = json.load(f) - assert "k1" in raw - assert raw["k1"]["value"] == {"nested": [1, 2, 3]} - assert "k2" in raw - assert raw["k2"]["value"] == "simple_string" + assert 'k1' in raw + assert raw['k1']['value'] == {'nested': [1, 2, 3]} + assert 'k2' in raw + assert raw['k2']['value'] == 'simple_string' # ---- Overwrite Value ---- + @pytest.mark.asyncio async def test_overwrite_value(tmp_file): backend = FileBackend(tmp_file) - await backend.set("key", "v1") - await backend.set("key", "v2") - assert await backend.get("key") == "v2" + await backend.set('key', 'v1') + await backend.set('key', 'v2') + assert await backend.get('key') == 'v2' diff --git a/tests/server/state/test_managers.py b/tests/server/state/test_managers.py index 95f98676a..d7fabd379 100644 --- a/tests/server/state/test_managers.py +++ b/tests/server/state/test_managers.py @@ -1,28 +1,22 @@ """Tests for state managers using MemoryBackend as integration backend.""" from __future__ import annotations +import pytest import time from datetime import datetime, timezone -import pytest - from twinkle.server.state.backend.memory_backend import MemoryBackend from twinkle.server.state.future_manager import FutureManager from twinkle.server.state.model_manager import ModelManager -from twinkle.server.state.models import ( - FutureRecord, - ModelRecord, - SamplingSessionRecord, - SessionRecord, -) +from twinkle.server.state.models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord from twinkle.server.state.sampling_manager import SamplingSessionManager from twinkle.server.state.session_manager import SessionManager - # ============================================================ # SessionManager Tests # ============================================================ + class TestSessionManager: @pytest.fixture @@ -35,58 +29,58 @@ def manager(self, backend): @pytest.mark.asyncio async def test_add_and_get(self, manager): - record = SessionRecord(tags=["test"], sdk_version="1.0") - await manager.add("sess1", record) - result = await manager.get("sess1") + record = SessionRecord(tags=['test'], sdk_version='1.0') + await manager.add('sess1', record) + result = await manager.get('sess1') assert result is not None - assert result.tags == ["test"] - assert result.sdk_version == "1.0" + assert result.tags == ['test'] + assert result.sdk_version == '1.0' @pytest.mark.asyncio async def test_get_nonexistent(self, manager): - result = await manager.get("nonexistent") + result = await manager.get('nonexistent') assert result is None @pytest.mark.asyncio async def test_remove(self, manager): record = SessionRecord() - await manager.add("sess1", record) - removed = await manager.remove("sess1") + await manager.add('sess1', record) + removed = await manager.remove('sess1') assert removed is True - assert await manager.get("sess1") is None + assert await manager.get('sess1') is None @pytest.mark.asyncio async def test_remove_nonexistent(self, manager): - removed = await manager.remove("nonexistent") + removed = await manager.remove('nonexistent') assert removed is False @pytest.mark.asyncio async def test_count(self, manager): - await manager.add("s1", SessionRecord()) - await manager.add("s2", SessionRecord()) + await manager.add('s1', SessionRecord()) + await manager.add('s2', SessionRecord()) assert await manager.count() == 2 @pytest.mark.asyncio async def test_touch_updates_heartbeat(self, manager): record = SessionRecord(last_heartbeat=1000.0) - await manager.add("sess1", record) + await manager.add('sess1', record) before = time.time() - result = await manager.touch("sess1") + result = await manager.touch('sess1') after = time.time() assert result is True - updated = await manager.get("sess1") + updated = await manager.get('sess1') assert before <= updated.last_heartbeat <= after @pytest.mark.asyncio async def test_touch_nonexistent(self, manager): - result = await manager.touch("nonexistent") + result = await manager.touch('nonexistent') assert result is False @pytest.mark.asyncio async def test_get_last_heartbeat(self, manager): record = SessionRecord(last_heartbeat=12345.0) - await manager.add("sess1", record) - hb = await manager.get_last_heartbeat("sess1") + await manager.add('sess1', record) + hb = await manager.get_last_heartbeat('sess1') assert hb == 12345.0 @pytest.mark.asyncio @@ -94,23 +88,23 @@ async def test_cleanup_expired(self, manager): now = time.time() # Old session old_record = SessionRecord(last_heartbeat=now - 1000) - await manager.add("old_sess", old_record) + await manager.add('old_sess', old_record) # Recent session new_record = SessionRecord(last_heartbeat=now) - await manager.add("new_sess", new_record) + await manager.add('new_sess', new_record) cutoff = now - 500 removed_count = await manager.cleanup_expired(cutoff) assert removed_count == 1 - assert await manager.get("old_sess") is None - assert await manager.get("new_sess") is not None + assert await manager.get('old_sess') is None + assert await manager.get('new_sess') is not None @pytest.mark.asyncio async def test_cleanup_expired_uses_created_at_fallback(self, manager): """When last_heartbeat is 0, should use created_at for expiry check.""" old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() record = SessionRecord(last_heartbeat=0.0, created_at=old_time) - await manager.add("old_sess", record) + await manager.add('old_sess', record) cutoff = time.time() - 100 removed_count = await manager.cleanup_expired(cutoff) @@ -121,6 +115,7 @@ async def test_cleanup_expired_uses_created_at_fallback(self, manager): # ModelManager Tests # ============================================================ + class TestModelManager: @pytest.fixture @@ -133,73 +128,73 @@ def manager(self, backend): @pytest.mark.asyncio async def test_add_and_get(self, manager): - record = ModelRecord(token="tok1", session_id="sess1", base_model="qwen") - await manager.add("model1", record) - result = await manager.get("model1") + record = ModelRecord(token='tok1', session_id='sess1', base_model='qwen') + await manager.add('model1', record) + result = await manager.get('model1') assert result is not None - assert result.token == "tok1" - assert result.base_model == "qwen" + assert result.token == 'tok1' + assert result.base_model == 'qwen' @pytest.mark.asyncio async def test_remove(self, manager): - record = ModelRecord(token="tok1") - await manager.add("model1", record) - removed = await manager.remove("model1") + record = ModelRecord(token='tok1') + await manager.add('model1', record) + removed = await manager.remove('model1') assert removed is True - assert await manager.get("model1") is None + assert await manager.get('model1') is None @pytest.mark.asyncio async def test_token_limit_enforced(self, manager): """Adding more models than per_token_model_limit should raise RuntimeError.""" for i in range(3): - await manager.add(f"m{i}", ModelRecord(token="tok1")) + await manager.add(f"m{i}", ModelRecord(token='tok1')) - with pytest.raises(RuntimeError, match="Model limit exceeded"): - await manager.add("m3", ModelRecord(token="tok1")) + with pytest.raises(RuntimeError, match='Model limit exceeded'): + await manager.add('m3', ModelRecord(token='tok1')) @pytest.mark.asyncio async def test_token_limit_per_token(self, manager): """Limit is per-token, different tokens have separate limits.""" for i in range(3): - await manager.add(f"a{i}", ModelRecord(token="tokenA")) + await manager.add(f"a{i}", ModelRecord(token='tokenA')) # Different token should work - await manager.add("b0", ModelRecord(token="tokenB")) - assert await manager.get("b0") is not None + await manager.add('b0', ModelRecord(token='tokenB')) + assert await manager.get('b0') is not None @pytest.mark.asyncio async def test_replica_registration(self, manager): - await manager.register_replica("replica1", max_loras=5) + await manager.register_replica('replica1', max_loras=5) info = await manager.get_capacity_info() - assert info["max_loras"] == 5 - assert info["used_loras"] == 0 - assert info["free_loras"] == 5 + assert info['max_loras'] == 5 + assert info['used_loras'] == 0 + assert info['free_loras'] == 5 @pytest.mark.asyncio async def test_capacity_info_after_add(self, manager): - await manager.register_replica("r1", max_loras=3) - record = ModelRecord(token="tok1", replica_id="r1") - await manager.add("m1", record) + await manager.register_replica('r1', max_loras=3) + record = ModelRecord(token='tok1', replica_id='r1') + await manager.add('m1', record) info = await manager.get_capacity_info() - assert info["used_loras"] == 1 - assert info["free_loras"] == 2 + assert info['used_loras'] == 1 + assert info['free_loras'] == 2 @pytest.mark.asyncio async def test_indexes_derived_from_backend(self, manager): """Per-token / per-replica counts are derived from the backend.""" - record1 = ModelRecord(token="tok1", replica_id="r1") - record2 = ModelRecord(token="tok1", replica_id="r2") - await manager.add("m1", record1) - await manager.add("m2", record2) + record1 = ModelRecord(token='tok1', replica_id='r1') + record2 = ModelRecord(token='tok1', replica_id='r2') + await manager.add('m1', record1) + await manager.add('m2', record2) - await manager.register_replica("r1", max_loras=5) - await manager.register_replica("r2", max_loras=5) + await manager.register_replica('r1', max_loras=5) + await manager.register_replica('r2', max_loras=5) # Backend-derived availability reflects all persisted records. - avail = await manager.get_available_replica_ids(["r1", "r2"]) - assert avail == ["r1", "r2"] + avail = await manager.get_available_replica_ids(['r1', 'r2']) + assert avail == ['r1', 'r2'] # Per-token count enforces the limit using the persisted records. - count = await manager._count_models_for_token("tok1") + count = await manager._count_models_for_token('tok1') assert count == 2 @pytest.mark.asyncio @@ -207,38 +202,39 @@ async def test_cascade_cleanup_by_session(self, manager): """Models owned by expired sessions should be cleaned up.""" now = time.time() record = ModelRecord( - token="tok1", - session_id="expired_sess", + token='tok1', + session_id='expired_sess', created_at=datetime.now(timezone.utc).isoformat(), ) - await manager.add("m1", record) + await manager.add('m1', record) # Cleanup with cascade removed = await manager.cleanup_expired( cutoff_time=now - 10000, # cutoff is old, so age-based wouldn't trigger - expired_session_ids=["expired_sess"], + expired_session_ids=['expired_sess'], ) assert removed == 1 - assert await manager.get("m1") is None + assert await manager.get('m1') is None @pytest.mark.asyncio async def test_get_available_replica_ids(self, manager): - await manager.register_replica("r1", max_loras=2) - await manager.register_replica("r2", max_loras=1) + await manager.register_replica('r1', max_loras=2) + await manager.register_replica('r2', max_loras=1) # Fill r2 - await manager.add("m1", ModelRecord(token="t", replica_id="r2")) + await manager.add('m1', ModelRecord(token='t', replica_id='r2')) - available = await manager.get_available_replica_ids(["r1", "r2", "r3_unknown"]) + available = await manager.get_available_replica_ids(['r1', 'r2', 'r3_unknown']) # r1 has capacity, r2 is full, r3 unknown (conservative include) - assert "r1" in available - assert "r2" not in available - assert "r3_unknown" in available + assert 'r1' in available + assert 'r2' not in available + assert 'r3_unknown' in available # ============================================================ # SamplingSessionManager Tests # ============================================================ + class TestSamplingSessionManager: @pytest.fixture @@ -251,47 +247,48 @@ def manager(self, backend): @pytest.mark.asyncio async def test_add_and_get(self, manager): - record = SamplingSessionRecord(session_id="sess1", base_model="qwen") - await manager.add("samp1", record) - result = await manager.get("samp1") + record = SamplingSessionRecord(session_id='sess1', base_model='qwen') + await manager.add('samp1', record) + result = await manager.get('samp1') assert result is not None - assert result.session_id == "sess1" - assert result.base_model == "qwen" + assert result.session_id == 'sess1' + assert result.base_model == 'qwen' @pytest.mark.asyncio async def test_cleanup_expired_by_age(self, manager): old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() - record = SamplingSessionRecord(session_id="sess1", created_at=old_time) - await manager.add("samp_old", record) + record = SamplingSessionRecord(session_id='sess1', created_at=old_time) + await manager.add('samp_old', record) # Recent - record2 = SamplingSessionRecord(session_id="sess2") - await manager.add("samp_new", record2) + record2 = SamplingSessionRecord(session_id='sess2') + await manager.add('samp_new', record2) cutoff = time.time() - 100 removed = await manager.cleanup_expired(cutoff) assert removed == 1 - assert await manager.get("samp_old") is None - assert await manager.get("samp_new") is not None + assert await manager.get('samp_old') is None + assert await manager.get('samp_new') is not None @pytest.mark.asyncio async def test_cleanup_expired_cascade(self, manager): """Sampling sessions should be cleaned when their parent session expires.""" - record = SamplingSessionRecord(session_id="expired_sess") - await manager.add("samp1", record) + record = SamplingSessionRecord(session_id='expired_sess') + await manager.add('samp1', record) removed = await manager.cleanup_expired( cutoff_time=0.0, # Won't catch by age - expired_session_ids=["expired_sess"], + expired_session_ids=['expired_sess'], ) assert removed == 1 - assert await manager.get("samp1") is None + assert await manager.get('samp1') is None # ============================================================ # FutureManager Tests # ============================================================ + class TestFutureManager: @pytest.fixture @@ -305,26 +302,22 @@ def manager(self, backend): @pytest.mark.asyncio async def test_store_status_creates_new(self, manager): await manager.store_status( - request_id="req1", - status="pending", - model_id="model1", + request_id='req1', + status='pending', + model_id='model1', ) - result = await manager.get("req1") + result = await manager.get('req1') assert result is not None - assert result.status == "pending" - assert result.model_id == "model1" + assert result.status == 'pending' + assert result.model_id == 'model1' @pytest.mark.asyncio async def test_store_status_updates_existing(self, manager): - await manager.store_status( - request_id="req1", status="pending", model_id="m1" - ) - await manager.store_status( - request_id="req1", status="completed", model_id="m1", result={"output": "done"} - ) - result = await manager.get("req1") - assert result.status == "completed" - assert result.result == {"output": "done"} + await manager.store_status(request_id='req1', status='pending', model_id='m1') + await manager.store_status(request_id='req1', status='completed', model_id='m1', result={'output': 'done'}) + result = await manager.get('req1') + assert result.status == 'completed' + assert result.result == {'output': 'done'} @pytest.mark.asyncio async def test_store_status_with_pydantic_result(self, manager): @@ -334,51 +327,44 @@ async def test_store_status_with_pydantic_result(self, manager): class MockResult(BaseModel): score: float = 0.95 - await manager.store_status( - request_id="req1", status="completed", model_id="m1", - result=MockResult() - ) - result = await manager.get("req1") - assert result.result == {"score": 0.95} + await manager.store_status(request_id='req1', status='completed', model_id='m1', result=MockResult()) + result = await manager.get('req1') + assert result.result == {'score': 0.95} @pytest.mark.asyncio async def test_store_status_preserves_reason(self, manager): - await manager.store_status( - request_id="req1", status="rate_limited", model_id=None, - reason="Too many requests" - ) - result = await manager.get("req1") - assert result.reason == "Too many requests" + await manager.store_status(request_id='req1', status='rate_limited', model_id=None, reason='Too many requests') + result = await manager.get('req1') + assert result.reason == 'Too many requests' @pytest.mark.asyncio async def test_cleanup_expired(self, manager): old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() - old_record = FutureRecord( - status="completed", created_at=old_time, updated_at=old_time - ) - await manager.add("old_req", old_record) + old_record = FutureRecord(status='completed', created_at=old_time, updated_at=old_time) + await manager.add('old_req', old_record) - new_record = FutureRecord(status="pending") - await manager.add("new_req", new_record) + new_record = FutureRecord(status='pending') + await manager.add('new_req', new_record) cutoff = time.time() - 100 removed = await manager.cleanup_expired(cutoff) assert removed == 1 - assert await manager.get("old_req") is None - assert await manager.get("new_req") is not None + assert await manager.get('old_req') is None + assert await manager.get('new_req') is not None @pytest.mark.asyncio async def test_get_nonexistent(self, manager): - result = await manager.get("nonexistent") + result = await manager.get('nonexistent') assert result is None @pytest.mark.asyncio async def test_store_status_queue_state(self, manager): await manager.store_status( - request_id="req1", status="queued", model_id="m1", - queue_state="paused_rate_limit", - queue_state_reason="Rate limit hit" - ) - result = await manager.get("req1") - assert result.queue_state == "paused_rate_limit" - assert result.queue_state_reason == "Rate limit hit" + request_id='req1', + status='queued', + model_id='m1', + queue_state='paused_rate_limit', + queue_state_reason='Rate limit hit') + result = await manager.get('req1') + assert result.queue_state == 'paused_rate_limit' + assert result.queue_state_reason == 'Rate limit hit' diff --git a/tests/server/state/test_redis_backend.py b/tests/server/state/test_redis_backend.py index dde5c1af7..813b60708 100644 --- a/tests/server/state/test_redis_backend.py +++ b/tests/server/state/test_redis_backend.py @@ -2,14 +2,13 @@ from __future__ import annotations import json -from unittest.mock import AsyncMock, MagicMock, patch - import pytest +from unittest.mock import AsyncMock, MagicMock, patch # Skip entire module if redis package not available -redis = pytest.importorskip("redis") +redis = pytest.importorskip('redis') -from twinkle.server.state.backend.redis_backend import RedisBackend +from twinkle.server.state.backend.redis_backend import RedisBackend # noqa: E402 @pytest.fixture @@ -29,142 +28,150 @@ def mock_redis_client(): @pytest.fixture def backend_no_prefix(mock_redis_client): """RedisBackend with no key prefix.""" - with patch("redis.asyncio.from_url", return_value=mock_redis_client): - backend = RedisBackend("redis://localhost:6379") + with patch('redis.asyncio.from_url', return_value=mock_redis_client): + backend = RedisBackend('redis://localhost:6379') return backend @pytest.fixture def backend_with_prefix(mock_redis_client): """RedisBackend with key prefix.""" - with patch("redis.asyncio.from_url", return_value=mock_redis_client): - backend = RedisBackend("redis://localhost:6379", key_prefix="twinkle:") + with patch('redis.asyncio.from_url', return_value=mock_redis_client): + backend = RedisBackend('redis://localhost:6379', key_prefix='twinkle:') return backend # ---- SET ---- + @pytest.mark.asyncio async def test_set_without_ttl(backend_no_prefix, mock_redis_client): - await backend_no_prefix.set("mykey", {"data": 123}) - mock_redis_client.set.assert_called_once_with("mykey", json.dumps({"data": 123})) + await backend_no_prefix.set('mykey', {'data': 123}) + mock_redis_client.set.assert_called_once_with('mykey', json.dumps({'data': 123})) @pytest.mark.asyncio async def test_set_with_ttl(backend_no_prefix, mock_redis_client): - await backend_no_prefix.set("mykey", "value", ttl=60) - mock_redis_client.set.assert_called_once_with("mykey", json.dumps("value"), ex=60) + await backend_no_prefix.set('mykey', 'value', ttl=60) + mock_redis_client.set.assert_called_once_with('mykey', json.dumps('value'), ex=60) @pytest.mark.asyncio async def test_set_with_prefix(backend_with_prefix, mock_redis_client): - await backend_with_prefix.set("mykey", "val") - mock_redis_client.set.assert_called_once_with("twinkle:mykey", json.dumps("val")) + await backend_with_prefix.set('mykey', 'val') + mock_redis_client.set.assert_called_once_with('twinkle:mykey', json.dumps('val')) # ---- GET ---- + @pytest.mark.asyncio async def test_get_existing_key(backend_no_prefix, mock_redis_client): - mock_redis_client.get.return_value = json.dumps({"hello": "world"}) - result = await backend_no_prefix.get("mykey") - mock_redis_client.get.assert_called_once_with("mykey") - assert result == {"hello": "world"} + mock_redis_client.get.return_value = json.dumps({'hello': 'world'}) + result = await backend_no_prefix.get('mykey') + mock_redis_client.get.assert_called_once_with('mykey') + assert result == {'hello': 'world'} @pytest.mark.asyncio async def test_get_nonexistent_key(backend_no_prefix, mock_redis_client): mock_redis_client.get.return_value = None - result = await backend_no_prefix.get("missing") + result = await backend_no_prefix.get('missing') assert result is None @pytest.mark.asyncio async def test_get_with_prefix(backend_with_prefix, mock_redis_client): - mock_redis_client.get.return_value = json.dumps("data") - result = await backend_with_prefix.get("mykey") - mock_redis_client.get.assert_called_once_with("twinkle:mykey") - assert result == "data" + mock_redis_client.get.return_value = json.dumps('data') + result = await backend_with_prefix.get('mykey') + mock_redis_client.get.assert_called_once_with('twinkle:mykey') + assert result == 'data' # ---- DELETE ---- + @pytest.mark.asyncio async def test_delete(backend_no_prefix, mock_redis_client): - await backend_no_prefix.delete("mykey") - mock_redis_client.delete.assert_called_once_with("mykey") + await backend_no_prefix.delete('mykey') + mock_redis_client.delete.assert_called_once_with('mykey') @pytest.mark.asyncio async def test_delete_with_prefix(backend_with_prefix, mock_redis_client): - await backend_with_prefix.delete("mykey") - mock_redis_client.delete.assert_called_once_with("twinkle:mykey") + await backend_with_prefix.delete('mykey') + mock_redis_client.delete.assert_called_once_with('twinkle:mykey') # ---- EXISTS ---- + @pytest.mark.asyncio async def test_exists_true(backend_no_prefix, mock_redis_client): mock_redis_client.exists.return_value = 1 - result = await backend_no_prefix.exists("mykey") + result = await backend_no_prefix.exists('mykey') assert result is True - mock_redis_client.exists.assert_called_once_with("mykey") + mock_redis_client.exists.assert_called_once_with('mykey') @pytest.mark.asyncio async def test_exists_false(backend_no_prefix, mock_redis_client): mock_redis_client.exists.return_value = 0 - result = await backend_no_prefix.exists("mykey") + result = await backend_no_prefix.exists('mykey') assert result is False # ---- KEYS ---- + @pytest.mark.asyncio async def test_keys_pattern(backend_no_prefix, mock_redis_client): - mock_redis_client.keys.return_value = ["session::a", "session::b"] - result = await backend_no_prefix.keys("session::*") - mock_redis_client.keys.assert_called_once_with("session::*") - assert result == ["session::a", "session::b"] + mock_redis_client.keys.return_value = ['session::a', 'session::b'] + result = await backend_no_prefix.keys('session::*') + mock_redis_client.keys.assert_called_once_with('session::*') + assert result == ['session::a', 'session::b'] @pytest.mark.asyncio async def test_keys_with_prefix(backend_with_prefix, mock_redis_client): - mock_redis_client.keys.return_value = ["twinkle:session::a", "twinkle:session::b"] - result = await backend_with_prefix.keys("session::*") - mock_redis_client.keys.assert_called_once_with("twinkle:session::*") + mock_redis_client.keys.return_value = ['twinkle:session::a', 'twinkle:session::b'] + result = await backend_with_prefix.keys('session::*') + mock_redis_client.keys.assert_called_once_with('twinkle:session::*') # Result should have prefix stripped - assert result == ["session::a", "session::b"] + assert result == ['session::a', 'session::b'] # ---- COUNT ---- + @pytest.mark.asyncio async def test_count(backend_no_prefix, mock_redis_client): - mock_redis_client.keys.return_value = ["a", "b", "c"] - result = await backend_no_prefix.count("*") + mock_redis_client.keys.return_value = ['a', 'b', 'c'] + result = await backend_no_prefix.count('*') assert result == 3 # ---- SET_NX ---- + @pytest.mark.asyncio async def test_set_nx_success(backend_no_prefix, mock_redis_client): mock_redis_client.set.return_value = True # nx succeeded - result = await backend_no_prefix.set_nx("newkey", {"value": 1}) - mock_redis_client.set.assert_called_once_with("newkey", json.dumps({"value": 1}), nx=True) + result = await backend_no_prefix.set_nx('newkey', {'value': 1}) + mock_redis_client.set.assert_called_once_with('newkey', json.dumps({'value': 1}), nx=True) assert result is True @pytest.mark.asyncio async def test_set_nx_failure(backend_no_prefix, mock_redis_client): mock_redis_client.set.return_value = None # nx failed (key exists) - result = await backend_no_prefix.set_nx("existing", "val") + result = await backend_no_prefix.set_nx('existing', 'val') assert result is False # ---- HEALTH CHECK ---- + @pytest.mark.asyncio async def test_health_check_healthy(backend_no_prefix, mock_redis_client): mock_redis_client.ping.return_value = True @@ -174,13 +181,14 @@ async def test_health_check_healthy(backend_no_prefix, mock_redis_client): @pytest.mark.asyncio async def test_health_check_unhealthy(backend_no_prefix, mock_redis_client): - mock_redis_client.ping.side_effect = ConnectionError("offline") + mock_redis_client.ping.side_effect = ConnectionError('offline') result = await backend_no_prefix.health_check() assert result is False # ---- CLOSE ---- + @pytest.mark.asyncio async def test_close(backend_no_prefix, mock_redis_client): await backend_no_prefix.close() @@ -189,10 +197,11 @@ async def test_close(backend_no_prefix, mock_redis_client): # ---- JSON Serialization ---- + @pytest.mark.asyncio async def test_json_serialization_complex_types(backend_no_prefix, mock_redis_client): """Values should be JSON-serialized before storage.""" - complex_value = {"list": [1, 2, 3], "nested": {"key": "val"}, "null": None} - await backend_no_prefix.set("complex", complex_value) + complex_value = {'list': [1, 2, 3], 'nested': {'key': 'val'}, 'null': None} + await backend_no_prefix.set('complex', complex_value) expected_json = json.dumps(complex_value) - mock_redis_client.set.assert_called_once_with("complex", expected_json) + mock_redis_client.set.assert_called_once_with('complex', expected_json) diff --git a/tests/server/state/test_redis_integration.py b/tests/server/state/test_redis_integration.py index cae9ecc99..11f2def2f 100644 --- a/tests/server/state/test_redis_integration.py +++ b/tests/server/state/test_redis_integration.py @@ -14,19 +14,18 @@ import asyncio import os -import uuid - import pytest +import uuid from twinkle.server.state import ServerState from twinkle.server.state.backend.factory import PersistenceConfig, create_backend from twinkle.server.state.backend.redis_backend import RedisBackend - REDIS_URL = os.environ.get('TWINKLE_TEST_REDIS_URL', 'redis://localhost:6379/0') def _can_reach_redis() -> bool: + async def _check() -> bool: backend = RedisBackend(REDIS_URL) try: @@ -81,9 +80,7 @@ def make_state(isolation_prefix: str): created: list[ServerState] = [] def _make() -> ServerState: - backend = create_backend( - PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=isolation_prefix) - ) + backend = create_backend(PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=isolation_prefix)) state = ServerState(backend=backend) created.append(state) return state @@ -123,9 +120,7 @@ async def test_property_26_model_write_visible(make_state) -> None: b = make_state() rid = f'r-{uuid.uuid4().hex[:6]}' await a.register_replica(rid, max_loras=2) - mid = await a.register_model( - {'base_model': 'mock'}, token='tok-A', model_id='mid-A', replica_id=rid - ) + mid = await a.register_model({'base_model': 'mock'}, token='tok-A', model_id='mid-A', replica_id=rid) meta = await b.get_model_metadata(mid) assert meta is not None @@ -159,7 +154,7 @@ async def test_property_27_concurrent_config_writes_no_torn_records(make_state) async def writer(state: ServerState, items: dict) -> None: await asyncio.gather(*(state.add_config(k, v) for k, v in items.items())) - half = list(payload.items())[: n // 2] + half = list(payload.items())[:n // 2] other = list(payload.items())[n // 2:] await asyncio.gather(writer(a, dict(half)), writer(b, dict(other))) diff --git a/tests/server/telemetry/test_context_carrier.py b/tests/server/telemetry/test_context_carrier.py index c02716141..36b35b67a 100644 --- a/tests/server/telemetry/test_context_carrier.py +++ b/tests/server/telemetry/test_context_carrier.py @@ -5,9 +5,8 @@ """ from __future__ import annotations -from unittest import mock - import pytest +from unittest import mock from twinkle.server.telemetry import context_carrier from twinkle.server.telemetry.context_carrier import activate_carrier, make_carrier diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py index 6d7b9ccff..636f8f87f 100644 --- a/tests/server/telemetry/test_tracing_and_correlation.py +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -11,20 +11,15 @@ """ from __future__ import annotations -from unittest import mock - import pytest -from hypothesis import given, settings, strategies as st +from hypothesis import given, settings +from hypothesis import strategies as st +from unittest import mock from twinkle.server.telemetry import correlation -from twinkle.server.telemetry.correlation import ( - CORRELATION_KEYS, - PREFIX, - set_correlation_attrs, -) +from twinkle.server.telemetry.correlation import CORRELATION_KEYS, PREFIX, set_correlation_attrs from twinkle.server.telemetry.tracing import _NoopSpan, traced_operation - # ---------- Property 23: prefix invariant (R11.3) ------------------------- # @@ -49,6 +44,7 @@ def test_property_23_helper_constants_complete() -> None: class _RecordingSpan: + def __init__(self) -> None: self.attrs: dict[str, object] = {} @@ -66,8 +62,7 @@ def set_attribute(self, key: str, value: object) -> None: correlation.REPLICA_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), correlation.TOKEN_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), }, - ) -) + )) def test_property_22_set_correlation_attrs_skips_none(payload: dict) -> None: span = _RecordingSpan() set_correlation_attrs(span, payload) @@ -215,7 +210,8 @@ def test_init_telemetry_attaches_handler_to_twinkle_logger() -> None: the entire server's logs would be invisible in Loki / OTLP backends. """ import logging - from opentelemetry import _logs as _otel_logs, metrics, trace + from opentelemetry import _logs as _otel_logs + from opentelemetry import metrics, trace from opentelemetry.sdk._logs import LoggingHandler from opentelemetry.util._once import Once @@ -238,29 +234,22 @@ def test_init_telemetry_attaches_handler_to_twinkle_logger() -> None: logging.getLogger(name).removeHandler(h) try: - provider.init_telemetry(provider.TelemetryConfig( - enabled=True, debug=True, # debug=True → console exporter, no real OTLP needed - service_name='twinkle-server-test', - )) - root_handlers = [ - h for h in logging.getLogger().handlers if isinstance(h, LoggingHandler) - ] - twinkle_handlers = [ - h for h in logging.getLogger('twinkle').handlers if isinstance(h, LoggingHandler) - ] + provider.init_telemetry( + provider.TelemetryConfig( + enabled=True, + debug=True, # debug=True → console exporter, no real OTLP needed + service_name='twinkle-server-test', + )) + root_handlers = [h for h in logging.getLogger().handlers if isinstance(h, LoggingHandler)] + twinkle_handlers = [h for h in logging.getLogger('twinkle').handlers if isinstance(h, LoggingHandler)] assert len(root_handlers) == 1, root_handlers assert len(twinkle_handlers) == 1, twinkle_handlers - assert root_handlers[0] is twinkle_handlers[0], ( - 'root and twinkle should share the same handler instance' - ) + assert root_handlers[0] is twinkle_handlers[0], ('root and twinkle should share the same handler instance') finally: provider.shutdown_telemetry() # shutdown should detach from both - assert all( - not isinstance(h, LoggingHandler) - for name in ('', 'twinkle') - for h in logging.getLogger(name).handlers - ) + assert all(not isinstance(h, LoggingHandler) for name in ('', 'twinkle') + for h in logging.getLogger(name).handlers) def test_pyproject_declares_telemetry_extras() -> None: @@ -281,24 +270,18 @@ def test_grafana_dashboard_includes_resource_panels() -> None: repo_root = Path(__file__).resolve().parents[3] dashboard = json.loads( - (repo_root / 'cookbook' / 'observability' / 'grafana' / 'dashboards' - / 'twinkle-overview.json').read_text() - ) + (repo_root / 'cookbook' / 'observability' / 'grafana' / 'dashboards' / 'twinkle-overview.json').read_text()) titles = ' | '.join(p['title'].lower() for p in dashboard['panels']) for required in ('cpu', 'memory', 'gpu utilization', 'gpu memory'): assert required in titles, f'dashboard missing panel containing {required!r}' # Each resource gauge name must be referenced by at least one panel target. - targets = ' | '.join( - t.get('expr', '') - for p in dashboard['panels'] - for t in p.get('targets', []) - ) + targets = ' | '.join(t.get('expr', '') for p in dashboard['panels'] for t in p.get('targets', [])) for metric in ( - 'twinkle_system_cpu_utilization', - 'twinkle_system_memory_usage_bytes', - 'twinkle_process_memory_usage_bytes', - 'twinkle_gpu_utilization', - 'twinkle_gpu_memory_usage_bytes', + 'twinkle_system_cpu_utilization', + 'twinkle_system_memory_usage_bytes', + 'twinkle_process_memory_usage_bytes', + 'twinkle_gpu_utilization', + 'twinkle_gpu_memory_usage_bytes', ): assert metric in targets, f'dashboard does not query metric {metric!r}' diff --git a/tests/server/utils/task_queue/test_config.py b/tests/server/utils/task_queue/test_config.py index c76096a1a..305210caf 100644 --- a/tests/server/utils/task_queue/test_config.py +++ b/tests/server/utils/task_queue/test_config.py @@ -11,12 +11,12 @@ from __future__ import annotations import pytest -from hypothesis import given, settings, strategies as st +from hypothesis import given, settings +from hypothesis import strategies as st from pydantic import ValidationError from twinkle.server.utils.task_queue.config import TaskQueueConfig - # ---------- defaults snapshot used by Property 18 (R9.8) ------------------- # DEFAULTS = { @@ -28,10 +28,8 @@ 'max_input_tokens': 16000, } - # ---------- Property 16: constraint enforcement (R9.2-9.5, 9.7) ------------ # - _CONSTRAINED_GE0_FLOATS = ['rps_limit', 'tps_limit', 'queue_timeout', 'token_cleanup_interval'] @@ -74,9 +72,8 @@ def test_property_16_max_input_tokens_rejects_lt_1(bad_value: int) -> None: cleanup=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), mit=st.integers(min_value=1, max_value=10_000_000), ) -def test_property_16_valid_values_accepted( - rps: float, tps: float, win: float, qt: float, cleanup: float, mit: int -) -> None: +def test_property_16_valid_values_accepted(rps: float, tps: float, win: float, qt: float, cleanup: float, + mit: int) -> None: """Any value satisfying the constraints constructs successfully.""" cfg = TaskQueueConfig( rps_limit=rps, @@ -93,7 +90,6 @@ def test_property_16_valid_values_accepted( # ---------- Property 17: from_dict equivalence (R9.6) ---------------------- # - _INPUT_DICT_STRATEGY = st.fixed_dictionaries( {}, optional={ @@ -101,14 +97,11 @@ def test_property_16_valid_values_accepted( 'tps_limit': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), 'window_seconds': st.floats(min_value=1e-6, max_value=1e6, allow_nan=False, allow_infinity=False), 'queue_timeout': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), - 'token_cleanup_interval': st.floats( - min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'token_cleanup_interval': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), 'max_input_tokens': st.integers(min_value=1, max_value=10_000_000), 'enabled': st.booleans(), - 'execution_timeout': st.floats( - min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), - 'token_cleanup_multiplier': st.floats( - min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'execution_timeout': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'token_cleanup_multiplier': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), }, ) From 4b59671f39adf0c5dd86dbcd353927e2c55802ad Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 13:02:10 +0800 Subject: [PATCH 25/34] chore: gitignore .kiro/ (local spec/planning notes) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8cfd041ff..daf3fa743 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,7 @@ images /custom/ megatron_output/ .qoder +.kiro/ # Pytorch *.pth From cc291390ab0de950b51e4520b0cab5a9b96ee585 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 13:19:52 +0800 Subject: [PATCH 26/34] style: convert double-quoted f-strings to single quotes (CI Python 3.11) The pre-commit-hooks v6.0.0 ``double-quote-string-fixer`` skips ``FSTRING_*`` tokens on Python 3.12+ but on the CI runner's Python 3.11 the f-strings are emitted as a single ``STRING`` token and get rewritten. Manually converted the 4 affected files so the hook is a no-op on either interpreter: - src/twinkle/server/state/base.py - src/twinkle/server/state/backend/redis_backend.py - src/twinkle/server/state/config_signature.py - tests/server/state/test_managers.py Verified: pre-commit run --all-files passes under a fresh Python 3.11.15 env (CI's runner version), and 244 unit + property + contract tests still pass under the twinkle env (Python 3.12). --- src/twinkle/server/state/backend/redis_backend.py | 2 +- src/twinkle/server/state/base.py | 10 +++++----- src/twinkle/server/state/config_signature.py | 10 +++++----- tests/server/state/test_managers.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/twinkle/server/state/backend/redis_backend.py b/src/twinkle/server/state/backend/redis_backend.py index b9b639c23..7f431924a 100644 --- a/src/twinkle/server/state/backend/redis_backend.py +++ b/src/twinkle/server/state/backend/redis_backend.py @@ -28,7 +28,7 @@ def __init__(self, redis_url: str, key_prefix: str = '') -> None: def _make_key(self, key: str) -> str: """Add namespace prefix to key.""" - return f"{self._prefix}{key}" if self._prefix else key + return f'{self._prefix}{key}' if self._prefix else key def _strip_prefix(self, key: str) -> str: """Remove namespace prefix from full key.""" diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index b7c6a85ab..97c2bf968 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -23,12 +23,12 @@ class BaseManager(ABC, Generic[T]): def __init__(self, backend: StateBackend, key_prefix: str, record_type: type[T], expiration_timeout: float): self._backend = backend - self._prefix = key_prefix # e.g. "session::", "model::", "sampling::", "future::" + self._prefix = key_prefix # e.g. 'session::', 'model::', 'sampling::', 'future::' self._record_type = record_type self.expiration_timeout = expiration_timeout def _make_key(self, resource_id: str) -> str: - return f"{self._prefix}{resource_id}" + return f'{self._prefix}{resource_id}' def _strip_prefix(self, key: str) -> str: return key[len(self._prefix):] @@ -56,16 +56,16 @@ async def remove(self, resource_id: str) -> bool: async def count(self) -> int: """Count all records managed by this manager.""" - return await self._backend.count(f"{self._prefix}*") + return await self._backend.count(f'{self._prefix}*') async def keys(self) -> list[str]: """Get all resource IDs (without prefix).""" - raw_keys = await self._backend.keys(f"{self._prefix}*") + raw_keys = await self._backend.keys(f'{self._prefix}*') return [self._strip_prefix(k) for k in raw_keys] async def get_all(self) -> dict[str, T]: """Load all records from backend. Used for index rebuilding.""" - all_keys = await self._backend.keys(f"{self._prefix}*") + all_keys = await self._backend.keys(f'{self._prefix}*') result = {} for key in all_keys: data = await self._backend.get(key) diff --git a/src/twinkle/server/state/config_signature.py b/src/twinkle/server/state/config_signature.py index 9fa3a6296..43bba4f60 100644 --- a/src/twinkle/server/state/config_signature.py +++ b/src/twinkle/server/state/config_signature.py @@ -71,9 +71,9 @@ async def validate_config_signature( return True # Mismatch detected - logger.warning(f"Config signature mismatch! " - f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " - f"Policy: {policy.value}") + logger.warning(f'Config signature mismatch! ' + f'Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... ' + f'Policy: {policy.value}') if policy == SignatureMismatchPolicy.WARN: # Update to new signature and continue @@ -91,8 +91,8 @@ async def validate_config_signature( return False elif policy == SignatureMismatchPolicy.ABORT: - raise ConfigMismatchError(f"Configuration signature mismatch. " - f"Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... " + raise ConfigMismatchError(f'Configuration signature mismatch. ' + f'Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... ' f"Use policy='warn' or 'clear' to allow startup with changed config.") return False diff --git a/tests/server/state/test_managers.py b/tests/server/state/test_managers.py index d7fabd379..ec51dd324 100644 --- a/tests/server/state/test_managers.py +++ b/tests/server/state/test_managers.py @@ -147,7 +147,7 @@ async def test_remove(self, manager): async def test_token_limit_enforced(self, manager): """Adding more models than per_token_model_limit should raise RuntimeError.""" for i in range(3): - await manager.add(f"m{i}", ModelRecord(token='tok1')) + await manager.add(f'm{i}', ModelRecord(token='tok1')) with pytest.raises(RuntimeError, match='Model limit exceeded'): await manager.add('m3', ModelRecord(token='tok1')) @@ -156,7 +156,7 @@ async def test_token_limit_enforced(self, manager): async def test_token_limit_per_token(self, manager): """Limit is per-token, different tokens have separate limits.""" for i in range(3): - await manager.add(f"a{i}", ModelRecord(token='tokenA')) + await manager.add(f'a{i}', ModelRecord(token='tokenA')) # Different token should work await manager.add('b0', ModelRecord(token='tokenB')) assert await manager.get('b0') is not None From c80781468f98645c1d62b3078545d2043b31d6a7 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 15:29:45 +0800 Subject: [PATCH 27/34] fix(server): address code-review gaps from server-config-observability-refactor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - worker.py: emit `task_queue.execute` (R10.2) + nested `.` (R10.3) spans so the queued handler op is observable, not just state-level ops. - sampler/processor non-queued handlers: wrap primary ops in `traced_operation` so set_template / add_adapter / apply_patch / processor.create / processor.call also satisfy R10.3. - config_signature: persist `_meta::config_payload` on first run so drift diff renders real stored-vs-current field differences (R15.3) instead of always showing the current config as if it were entirely new. - mock model / sampler: replace Python's salted `hash(tuple-of-strings)` with SHA-256 over a canonical string form so deterministic outputs (R2.5/R4.4/R4.5) hold across processes — built-in hash is PYTHONHASHSEED-salted and would diverge across replicas / restarts. - context_carrier: document that the current topology routes every cross- deployment hop through the Gateway HTTP proxy (already trace-propagating), so there are no in-process DeploymentHandle call sites to thread the carrier through today; the helpers remain the supported integration point for any future handle-based hop. - launcher: wire ServerConfig.proxy_location into `serve.start(...)` (example configs already declare it) and make ApplicationSpec a real top-level import so `get_type_hints(_deploy_application)` resolves at runtime. --- src/twinkle/server/launcher.py | 12 +++++-- .../server/model/backends/mock_model.py | 17 +++++++--- .../server/processor/twinkle_handlers.py | 17 ++++++++-- .../server/sampler/backends/mock_sampler.py | 20 +++++++++--- .../server/sampler/twinkle_handlers.py | 11 +++++-- src/twinkle/server/state/config_signature.py | 6 +++- .../server/telemetry/context_carrier.py | 11 +++++++ src/twinkle/server/utils/task_queue/worker.py | 32 +++++++++++++++---- 8 files changed, 102 insertions(+), 24 deletions(-) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 23d563865..41662c5fb 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -29,6 +29,7 @@ from twinkle import get_logger from twinkle.server.config import ServerConfig +from twinkle.server.config.application_spec import ApplicationSpec from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches logger = get_logger() @@ -205,8 +206,15 @@ def _start_serve(self) -> None: pass http_options = self.config.http_options.model_dump() - serve.start(http_options=http_options) - logger.info(f'Ray Serve started with http_options={http_options}') + serve_kwargs: dict[str, Any] = {'http_options': http_options} + # ``proxy_location`` controls where the Ray Serve HTTP proxy runs + # (``EveryNode`` / ``HeadOnly`` / ``Disabled``). The example configs + # set this field, so honour it here instead of silently ignoring. + if self.config.proxy_location: + serve_kwargs['proxy_location'] = self.config.proxy_location + serve.start(**serve_kwargs) + logger.info(f'Ray Serve started with http_options={http_options}, ' + f'proxy_location={self.config.proxy_location!r}') self._serve_started = True diff --git a/src/twinkle/server/model/backends/mock_model.py b/src/twinkle/server/model/backends/mock_model.py index 358dec8b1..7215350b8 100644 --- a/src/twinkle/server/model/backends/mock_model.py +++ b/src/twinkle/server/model/backends/mock_model.py @@ -17,15 +17,24 @@ """ from __future__ import annotations +import hashlib import numpy as np from typing import Any def _seed_for(model_id: str, adapter_name: str | None, seed: int, *extra: Any) -> int: - """Deterministic per-request RNG seed derived from string/int components.""" - h = hash((str(model_id), str(adapter_name), int(seed), tuple(map(repr, extra)))) - # numpy seeds must fit in uint32. - return h & 0xFFFFFFFF + """Deterministic per-request RNG seed derived from string/int components. + + Uses SHA-256 over a canonical string form rather than Python's built-in + ``hash()``: the latter is salted per process (PYTHONHASHSEED) for tuples + containing strings, which would make identical requests on different + replicas / restarts produce different outputs and break R2.5 / R4.4 / R4.5 + across processes. + """ + parts = (str(model_id), str(adapter_name), str(int(seed)), *(repr(x) for x in extra)) + digest = hashlib.sha256('\x1f'.join(parts).encode('utf-8')).digest() + # numpy seeds must fit in uint32; take the first 4 bytes of the digest. + return int.from_bytes(digest[:4], 'big') class TwinkleCompatMockModel: diff --git a/src/twinkle/server/processor/twinkle_handlers.py b/src/twinkle/server/processor/twinkle_handlers.py index 66799ee82..7e6a1f7e3 100644 --- a/src/twinkle/server/processor/twinkle_handlers.py +++ b/src/twinkle/server/processor/twinkle_handlers.py @@ -18,6 +18,8 @@ from .app import ProcessorManagement import twinkle_client.types as types +from twinkle.server.telemetry.correlation import SESSION_ID, TOKEN_ID +from twinkle.server.telemetry.tracing import traced_operation from twinkle.server.utils.validation import get_session_id_from_request, get_token_from_request from twinkle.utils.logger import get_logger from twinkle_client.common.serialize import deserialize_object @@ -77,7 +79,14 @@ def _do_create(): return getattr(processor_module, class_type)( remote_group=_remote_group, device_mesh=_device_mesh, instance_id=processor_id, **resolved_kwargs) - processor = await asyncio.get_running_loop().run_in_executor(None, _do_create) + # R10.3: span the primary processor.create op with token + session correlation. + with traced_operation( + f'processor.create.{processor_type_name}.{class_type}', + attrs={ + TOKEN_ID: token, + SESSION_ID: session_id, + }): + processor = await asyncio.get_running_loop().run_in_executor(None, _do_create) self.resource_dict[processor_id] = processor return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) @@ -117,7 +126,11 @@ def _do_call(): except StopIteration: return True, None - is_exhausted, result = await asyncio.get_running_loop().run_in_executor(None, _do_call) + # R10.3: span the primary processor.call op so each invocation is observable. + with traced_operation( + f'processor.call.{function_name}', + attrs={TOKEN_ID: get_token_from_request(request)}): + is_exhausted, result = await asyncio.get_running_loop().run_in_executor(None, _do_call) if function_name == '__next__': if is_exhausted: diff --git a/src/twinkle/server/sampler/backends/mock_sampler.py b/src/twinkle/server/sampler/backends/mock_sampler.py index 4565929c4..330c677b9 100644 --- a/src/twinkle/server/sampler/backends/mock_sampler.py +++ b/src/twinkle/server/sampler/backends/mock_sampler.py @@ -13,6 +13,7 @@ """ from __future__ import annotations +import hashlib import numpy as np from typing import Any, List, Optional @@ -20,6 +21,19 @@ from twinkle.data_format import SampledSequence, SampleResponse, SamplingParams +def _stable_seed(*parts: Any) -> int: + """Cross-process-stable numpy seed (uint32) derived from string parts. + + Python's built-in ``hash()`` of a tuple containing strings is salted per + process (PYTHONHASHSEED), which would make identical sample requests on + different replicas / restarts produce different outputs and violate R2.5. + Use a stable digest instead. + """ + canonical = '\x1f'.join(str(p) for p in parts).encode('utf-8') + digest = hashlib.sha256(canonical).digest() + return int.from_bytes(digest[:4], 'big') + + class MockSampler: """Deterministic numpy-only sampler. @@ -56,11 +70,7 @@ def sample( for prompt_idx, _ in enumerate(normalized): sequences: list[SampledSequence] = [] for sample_idx in range(num_samples): - seed = ( - abs( - hash( - (str(self.model_id), str(adapter_name), int(self._seed), int(prompt_idx), int(sample_idx)))) - & 0xFFFFFFFF) + seed = _stable_seed(self.model_id, adapter_name, self._seed, prompt_idx, sample_idx) rng = np.random.default_rng(seed) tokens = [int(t) for t in rng.integers(low=0, high=max(1, self._vocab_size), size=max_tokens)] logprobs_per_token = rng.uniform(-2.0, 0.0, size=max_tokens).astype(float).tolist() diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index b27ec23d0..2d4dc86e5 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -19,6 +19,8 @@ import twinkle_client.types as types from twinkle.data_format import InputFeature, SamplingParams, Trajectory +from twinkle.server.telemetry.correlation import MODEL_ID, TOKEN_ID +from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger logger = get_logger() @@ -165,7 +167,8 @@ async def set_template( ) -> types.SetTemplateResponse: """Set the chat template for encoding Trajectory inputs.""" extra_kwargs = body.model_extra or {} - self.sampler.set_template(body.template_cls, **extra_kwargs) + with traced_operation('sampler.set_template'): + self.sampler.set_template(body.template_cls, **extra_kwargs) return types.SetTemplateResponse() @app.post('/twinkle/add_adapter_to_sampler', response_model=types.AddAdapterResponse) @@ -181,7 +184,8 @@ async def add_adapter_to_sampler( from peft import LoraConfig config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - self.sampler.add_adapter_to_sampler(full_adapter_name, config) + with traced_operation('sampler.add_adapter_to_sampler', attrs={MODEL_ID: full_adapter_name}): + self.sampler.add_adapter_to_sampler(full_adapter_name, config) return types.AddAdapterResponse(adapter_name=full_adapter_name) @@ -193,4 +197,5 @@ async def apply_patch( ) -> None: extra_kwargs = body.model_extra or {} patch_cls = deserialize_object(body.patch_cls) - self.sampler.apply_patch(patch_cls, **extra_kwargs) + with traced_operation('sampler.apply_patch'): + self.sampler.apply_patch(patch_cls, **extra_kwargs) diff --git a/src/twinkle/server/state/config_signature.py b/src/twinkle/server/state/config_signature.py index 43bba4f60..f38fa3cb5 100644 --- a/src/twinkle/server/state/config_signature.py +++ b/src/twinkle/server/state/config_signature.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) _SIGNATURE_KEY = '_meta::config_signature' +_PAYLOAD_KEY = '_meta::config_payload' class SignatureMismatchPolicy(str, Enum): @@ -136,13 +137,16 @@ async def validate_against_backend(persistence_config: Any, current_config: dict if stored_sig is None: await backend.set(_SIGNATURE_KEY, current_sig) + # Persist the payload alongside the signature so a future drift diff can + # render real ``stored vs current`` field-level differences (R15.3). + await backend.set(_PAYLOAD_KEY, current_config) logger.info('No previous config signature found. Stored current signature.') return if stored_sig == current_sig: return - stored_payload = await backend.get('_meta::config_payload') + stored_payload = await backend.get(_PAYLOAD_KEY) diff = _format_diff(stored_payload if isinstance(stored_payload, dict) else None, current_config) raise ConfigMismatchError('Persistence configuration drifted since the last launch. ' f'Stored signature: {stored_sig[:12]}..., current signature: {current_sig[:12]}...\n' diff --git a/src/twinkle/server/telemetry/context_carrier.py b/src/twinkle/server/telemetry/context_carrier.py index 13a0d5b78..1cb3f7714 100644 --- a/src/twinkle/server/telemetry/context_carrier.py +++ b/src/twinkle/server/telemetry/context_carrier.py @@ -12,6 +12,17 @@ When the OTEL SDK is missing, both helpers degrade to a NoOp: ``make_carrier`` returns an empty dict and ``activate_carrier`` is a no-op context manager that runs the body without raising (R13.4 / R18.3). + +Note on current wiring (R13.3): the present server topology routes every +cross-deployment hop through the Gateway's localhost HTTP proxy, which already +carries the trace context in HTTP headers via :func:`twinkle.server.telemetry. +tracing.inject_context`. There are therefore no in-process ``DeploymentHandle`` +call sites in ``src/`` today to thread the carrier through. These helpers are +the supported integration point for any future deployment-to-deployment handle +calls (e.g. Model → Sampler over a non-HTTP path); add ``trace_context: dict | +None = None`` to the handle signature, build it on the caller with +:func:`make_carrier`, and wrap the receiver body in :func:`activate_carrier` +to preserve a single trace id across the hop. """ from __future__ import annotations diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index 5013065bc..81e8a82eb 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -14,6 +14,8 @@ from collections import deque from typing import TYPE_CHECKING, Any, Deque +from twinkle.server.telemetry.correlation import MODEL_ID, TOKEN_ID +from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger from .config import TaskQueueConfig from .types import QueuedTask, QueueState, TaskStatus @@ -198,14 +200,30 @@ async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue exec_start = time.monotonic() task_status = 'completed' exec_time = 0.0 + # R10.2: one span per queued task execution; R10.3: a nested span tagged + # `.` (e.g. `model.forward`, `sampler.sample`) so + # the handler's primary op has its own span carrying token/model + # correlation. The nested span is started inside the try/except so + # span lifecycle stays correctly scoped on timeout / exceptions. + queue_attrs = { + TOKEN_ID: task.token, + MODEL_ID: task.model_id, + 'twinkle.task_type': task_type, + 'twinkle.deployment': self._deployment_name or 'unknown', + 'twinkle.queue_key': queue_key, + } + handler_attrs = {TOKEN_ID: task.token, MODEL_ID: task.model_id} + handler_span_name = f'{self._deployment_name or "deployment"}.{task_type}' try: - coro = task.coro_factory() - logger.debug(f'[ComputeWorker] Task {task.request_id} executing, ' - f'type={task_type}, queue_key={queue_key}') - if self._config.execution_timeout > 0: - result = await asyncio.wait_for(coro, timeout=self._config.execution_timeout) - else: - result = await coro + with traced_operation('task_queue.execute', attrs=queue_attrs): + logger.debug(f'[ComputeWorker] Task {task.request_id} executing, ' + f'type={task_type}, queue_key={queue_key}') + with traced_operation(handler_span_name, attrs=handler_attrs): + coro = task.coro_factory() + if self._config.execution_timeout > 0: + result = await asyncio.wait_for(coro, timeout=self._config.execution_timeout) + else: + result = await coro exec_time = time.monotonic() - exec_start logger.info(f'[ComputeWorker] Task {task.request_id} completed in {exec_time:.2f}s, type={task_type}') await self._state.store_future_status( From 77113bd1dd543cc3d48f45775d34d5c0e9168861 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 15:45:04 +0800 Subject: [PATCH 28/34] refactor(server): drop ServerStateProxy alias, use ServerState directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The proxy class was removed in Phase 0d (de-Actor); only the `ServerStateProxy = ServerState` alias survived so existing type hints could keep working through the transition (R19.1). With every call site updated, the alias is now misleading — there is no proxy, just direct backend access. - Delete the alias and its retention comment in `state/server_state.py`. - Remove the re-export from `state/__init__.py`. - Rename all 7 call-site type hints (`router`, `lifecycle/base`, `task_queue/mixin`, `task_queue/worker`, `model/app`, `sampler/app`, `processor/app`) to `ServerState`. Pure rename — zero behavior change. The Client_Facing_API contract is unaffected (R20). --- src/twinkle/server/common/router.py | 4 ++-- src/twinkle/server/model/app.py | 4 ++-- src/twinkle/server/processor/app.py | 4 ++-- src/twinkle/server/sampler/app.py | 4 ++-- src/twinkle/server/state/__init__.py | 3 +-- src/twinkle/server/state/server_state.py | 14 +++++--------- src/twinkle/server/utils/lifecycle/base.py | 4 ++-- src/twinkle/server/utils/task_queue/mixin.py | 6 +++--- src/twinkle/server/utils/task_queue/worker.py | 4 ++-- 9 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/twinkle/server/common/router.py b/src/twinkle/server/common/router.py index 5ecf55d28..cc9b94915 100644 --- a/src/twinkle/server/common/router.py +++ b/src/twinkle/server/common/router.py @@ -4,7 +4,7 @@ RequestRouter, RunningReplica) from typing import Dict, List, Optional -from twinkle.server.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerState, get_server_state from twinkle.utils.logger import get_logger logger = get_logger() @@ -15,7 +15,7 @@ class StickyLoraRequestRouter(FIFOMixin, MultiplexMixin, RequestRouter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() async def choose_replicas( self, diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index b08967918..328db0794 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -16,7 +16,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.exceptions import ConfigError -from twinkle.server.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerState, get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.lifecycle import AdapterManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware @@ -115,7 +115,7 @@ def __init__(self, ) self.model = _dispatch_model_backend(backend, ctor_kwargs) - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() self._replica_registered = False # Initialize mixins diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index d2a8c9b3f..a50cbb995 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -21,7 +21,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger -from twinkle.server.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerState, get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.lifecycle import ProcessorManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware @@ -64,7 +64,7 @@ def __init__(self, # processor objects keyed by processor_id self.resource_dict: dict[str, Any] = {} - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() _cfg = processor_config or {} _env_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 351cc186d..f9884c52b 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -15,7 +15,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.server.exceptions import ConfigError -from twinkle.server.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerState, get_server_state from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin @@ -119,7 +119,7 @@ def __init__(self, ) self.sampler = _dispatch_sampler_backend(sampler_type, sampler_kwargs) - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() # Initialize task queue mixin self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Sampler') diff --git a/src/twinkle/server/state/__init__.py b/src/twinkle/server/state/__init__.py index c6b85914c..04af91e8d 100644 --- a/src/twinkle/server/state/__init__.py +++ b/src/twinkle/server/state/__init__.py @@ -8,7 +8,7 @@ from .models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord from .replica_registry import ReplicaRegistry from .sampling_manager import SamplingSessionManager -from .server_state import ServerState, ServerStateProxy, get_server_state, reset_server_state_cache +from .server_state import ServerState, get_server_state, reset_server_state_cache from .session_manager import SessionManager __all__ = [ @@ -27,7 +27,6 @@ 'ConfigManager', # Server state 'ServerState', - 'ServerStateProxy', 'ReplicaRegistry', 'get_server_state', 'reset_server_state_cache', diff --git a/src/twinkle/server/state/server_state.py b/src/twinkle/server/state/server_state.py index 42b1333e3..4002e3754 100644 --- a/src/twinkle/server/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -472,15 +472,11 @@ async def get_cleanup_stats(self) -> dict[str, Any]: # Direct-backend factory (R19) # --------------------------------------------------------------------------- # -# ``ServerStateProxy`` is intentionally retained as a thin alias of -# ``ServerState`` so call-site type hints (e.g. ``state: ServerStateProxy``) -# keep working without import churn during this transition. The detached Ray -# Actor and the per-method ``self._actor.X.remote(...)`` forwarding are gone: -# every worker now binds a process-local :class:`ServerState` directly to the -# shared :class:`StateBackend`, and the existing ``await state.X(...)`` call -# sites awaited a coroutine before and still do. - -ServerStateProxy = ServerState # type: ignore[assignment] +# The detached Ray Actor and the per-method ``self._actor.X.remote(...)`` +# forwarding are gone: every worker now binds a process-local +# :class:`ServerState` directly to the shared :class:`StateBackend`, and the +# existing ``await state.X(...)`` call sites awaited a coroutine before and +# still do. _PROCESS_STATE_CACHE: dict[str, ServerState] = {} diff --git a/src/twinkle/server/utils/lifecycle/base.py b/src/twinkle/server/utils/lifecycle/base.py index 2d947bdd0..394cd2200 100644 --- a/src/twinkle/server/utils/lifecycle/base.py +++ b/src/twinkle/server/utils/lifecycle/base.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from twinkle.server.state import ServerStateProxy + from twinkle.server.state import ServerState from twinkle.utils.logger import get_logger @@ -38,7 +38,7 @@ class SessionResourceMixin: """ # Type hint for state attribute that inheriting classes must provide - state: ServerStateProxy + state: ServerState # Resource type name for logging (override in subclass) _resource_type: str = 'resource' diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py index 5c8d77b92..1bec0843b 100644 --- a/src/twinkle/server/utils/task_queue/mixin.py +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -22,7 +22,7 @@ from .worker import ComputeWorker if TYPE_CHECKING: - from twinkle.server.state import ServerStateProxy + from twinkle.server.state import ServerState logger = get_logger() @@ -43,11 +43,11 @@ class TaskQueueMixin: Requirements ------------ - Inheriting class must expose self.state: ServerStateProxy and call + Inheriting class must expose self.state: ServerState and call _init_task_queue() during __init__. """ - state: ServerStateProxy + state: ServerState def _init_task_queue(self, config: TaskQueueConfig | None = None, deployment_name: str = '') -> None: """Initialise the task queue, rate limiter, and compute worker.""" diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index 81e8a82eb..a546172c1 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -21,7 +21,7 @@ from .types import QueuedTask, QueueState, TaskStatus if TYPE_CHECKING: - from twinkle.server.state import ServerStateProxy + from twinkle.server.state import ServerState from twinkle.server.utils.metrics import TaskMetrics logger = get_logger() @@ -41,7 +41,7 @@ class ComputeWorker: def __init__( self, - state: ServerStateProxy, + state: ServerState, config: TaskQueueConfig, task_metrics: TaskMetrics | None, deployment_name: str, From fad76f34836a0db64752fc3ef20e8b122aec7f7b Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 16:00:11 +0800 Subject: [PATCH 29/34] docs(observability): add load.py to populate every Grafana overview panel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mock-mode servers leave most dashboard panels at "No data" because: - `histogram_quantile(..., rate(_bucket[5m]))` returns NaN under zero recent traffic — sparse requests render counters but blank histograms; - `up_down_counter` gauges (`active_sessions`, `active_models`, `queue_depth`) emit on delta only and stay invisible until the underlying count moves; - mock backends execute in microseconds so P95 hugs the bottom bucket. `load.py` drives a running mock server with N concurrent users that each: POST /api/v1/twinkle/create_session -> active_sessions++ POST /api/v1/model/mock/twinkle/add_adapter_to_model -> active_models++ POST /api/v1/sampler/mock/twinkle/sample (loop, 80%) -> http rate + latency + queue_depth + task_execution + task_wait POST /api/v1/model/mock/twinkle/forward_only (~20%) -> sticky-LoRA path Sticky `X-Ray-Serve-Request-Id` is pinned per (user, adapter) so the `request_id + '-' + adapter_name` lookup in `assert_resource_exists` resolves on subsequent /forward_only calls. Tunable: `--concurrency`, `--duration`, `--interval`, `--max-tokens`. Bumping `--max-tokens` lifts mock execution time off the bottom histogram bucket so P95 panels show meaningful values. --- cookbook/observability/load.py | 258 +++++++++++++++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 cookbook/observability/load.py diff --git a/cookbook/observability/load.py b/cookbook/observability/load.py new file mode 100644 index 000000000..1aa6d453f --- /dev/null +++ b/cookbook/observability/load.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python +"""Drive a running mock Twinkle server with enough traffic that every panel on +the ``Twinkle Server Overview`` Grafana dashboard shows data. + +Why this script exists +---------------------- +Most empty panels on the overview dashboard are NOT a wiring problem — they +are caused by: + +1. ``histogram_quantile(0.95, rate(..._bucket[5m]))`` returns NaN when the + 5-minute look-back has zero traffic, so latency / wait-time / execution + panels read "No data". +2. ``up_down_counter`` gauges emit on *delta* only. ``active_sessions`` / + ``active_models`` / ``queue_depth`` stay invisible until something + actually changes their underlying count at least once. +3. Mock backends execute in microseconds — even when traffic exists, + histogram P95 hugs the bottom bucket. Bump ``--max-tokens`` so the + sampler's per-request runtime lifts off the floor. + +Pre-reqs +-------- +1. Observability stack running:: + + docker compose -f cookbook/observability/docker-compose.yaml up -d + +2. Mock server running with telemetry **enabled** (the shipped + ``cookbook/client/server/mock/server_config.yaml`` has + ``telemetry.enabled: false`` — flip it to ``true`` or override via env + before launching):: + + TWINKLE_TELEMETRY_ENABLED=true \\ + python -m twinkle.server launch \\ + --config cookbook/client/server/mock/server_config.yaml + +Usage +----- +:: + + # Defaults: 4 concurrent users, 120s, ~2 req/s each + python cookbook/observability/load.py + + # Heavier: 8 users, 5 minutes, longer sampler runtime (lifts P95) + python cookbook/observability/load.py \\ + --concurrency 8 --duration 300 --max-tokens 128 + +In Grafana set the time window to ``Last 15 minutes`` for the rate[5m] +queries to be meaningful. +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import random +import time +import uuid + +import httpx + +# Routes from cookbook/client/server/mock/server_config.yaml — keep in sync. +GATEWAY_ROUTE = '/api/v1' +MODEL_ROUTE = '/api/v1/model/mock' +SAMPLER_ROUTE = '/api/v1/sampler/mock' + +# Any non-empty token is accepted (``is_token_valid`` is permissive by default). +TOKEN = 'load-test-token' + + +def _headers(session_id: str, *, request_id: str) -> dict[str, str]: + """Build the per-request header set the server middleware expects. + + The server's ``verify_request_token`` middleware requires: + - ``Twinkle-Authorization: Bearer `` + - ``X-Ray-Serve-Request-Id`` for sticky routing (any unique string ok) + - ``X-Twinkle-Session-Id`` for session correlation (optional) + + Pass the SAME ``request_id`` for every call against the same registered + adapter so the sticky-LoRA key (``request_id + '-' + adapter_name``) + resolves to the registered resource on subsequent ``/forward_only`` calls. + """ + return { + 'Twinkle-Authorization': f'Bearer {TOKEN}', + 'X-Ray-Serve-Request-Id': request_id, + 'X-Twinkle-Session-Id': session_id, + 'Content-Type': 'application/json', + } + + +def _lora_config_payload(rank: int = 8) -> str: + """JSON payload the server's ``deserialize_object`` will rehydrate into a + ``peft.LoraConfig``. Matches ``twinkle_client.common.serialize.serialize_object``. + """ + return json.dumps({ + '_TWINKLE_TYPE_': 'LoraConfig', + 'r': rank, + 'lora_alpha': rank * 2, + 'lora_dropout': 0.0, + 'bias': 'none', + 'task_type': 'CAUSAL_LM', + 'target_modules': ['q_proj', 'v_proj'], + }) + + +async def create_session(client: httpx.AsyncClient, session_id: str) -> bool: + """POST /api/v1/twinkle/create_session — moves ``active_sessions`` gauge.""" + r = await client.post( + f'{GATEWAY_ROUTE}/twinkle/create_session', + headers=_headers(session_id, request_id=uuid.uuid4().hex), + json={'metadata': {'source': 'load.py'}}, + timeout=10.0, + ) + if r.status_code != 200: + print(f' create_session -> {r.status_code} {r.text[:160]}') + return False + return True + + +async def add_adapter(client: httpx.AsyncClient, adapter_name: str, session_id: str, request_id: str) -> bool: + """POST /api/v1/model/mock/twinkle/add_adapter_to_model — moves + ``active_models`` gauge and goes through the task queue (queue_depth + + task_execution histograms). + """ + body = {'adapter_name': adapter_name, 'config': _lora_config_payload()} + r = await client.post( + f'{MODEL_ROUTE}/twinkle/add_adapter_to_model', + headers=_headers(session_id, request_id=request_id), + json=body, + timeout=30.0, + ) + if r.status_code != 200: + print(f' add_adapter {adapter_name} -> {r.status_code} {r.text[:200]}') + return False + return True + + +async def sample(client: httpx.AsyncClient, session_id: str, *, max_tokens: int) -> int: + """POST /api/v1/sampler/mock/twinkle/sample — primary load. No adapter + registration needed (``adapter_name=''`` skips the resource check).""" + body = { + 'inputs': [{'input_ids': [random.randint(0, 100) for _ in range(8)]}], + 'sampling_params': {'max_tokens': max_tokens}, + 'adapter_name': '', + } + r = await client.post( + f'{SAMPLER_ROUTE}/twinkle/sample', + headers=_headers(session_id, request_id=uuid.uuid4().hex), + json=body, + timeout=60.0, + ) + return r.status_code + + +async def forward_only(client: httpx.AsyncClient, adapter_name: str, session_id: str, request_id: str) -> int: + """POST /api/v1/model/mock/twinkle/forward_only against a registered adapter. + + ``request_id`` MUST be the same one used by the original ``add_adapter`` + call — the server prefixes ``request_id`` onto the adapter key for + sticky-LoRA routing, so reusing it lets ``assert_resource_exists`` find + the adapter we registered. + """ + body = { + 'inputs': [{'input_ids': [random.randint(0, 100) for _ in range(16)]}], + 'adapter_name': adapter_name, + } + r = await client.post( + f'{MODEL_ROUTE}/twinkle/forward_only', + headers=_headers(session_id, request_id=request_id), + json=body, + timeout=30.0, + ) + return r.status_code + + +async def user_loop( + user_id: int, + base_url: str, + deadline: float, + interval: float, + max_tokens: int, +) -> None: + """Per-user driver: create_session + add_adapter once, then loop sample + (and occasional forward_only) until the deadline.""" + session_id = f'load-user-{user_id}-{uuid.uuid4().hex[:6]}' + adapter_name = f'adapter-u{user_id}-{uuid.uuid4().hex[:6]}' + # Pinned per-user request id for the sticky-LoRA path (forward_only against + # the adapter we register in this loop iteration). + sticky_request_id = uuid.uuid4().hex + + async with httpx.AsyncClient(base_url=base_url) as client: + sess_ok = await create_session(client, session_id) + adapter_ok = False + if sess_ok: + adapter_ok = await add_adapter(client, adapter_name, session_id, sticky_request_id) + + ok_n = err_n = 0 + while time.monotonic() < deadline: + # 80% sample (no adapter), 20% forward_only (uses registered adapter). + use_forward = adapter_ok and random.random() < 0.2 + try: + if use_forward: + status = await forward_only(client, adapter_name, session_id, sticky_request_id) + else: + status = await sample(client, session_id, max_tokens=max_tokens) + except Exception as exc: + err_n += 1 + print(f' user {user_id} request error: {exc!r}') + await asyncio.sleep(1.0) + continue + if 200 <= status < 300: + ok_n += 1 + else: + err_n += 1 + # Jittered interval keeps requests from clumping into the same scrape window. + await asyncio.sleep(max(0.01, interval + random.uniform(-interval / 4, interval / 4))) + + print(f' user {user_id:>2} ok={ok_n:>4} err={err_n} ' + f'session={session_id} adapter={adapter_name if adapter_ok else ""}') + + +async def main_async(args: argparse.Namespace) -> None: + deadline = time.monotonic() + args.duration + print(f'Load: base={args.base_url} concurrency={args.concurrency} ' + f'duration={args.duration}s interval={args.interval}s max_tokens={args.max_tokens}') + print(f'Hits: POST {GATEWAY_ROUTE}/twinkle/create_session') + print(f' POST {MODEL_ROUTE}/twinkle/add_adapter_to_model') + print(f' POST {MODEL_ROUTE}/twinkle/forward_only (~20%)') + print(f' POST {SAMPLER_ROUTE}/twinkle/sample (~80%)') + await asyncio.gather(*[ + user_loop(i, args.base_url, deadline, args.interval, args.max_tokens) for i in range(args.concurrency) + ]) + print('Done. Allow ~30s for the next OTLP export tick, then refresh Grafana ') + print('with the time window set to "Last 15 minutes".') + + +def main() -> None: + p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + p.add_argument('--base-url', default='http://localhost:8000', help='Server base URL (default: %(default)s)') + p.add_argument('--concurrency', type=int, default=4, help='Parallel users (default: %(default)s)') + p.add_argument('--duration', type=int, default=120, help='Total seconds to run (default: %(default)s)') + p.add_argument( + '--interval', + type=float, + default=0.5, + help='Mean seconds between requests per worker; ±25%% jitter applied. ' + 'Lower → higher RPS. Default: %(default)s') + p.add_argument( + '--max-tokens', + type=int, + default=64, + help='Mock sampler runtime scales with max_tokens. Bump to >= 64 so ' + 'task_execution P95 lifts off the bottom histogram bucket. ' + 'Default: %(default)s') + args = p.parse_args() + asyncio.run(main_async(args)) + + +if __name__ == '__main__': + main() From 0142f8e8920382997505d91e41a715f583da1616 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 16:36:01 +0800 Subject: [PATCH 30/34] fix(server): MockSampler accepts handler kwargs; load.py uses TELEMETRY_ENABLED=1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two issues exposed when actually running cookbook/observability/load.py against a live mock server: 1. ``MockSampler.sample(...)`` raised ``TypeError`` because the Tinker / Twinkle handlers forward ``adapter_path`` (matching the ``vLLMSampler`` signature) but the mock didn't accept extra kwargs. Added ``**kwargs`` so the mock stays callable through the same handler call sites — 100% of ``/sample`` requests now succeed. 2. load.py docstring told users ``TWINKLE_TELEMETRY_ENABLED=true`` but ``worker_init.ensure_telemetry_initialized`` reads the literal ``"1"``, so telemetry was never actually initialised — every panel showed "No data". Corrected to ``=1`` and added the ``ray start --head`` prereq (the launcher does ``ray.init(address='auto')`` and won't bootstrap one). --- cookbook/observability/load.py | 36 ++++++++++++++----- .../server/sampler/backends/mock_sampler.py | 6 ++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/cookbook/observability/load.py b/cookbook/observability/load.py index 1aa6d453f..99b11d948 100644 --- a/cookbook/observability/load.py +++ b/cookbook/observability/load.py @@ -23,15 +23,23 @@ docker compose -f cookbook/observability/docker-compose.yaml up -d -2. Mock server running with telemetry **enabled** (the shipped +2. Mock server running with telemetry **enabled**. The shipped ``cookbook/client/server/mock/server_config.yaml`` has - ``telemetry.enabled: false`` — flip it to ``true`` or override via env - before launching):: + ``telemetry.enabled: false``; override via env vars before launching. + NB: the worker-side flag is read as the literal string ``"1"`` (see + ``twinkle.server.telemetry.worker_init.ensure_telemetry_initialized``), + NOT ``"true"`` / ``"yes"``:: - TWINKLE_TELEMETRY_ENABLED=true \\ + TWINKLE_TELEMETRY_ENABLED=1 \\ + TWINKLE_TELEMETRY_ENDPOINT=http://localhost:4317 \\ python -m twinkle.server launch \\ --config cookbook/client/server/mock/server_config.yaml + Also start Ray first (the launcher does ``ray.init(address='auto')`` and + will refuse to spin one up locally):: + + ray start --head --num-cpus=4 --disable-usage-stats + Usage ----- :: @@ -66,7 +74,7 @@ TOKEN = 'load-test-token' -def _headers(session_id: str, *, request_id: str) -> dict[str, str]: +def _headers(session_id: str, *, request_id: str, multiplex_key: str | None = None) -> dict[str, str]: """Build the per-request header set the server middleware expects. The server's ``verify_request_token`` middleware requires: @@ -74,16 +82,26 @@ def _headers(session_id: str, *, request_id: str) -> dict[str, str]: - ``X-Ray-Serve-Request-Id`` for sticky routing (any unique string ok) - ``X-Twinkle-Session-Id`` for session correlation (optional) + Model + sampler deployments additionally call + ``serve.get_multiplexed_model_id()`` for sticky-LoRA replica routing — + Ray Serve raises ``ValueError("The model ID cannot be empty.")`` if the + ``serve_multiplexed_model_id`` header is absent. Always set + ``multiplex_key`` for model / sampler calls; the Gateway endpoint + (``/api/v1/twinkle/create_session``) does not need it. + Pass the SAME ``request_id`` for every call against the same registered adapter so the sticky-LoRA key (``request_id + '-' + adapter_name``) resolves to the registered resource on subsequent ``/forward_only`` calls. """ - return { + headers = { 'Twinkle-Authorization': f'Bearer {TOKEN}', 'X-Ray-Serve-Request-Id': request_id, 'X-Twinkle-Session-Id': session_id, 'Content-Type': 'application/json', } + if multiplex_key is not None: + headers['serve_multiplexed_model_id'] = multiplex_key + return headers def _lora_config_payload(rank: int = 8) -> str: @@ -123,7 +141,7 @@ async def add_adapter(client: httpx.AsyncClient, adapter_name: str, session_id: body = {'adapter_name': adapter_name, 'config': _lora_config_payload()} r = await client.post( f'{MODEL_ROUTE}/twinkle/add_adapter_to_model', - headers=_headers(session_id, request_id=request_id), + headers=_headers(session_id, request_id=request_id, multiplex_key=adapter_name), json=body, timeout=30.0, ) @@ -143,7 +161,7 @@ async def sample(client: httpx.AsyncClient, session_id: str, *, max_tokens: int) } r = await client.post( f'{SAMPLER_ROUTE}/twinkle/sample', - headers=_headers(session_id, request_id=uuid.uuid4().hex), + headers=_headers(session_id, request_id=uuid.uuid4().hex, multiplex_key=session_id), json=body, timeout=60.0, ) @@ -164,7 +182,7 @@ async def forward_only(client: httpx.AsyncClient, adapter_name: str, session_id: } r = await client.post( f'{MODEL_ROUTE}/twinkle/forward_only', - headers=_headers(session_id, request_id=request_id), + headers=_headers(session_id, request_id=request_id, multiplex_key=adapter_name), json=body, timeout=30.0, ) diff --git a/src/twinkle/server/sampler/backends/mock_sampler.py b/src/twinkle/server/sampler/backends/mock_sampler.py index 330c677b9..786591041 100644 --- a/src/twinkle/server/sampler/backends/mock_sampler.py +++ b/src/twinkle/server/sampler/backends/mock_sampler.py @@ -59,7 +59,13 @@ def sample( adapter_name: str = '', *, num_samples: int = 1, + **kwargs: Any, ) -> list[SampleResponse]: + # The real ``vLLMSampler.sample`` accepts extra keyword arguments + # (``adapter_path``, ``adapter_uri``, etc.) that the Tinker / Twinkle + # handlers always forward. Swallow them here so the mock backend + # stays callable through the same handler call sites without a + # TypeError. max_tokens = self._resolve_max_tokens(sampling_params) if max_tokens is None or max_tokens < 1: raise ValueError(f'max_tokens must be >= 1, got {max_tokens!r} ' From 90fb3e1d0427ddc38dab402461e4ffe0875f7302 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 16:36:24 +0800 Subject: [PATCH 31/34] fix(server): lazy-start ServerState cleanup loop on first request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ray Serve binds ``serve.get_replica_context().servable_object`` AFTER FastAPI ``lifespan`` startup completes, so the existing lifespan call to ``get_self().state.start_cleanup_task()`` crashed with ``'NoneType' object has no attribute 'state'`` in every worker and was silently swallowed. The cleanup loop drives ``_metrics_loop``, which emits the four ``twinkle_*_active`` resource gauges — so those gauges never produced a single sample and the "Active resources" Grafana panels always read "No data". Move the call from lifespan to first-request lazy-init: - Model / Sampler: ``_on_request_start`` -> ``_ensure_state_cleanup_started`` - Processor: ``_ensure_sticky`` (which every routed call goes through) - Gateway: a tiny ``ensure_state_cleanup_started`` HTTP middleware (no per-handler hook exists) ``state.start_cleanup_task`` is itself idempotent via ``_cleanup_running``; the per-instance flag avoids the await call on every subsequent request. Verified end-to-end against a live mock server with the LGTM stack: ``twinkle_sessions_active=4`` and ``twinkle_futures_active`` now emit correctly. (``twinkle_models_active`` still empty under load — separate diagnostic for a follow-up; ``models.add`` reaches the backend and ``futures.add`` works through the same metrics loop, so likely an instrument-binding issue at lazy-init time worth its own investigation.) --- src/twinkle/server/gateway/server.py | 37 ++++++++++++++++++++++++---- src/twinkle/server/model/app.py | 35 ++++++++++++++++++-------- src/twinkle/server/processor/app.py | 24 ++++++++++++++---- src/twinkle/server/sampler/app.py | 24 ++++++++++++++---- 4 files changed, 95 insertions(+), 25 deletions(-) diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index f3f26ffba..b0644f3c7 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -43,6 +43,24 @@ def __init__(self, types.SupportedModel(model_name='Qwen/Qwen3.6-27B'), ] self._modelscope_config_lock = asyncio.Lock() + self._state_cleanup_started = False + + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + Ray Serve binds ``serve.get_replica_context().servable_object`` AFTER + FastAPI ``lifespan`` startup, so the cleanup task cannot run there + (``get_self()`` returns ``None`` during lifespan). Lazy-init here on + the first request instead. ``start_cleanup_task`` is idempotent via + its internal ``_cleanup_running`` guard. + """ + if self._state_cleanup_started: + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True def _normalize_models(self, supported_models): if not supported_models: @@ -100,11 +118,9 @@ async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() - # Start the ServerState cleanup loop now that we have a running loop. - try: - await get_self().state.start_cleanup_task() - except Exception as e: - logger.warning(f'Failed to start ServerState cleanup task: {e}') + # NOTE: ``state.start_cleanup_task()`` cannot run here — Ray Serve binds + # ``servable_object`` AFTER lifespan startup. Lazy-started from the + # first request via the ``ensure_state_cleanup_started`` middleware. yield try: await get_self().proxy.close() @@ -113,6 +129,17 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + @app.middleware('http') + async def ensure_state_cleanup_started(request: Request, call_next): + # Lazy-init the state cleanup + metrics loops on first request — see + # GatewayServer._ensure_state_cleanup_started. Gateway has no per- + # handler hook, so a tiny middleware covers every route. + try: + await get_self()._ensure_state_cleanup_started() + except Exception as e: + logger.debug(f'state cleanup lazy-init skipped: {e}') + return await call_next(request) + @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 328db0794..fca378c0e 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -129,6 +129,24 @@ async def _ensure_replica_registered(self): await self.state.register_replica(self.replica_id, self.max_loras) self._replica_registered = True + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + Cannot run in FastAPI ``lifespan``: Ray Serve binds + ``serve.get_replica_context().servable_object`` AFTER the lifespan + startup phase, so a lifespan call has no ``self`` to reach. By the + time a request arrives, the binding exists. ``state.start_cleanup_task`` + is itself idempotent via ``_cleanup_running``, but the per-instance + flag avoids the await on every subsequent request. + """ + if getattr(self, '_state_cleanup_started', False): + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True + @serve.multiplexed(max_num_models_per_replica=5) async def _sticky_entry(self, sticky_key: str): return sticky_key @@ -142,6 +160,7 @@ async def _ensure_sticky(self): async def _on_request_start(self, request: Request) -> str: await self._ensure_sticky() await self._ensure_replica_registered() + await self._ensure_state_cleanup_started() token = get_token_from_request(request) return token @@ -208,16 +227,12 @@ async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() - # Start the ServerState cleanup loop now that we have a running loop; - # idempotent across replicas in the same process. - try: - await get_self().state.start_cleanup_task() - except Exception as e: - logger.warning(f'Failed to start ServerState cleanup task: {e}') - try: - await get_self()._ensure_replica_registered() - except Exception as e: - logger.warning(f'Failed to register replica at startup: {e}') + # NOTE: ``state.start_cleanup_task()`` and ``_ensure_replica_registered()`` + # cannot run here — Ray Serve binds ``servable_object`` AFTER lifespan + # startup, so ``get_self()`` returns ``None`` and the call would crash. + # They are lazy-started from the first request via + # ``_on_request_start`` → ``_ensure_state_cleanup_started`` and + # ``_ensure_replica_registered`` respectively. yield try: await get_self().shutdown() diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index a50cbb995..9b6e90b06 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -83,6 +83,22 @@ async def _ensure_sticky(self): await self._sticky_entry(sticky_key) # Lazy-start countdown task on first request (requires running event loop) self._ensure_countdown_started() + await self._ensure_state_cleanup_started() + + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + See ``model/app.py``: ``servable_object`` is unbound during lifespan, + so we lazy-init here. Processor has no ``_on_request_start`` hook; + every routed request flows through ``_ensure_sticky`` first. + """ + if getattr(self, '_state_cleanup_started', False): + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True def _on_processor_expired(self, processor_id: str) -> None: """Called by the countdown thread when a processor's session expires.""" @@ -131,11 +147,9 @@ async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() - # Start the ServerState cleanup loop now that we have a running loop. - try: - await get_self().state.start_cleanup_task() - except Exception as e: - logger.warning(f'Failed to start ServerState cleanup task: {e}') + # NOTE: ``state.start_cleanup_task()`` cannot run here — Ray Serve binds + # ``servable_object`` AFTER lifespan startup. Lazy-started from the + # first request via ``_ensure_sticky`` → ``_ensure_state_cleanup_started``. yield app = FastAPI(lifespan=lifespan) diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index f9884c52b..2cebf7166 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -132,8 +132,24 @@ async def _ensure_sticky(self): sticky_key = serve.get_multiplexed_model_id() await self._sticky_entry(sticky_key) + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + See the matching helper in ``model/app.py`` for the lifespan-timing + rationale: ``servable_object`` is unbound during lifespan, so we lazy- + init here. + """ + if getattr(self, '_state_cleanup_started', False): + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True + async def _on_request_start(self, request: Request) -> str: await self._ensure_sticky() + await self._ensure_state_cleanup_started() token = get_token_from_request(request) return token @@ -182,11 +198,9 @@ async def lifespan(app: FastAPI): # Initialize telemetry in worker process (after deserialization) from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized ensure_telemetry_initialized() - # Start the ServerState cleanup loop now that we have a running loop. - try: - await get_self().state.start_cleanup_task() - except Exception as e: - logger.warning(f'Failed to start ServerState cleanup task: {e}') + # NOTE: ``state.start_cleanup_task()`` cannot run here — Ray Serve binds + # ``servable_object`` AFTER lifespan startup. Lazy-started from the + # first request via ``_on_request_start`` → ``_ensure_state_cleanup_started``. yield app = FastAPI( From c9b77c1741e136befcbb3f958f19e08e260debb6 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 17:03:24 +0800 Subject: [PATCH 32/34] docs(observability): load.py uses server-issued session_id, requires redis backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three bugs uncovered while making active-resource panels light up: 1. ``X-Twinkle-Session-Id`` used a client-side string that the server never persisted, so the adapter countdown loop in ``utils/lifecycle/base.py`` saw "session not found" and expired every registered adapter within ~10s. Now call ``/twinkle/create_session``, take the server-issued id from the response body, and pass that id to every subsequent header. Also heartbeat the session every 5s so it stays alive. 2. ``persistence: memory`` is per-process — Gateway-worker sessions are invisible to the Model worker. Even with the correct session_id the liveness check still fails because Model's MemoryBackend has zero session records. Docstring now states the script requires a shared backend (Redis) and explains the trap; reasonable, since R19.4 cross-worker visibility specifically requires shared persistence. 3. Sampling-session calls hit ``/api/v1/twinkle/create_sampling_session`` and 404'd because the route lives at the gateway root (it is a Tinker route — only ``create_session`` is mounted under ``/twinkle/``). Fixed to call ``/api/v1/create_sampling_session``. Result against a redis-backed mock server: 13/14 dashboard panels populate (``rate_limit_rejections`` stays empty under gentle load — by design; ``gpu_*`` stays empty on a CPU-only mock — by design). --- cookbook/observability/load.py | 101 +++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 23 deletions(-) diff --git a/cookbook/observability/load.py b/cookbook/observability/load.py index 99b11d948..e41338024 100644 --- a/cookbook/observability/load.py +++ b/cookbook/observability/load.py @@ -19,23 +19,31 @@ Pre-reqs -------- -1. Observability stack running:: +1. Observability stack + Redis running:: docker compose -f cookbook/observability/docker-compose.yaml up -d + docker run -d --name twinkle-redis -p 6379:6379 redis:7 -2. Mock server running with telemetry **enabled**. The shipped - ``cookbook/client/server/mock/server_config.yaml`` has - ``telemetry.enabled: false``; override via env vars before launching. - NB: the worker-side flag is read as the literal string ``"1"`` (see - ``twinkle.server.telemetry.worker_init.ensure_telemetry_initialized``), - NOT ``"true"`` / ``"yes"``:: +2. Mock server running with telemetry **enabled** and a SHARED persistence + backend. The shipped ``cookbook/client/server/mock/server_config.yaml`` + ships ``mode: memory`` which is per-process — Gateway-worker sessions + are invisible to Model worker, the adapter countdown loop sees "session + not found" and expires registered adapters within ~10s (which empties + the ``twinkle_models_active`` gauge). For this load script to populate + every panel, switch persistence to redis:: + + persistence: { mode: redis, redis_url: redis://localhost:6379 } + + Telemetry: the worker reads the env flag as the literal string ``"1"`` + (see ``telemetry.worker_init.ensure_telemetry_initialized``), NOT + ``"true"``:: TWINKLE_TELEMETRY_ENABLED=1 \\ TWINKLE_TELEMETRY_ENDPOINT=http://localhost:4317 \\ python -m twinkle.server launch \\ --config cookbook/client/server/mock/server_config.yaml - Also start Ray first (the launcher does ``ray.init(address='auto')`` and + Start Ray first (the launcher does ``ray.init(address='auto')`` and will refuse to spin one up locally):: ray start --head --num-cpus=4 --disable-usage-stats @@ -119,18 +127,57 @@ def _lora_config_payload(rank: int = 8) -> str: }) -async def create_session(client: httpx.AsyncClient, session_id: str) -> bool: - """POST /api/v1/twinkle/create_session — moves ``active_sessions`` gauge.""" +async def create_session(client: httpx.AsyncClient) -> str | None: + """POST /api/v1/twinkle/create_session — returns the SERVER-issued + ``session_id`` so subsequent ``X-Twinkle-Session-Id`` headers reference + a session the server actually persisted. Using a client-side string + would silently fail liveness checks (the adapter countdown loop in + ``utils/lifecycle/base.py`` calls ``state.get_session_last_heartbeat`` + and expires adapters within ~10s when the ID isn't found).""" r = await client.post( f'{GATEWAY_ROUTE}/twinkle/create_session', - headers=_headers(session_id, request_id=uuid.uuid4().hex), + headers=_headers('', request_id=uuid.uuid4().hex), json={'metadata': {'source': 'load.py'}}, timeout=10.0, ) if r.status_code != 200: print(f' create_session -> {r.status_code} {r.text[:160]}') - return False - return True + return None + return r.json().get('session_id') + + +async def create_sampling_session(client: httpx.AsyncClient, session_id: str, model_path: str) -> None: + """POST /api/v1/create_sampling_session — bumps ``active_sampling_sessions``. + This is a Tinker route mounted at the gateway root (NOT under + ``/twinkle/``); the Twinkle gateway handlers only expose ``create_session``.""" + try: + await client.post( + f'{GATEWAY_ROUTE}/create_sampling_session', + headers=_headers(session_id, request_id=uuid.uuid4().hex), + json={ + 'session_id': session_id, + 'sampling_session_seq_id': 0, + 'model_path': model_path, + 'base_model': 'mock-model', + }, + timeout=10.0, + ) + except Exception: + pass + + +async def session_heartbeat(client: httpx.AsyncClient, session_id: str) -> None: + """POST /api/v1/twinkle/session_heartbeat — refreshes the session so + the adapter countdown loop doesn't expire registered adapters mid-load.""" + try: + await client.post( + f'{GATEWAY_ROUTE}/twinkle/session_heartbeat', + headers=_headers(session_id, request_id=uuid.uuid4().hex), + json={'session_id': session_id}, + timeout=5.0, + ) + except Exception: + pass async def add_adapter(client: httpx.AsyncClient, adapter_name: str, session_id: str, request_id: str) -> bool: @@ -197,28 +244,37 @@ async def user_loop( max_tokens: int, ) -> None: """Per-user driver: create_session + add_adapter once, then loop sample - (and occasional forward_only) until the deadline.""" - session_id = f'load-user-{user_id}-{uuid.uuid4().hex[:6]}' + (and occasional forward_only) until the deadline. Periodically heartbeats + the session so the adapter countdown loop doesn't expire registered + adapters mid-load (default adapter_timeout is 1800s, but a missing + heartbeat trips ``_is_session_alive`` long before that).""" adapter_name = f'adapter-u{user_id}-{uuid.uuid4().hex[:6]}' - # Pinned per-user request id for the sticky-LoRA path (forward_only against - # the adapter we register in this loop iteration). sticky_request_id = uuid.uuid4().hex + heartbeat_interval = 5.0 async with httpx.AsyncClient(base_url=base_url) as client: - sess_ok = await create_session(client, session_id) + # IMPORTANT: use the SERVER-issued session_id; sending our own client- + # side string would never match a stored session and registered + # adapters would expire within ~10s. + session_id = await create_session(client) adapter_ok = False - if sess_ok: + if session_id: adapter_ok = await add_adapter(client, adapter_name, session_id, sticky_request_id) + # Best-effort: bump the active_sampling_sessions gauge. + await create_sampling_session(client, session_id, model_path=f'mock://{adapter_name}') ok_n = err_n = 0 + last_hb = time.monotonic() while time.monotonic() < deadline: - # 80% sample (no adapter), 20% forward_only (uses registered adapter). + if session_id and time.monotonic() - last_hb >= heartbeat_interval: + await session_heartbeat(client, session_id) + last_hb = time.monotonic() use_forward = adapter_ok and random.random() < 0.2 try: if use_forward: status = await forward_only(client, adapter_name, session_id, sticky_request_id) else: - status = await sample(client, session_id, max_tokens=max_tokens) + status = await sample(client, session_id or '', max_tokens=max_tokens) except Exception as exc: err_n += 1 print(f' user {user_id} request error: {exc!r}') @@ -228,11 +284,10 @@ async def user_loop( ok_n += 1 else: err_n += 1 - # Jittered interval keeps requests from clumping into the same scrape window. await asyncio.sleep(max(0.01, interval + random.uniform(-interval / 4, interval / 4))) print(f' user {user_id:>2} ok={ok_n:>4} err={err_n} ' - f'session={session_id} adapter={adapter_name if adapter_ok else ""}') + f'session={session_id or ""} adapter={adapter_name if adapter_ok else ""}') async def main_async(args: argparse.Namespace) -> None: From 74b21f40a3717216da4f9fbb039a843663579672 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 17:25:49 +0800 Subject: [PATCH 33/34] chore: drop unfinalized mock/observability surface from docs + cookbook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mock backends and OpenTelemetry pipeline still live in src/ but their public contract isn't settled. Pull them out of user-facing docs and cookbook examples so iteration doesn't churn published surface. - delete Observability.md (en/zh), strip telemetry rows + mock mentions from 服务配置.md, drop entries from both index.rst toctrees - delete cookbook/observability/, cookbook/client/server/mock/, and cookbook/client/server/server_config.example.yaml - strip telemetry: blocks from transformer/megatron server_config.yaml - migrate the mock CPU-only YAML the e2e test needed into tests/server/fixtures/server_config_mock.yaml; CLI + mock-mode-startup tests import the shared path constant - drop now-dead tests: tests/docs/test_docs_smoke.py (asserted removed files), tests/integration/test_lgtm_telemetry.py (gated on removed docker-compose; in-process equivalents already covered), grafana dashboard panel test, and the two mock cookbook README/config asserts in test_mock_sampler.py --- .../client/server/megatron/server_config.yaml | 13 - cookbook/client/server/mock/README.md | 43 --- .../client/server/server_config.example.yaml | 162 -------- .../server/transformer/server_config.yaml | 15 +- cookbook/observability/README.md | 102 ----- cookbook/observability/demo_sft_users.py | 227 ----------- cookbook/observability/docker-compose.yaml | 35 -- .../grafana/dashboards/twinkle-overview.json | 362 ------------------ cookbook/observability/load.py | 331 ---------------- docs/source_en/Usage Guide/Observability.md | 101 ----- docs/source_en/index.rst | 1 - docs/source_zh/index.rst | 1 - ...57\350\247\202\346\265\213\345\214\226.md" | 94 ----- ...15\345\212\241\351\205\215\347\275\256.md" | 30 +- tests/docs/test_docs_smoke.py | 105 ----- tests/integration/test_lgtm_telemetry.py | 274 ------------- tests/integration/test_mock_mode_startup.py | 3 +- tests/server/cli/test_cli.py | 6 +- tests/server/fixtures/__init__.py | 12 + .../server/fixtures/server_config_mock.yaml | 30 +- tests/server/sampler/test_mock_sampler.py | 23 -- .../telemetry/test_tracing_and_correlation.py | 24 -- 22 files changed, 37 insertions(+), 1957 deletions(-) delete mode 100644 cookbook/client/server/mock/README.md delete mode 100644 cookbook/client/server/server_config.example.yaml delete mode 100644 cookbook/observability/README.md delete mode 100644 cookbook/observability/demo_sft_users.py delete mode 100644 cookbook/observability/docker-compose.yaml delete mode 100644 cookbook/observability/grafana/dashboards/twinkle-overview.json delete mode 100644 cookbook/observability/load.py delete mode 100644 docs/source_en/Usage Guide/Observability.md delete mode 100644 "docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" delete mode 100644 tests/docs/test_docs_smoke.py delete mode 100644 tests/integration/test_lgtm_telemetry.py create mode 100644 tests/server/fixtures/__init__.py rename cookbook/client/server/mock/server_config.yaml => tests/server/fixtures/server_config_mock.yaml (59%) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index ae6e9236c..7b18ed68c 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -9,19 +9,6 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 9000 # Port number for the server -# Telemetry configuration for observability (OpenTelemetry-based). -# Disabled by default — opentelemetry-* packages are optional dependencies. -# To enable: -# 1. pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp -# 2. set enabled: true (and debug: true to dump exporters to console for local dev, -# or leave debug: false and point otlp_endpoint at an OTLP collector — see -# cookbook/observability/ for a docker-compose example). -# telemetry: -# enabled: false -# debug: false -# service_name: twinkle-server -# otlp_endpoint: http://localhost:4317 - # Persistence configuration for ServerState (sessions, models, futures, ...). # Top-level placement makes the launcher propagate this to every Ray worker # via env vars, so the configured backend is used regardless of which diff --git a/cookbook/client/server/mock/README.md b/cookbook/client/server/mock/README.md deleted file mode 100644 index 9aefe586e..000000000 --- a/cookbook/client/server/mock/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Mock backend — CPU-only quick start - -This directory ships an all-mock Twinkle Server configuration so you can -launch the HTTP surface in seconds on a CPU-only laptop, no GPU and no -torch/transformers/vllm/megatron required. Use it for local development, -CI smoke tests, and contract-level HTTP debugging. - -> **Not for production.** Mock backends return fixed numpy-derived results -> without performing real model computation or sampling. The training and -> sampling endpoints respond with deterministic synthetic outputs derived -> only from the request shape and a seed. - -## Launch - -```bash -python -m twinkle.server --config cookbook/client/server/mock/server_config.yaml -``` - -The launcher should reach the ready state within **30 seconds** on a CPU-only -host (R4.1) — `ModelManagement` and `SamplerManagement` skip the -`twinkle.initialize(mode='ray', ...)` step that the GPU backends would run -(R3.7, R3.8). - -## What the mock backends do - -- **Model (`backend: mock`)** — numpy-only. Forward / forward-only / - forward-backward calls return deterministic logprobs and elementwise - losses keyed by `(model_id, adapter_name, seed, input_shape)`. Step / - backward / optimizer-update calls are no-ops. Adapter add / remove / - has are tracked in an in-memory record. -- **Sampler (`sampler_type: mock`)** — numpy-only. `sample` returns one - `SampleResponse` per input prompt with `num_samples` sequences of length - `max_tokens`, exactly one logprob entry per emitted token. Repeated calls - with the same parameters return identical token sequences and logprobs. - `max_tokens < 1` raises a validation error. - -## Verifying determinism - -```bash -curl -s -X POST http://localhost:8000/api/v1/model/mock/twinkle/forward_only \ - -H 'Content-Type: application/json' -d @some_payload.json -# Repeat the same request — the response body is byte-for-byte identical. -``` diff --git a/cookbook/client/server/server_config.example.yaml b/cookbook/client/server/server_config.example.yaml deleted file mode 100644 index 468f217c6..000000000 --- a/cookbook/client/server/server_config.example.yaml +++ /dev/null @@ -1,162 +0,0 @@ -# ============================================================================= -# Twinkle Server — fully documented example configuration -# ============================================================================= -# -# This file is a reference: every field carries its type, default value, and -# the available options. Loadable as-is via: -# -# python -m twinkle.server check-config --config server_config.example.yaml -# python -m twinkle.server launch --config server_config.example.yaml -# -# Field naming after the Phase 0c refactor is strict: legacy aliases -# `telemetry_config` / `persistence_config` are no longer accepted (R8). Use -# `telemetry` / `persistence` instead. - -# Optional. Ray cluster namespace. -# Type: string | null. Default: null (resolves to "twinkle_cluster"). -# Env override: TWINKLE_RAY_NAMESPACE. -ray_namespace: twinkle_cluster - -# Optional. Ray Serve proxy placement. -# Type: string | null. Options: "EveryNode" (default for multi-node), "HeadOnly". -proxy_location: EveryNode - -# HTTP listener. -# host: str — bind address. Default "localhost". Use "0.0.0.0" to listen on all. -# port: int — TCP port. Default 8000. -http_options: - host: 0.0.0.0 - port: 8000 - -# Telemetry (OpenTelemetry pipeline). -# enabled: bool — when false, init/shutdown is a NoOp. Default false. -# debug: bool — true: console exporters; false: OTLP exporter. Default false. -# service_name: str — OTEL resource service.name. Default "twinkle-server". -# otlp_endpoint: str — gRPC OTLP endpoint. Default "http://localhost:4317". -# export_interval_ms: int — metric export interval in ms. Default 30000. -# resource_attributes: map[str, any] — extra OTEL Resource attributes. Default {}. -telemetry: - enabled: false - debug: false - service_name: twinkle-server - otlp_endpoint: http://localhost:4317 - export_interval_ms: 30000 - -# Persistence backend for ServerState (sessions, models, futures, ...). -# mode: str — "memory" | "file" | "redis". Default "memory". -# file_path: str — required when mode == "file". -# redis_url: str — required when mode == "redis", e.g. "redis://localhost:6379/0". -# key_prefix: str — optional global key prefix. Default "". -persistence: - mode: memory - # file_path: /tmp/twinkle_state.json - # redis_url: redis://localhost:6379/0 - # key_prefix: "" - -# Task queue / rate-limit defaults (overridable per application under args.queue_config). -# rps_limit: float >= 0 — requests/sec. 0 disables. Default 100.0. -# tps_limit: float >= 0 — input tokens/sec. 0 disables. Default 16000.0. -# window_seconds: float > 0 — rate-limit sliding window. Default 1.0. -# queue_timeout: float >= 0 — max queue wait (s). Default 300.0. -# execution_timeout: float >= 0 — task execution timeout (s). 0 disables. Default 120.0. -# enabled: bool — rate limiting on/off. Default true. -# token_cleanup_multiplier:float >= 0 — token retention multiplier. Default 10.0. -# token_cleanup_interval: float >= 0 — cleanup task interval (s). Default 60.0. -# max_input_tokens: int >= 1 — per-request input token cap. Default 16000. -task_queue: - rps_limit: 100.0 - tps_limit: 16000.0 - window_seconds: 1.0 - queue_timeout: 300.0 - execution_timeout: 120.0 - enabled: true - max_input_tokens: 16000 - -# Applications: each entry deploys one component (server | model | sampler | processor). -# Required fields per entry: -# name: str — Ray Serve app name. -# route_prefix:str — HTTP route prefix. Default "/". -# import_path: str — one of {server, model, sampler, processor}. -# args: map — typed args, schema selected by import_path. -# Optional: -# deployments: list — Ray Serve deployment options (only the first is used). -applications: - - # 1. Tinker-compatible gateway (server) - - name: server - route_prefix: /api/v1 - import_path: server - args: - # ServerArgs schema — fields are optional unless noted. - server_config: - per_token_model_limit: 3 - supported_models: - - mock-model - deployments: - - name: GatewayServer - max_ongoing_requests: 50 - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - - # 2. Model deployment. - # backend: str — required. Options: "mock" | "transformers" | "megatron". - # model_id: str — required. Model identifier (e.g. "Qwen/Qwen3.5-4B"). - # nproc_per_node: int — distributed processes per node. Default 1. - # device_group / device_mesh: dict — required parallelism config. - # max_loras: int — per-replica LoRA capacity. Default 5. - # queue_config: map — overrides task_queue defaults for this app. - - name: models - route_prefix: /api/v1/model - import_path: model - args: - backend: mock - model_id: mock-model - nproc_per_node: 1 - device_group: - name: model - ranks: 1 - device_type: cpu - device_mesh: - device_type: cpu - dp_size: 1 - max_loras: 5 - queue_config: - rps_limit: 100 - tps_limit: 100000 - deployments: - - name: ModelManagement - autoscaling_config: { min_replicas: 1, max_replicas: 1, target_ongoing_requests: 16 } - ray_actor_options: { num_cpus: 0.1 } - - # 3. Sampler deployment. - # sampler_type: str — required. Options: "mock" | "vllm" | "torch". - - name: sampler - route_prefix: /api/v1/sampler - import_path: sampler - args: - sampler_type: mock - model_id: mock-model - nproc_per_node: 1 - device_group: { name: sampler, ranks: 1, device_type: cpu } - device_mesh: { device_type: cpu, dp_size: 1 } - deployments: - - name: SamplerManagement - autoscaling_config: { min_replicas: 1, max_replicas: 1, target_ongoing_requests: 16 } - ray_actor_options: { num_cpus: 0.1 } - - # 4. Processor deployment (CPU-only feature engineering). - - name: processor - route_prefix: /api/v1/processor - import_path: processor - args: - ncpu_proc_per_node: 2 - device_group: { name: processor, ranks: 2, device_type: cpu } - device_mesh: { device_type: cpu, dp_size: 2 } - deployments: - - name: ProcessorManagement - autoscaling_config: { min_replicas: 1, max_replicas: 1, target_ongoing_requests: 128 } - ray_actor_options: { num_cpus: 0.1 } diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 390871165..ee23cc33a 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -9,19 +9,6 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 8000 # Port number for the server -# Telemetry configuration for observability (OpenTelemetry-based). -# Disabled by default — opentelemetry-* packages are optional dependencies. -# To enable: -# 1. pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp -# 2. set enabled: true (and debug: true to dump exporters to console for local dev, -# or leave debug: false and point otlp_endpoint at an OTLP collector — see -# cookbook/observability/ for a docker-compose example). -telemetry: - enabled: false - debug: false - service_name: twinkle-server - otlp_endpoint: http://localhost:4317 - # Persistence configuration for ServerState (sessions, models, futures, ...). # Top-level placement makes the launcher propagate this to every Ray worker # via env vars, so the configured backend is used regardless of which @@ -62,7 +49,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - backend: transformers # Model backend: mock | transformers | megatron + backend: transformers # Model backend: transformers | megatron model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 1 # Number of GPU processes per node diff --git a/cookbook/observability/README.md b/cookbook/observability/README.md deleted file mode 100644 index a17fd009c..000000000 --- a/cookbook/observability/README.md +++ /dev/null @@ -1,102 +0,0 @@ -# Twinkle Observability Stack - -A one-container OTLP receiver + dashboard for Twinkle, built on the -[`grafana/otel-lgtm`](https://github.com/grafana/docker-otel-lgtm) image. -That image bundles OTel Collector, Mimir (Prometheus-compatible), Tempo, -Loki, and Grafana with everything pre-wired — no extra config files needed. - -## What you get - -| Surface | URL | Purpose | -|---|---|---| -| Grafana | `http://localhost:3000` | Dashboards + Explore (metrics / traces / logs) | -| OTLP gRPC | `localhost:4317` | Point Twinkle's `otlp_endpoint` here | -| OTLP HTTP | `localhost:4318` | Same, HTTP alternative | - -## Quick start - -```bash -# 1. Start the stack -cd cookbook/observability -docker compose up -d - -# 2. Make sure Twinkle has the OTLP exporter -pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp - -# 3. In your server_config.yaml: -# -# telemetry_config: -# enabled: true -# debug: false # debug=true dumps to console instead of OTLP -# service_name: twinkle-server -# otlp_endpoint: http://localhost:4317 - -# 4. Launch Twinkle as usual -python -m twinkle.server --config server_config.yaml - -# 5. Open Grafana -open http://localhost:3000 -``` - -Anonymous viewer access is on by default; full access is `admin` / `admin`. - -The provisioned **Twinkle / Twinkle Server Overview** dashboard shows: - -- HTTP request rate and P95 latency per deployment (Gateway / Model / Sampler / Processor) -- Active resources (sessions, models, sampling sessions, futures) -- Task queue depth, execution P95, wait-time P95 -- Rate-limit rejections and task completions by status - -For traces, switch the datasource picker in **Explore** to Tempo and search by -service or span name. Twinkle spans are namespaced under -`twinkle.server.` (Gateway / Model / Sampler / Processor). - -## Metric naming reference - -Twinkle emits OpenTelemetry metric names with dot notation. Prometheus's OTLP -ingestion converts dots to underscores and appends `_total` to monotonic -counters where missing: - -| OpenTelemetry name | Prometheus name | -|---|---| -| `twinkle.http.requests.total` | `twinkle_http_requests_total` | -| `twinkle.http.request.duration_seconds` | `twinkle_http_request_duration_seconds_bucket` (and `_sum`, `_count`) | -| `twinkle.queue.depth` | `twinkle_queue_depth` | -| `twinkle.task.execution_seconds` | `twinkle_task_execution_seconds_bucket` | -| `twinkle.task.wait_seconds` | `twinkle_task_wait_seconds_bucket` | -| `twinkle.rate_limit.rejections.total` | `twinkle_rate_limit_rejections_total` | -| `twinkle.tasks.total` | `twinkle_tasks_total` | -| `twinkle.rate_limiter.active_tokens` | `twinkle_rate_limiter_active_tokens` | -| `twinkle.sessions.active` | `twinkle_sessions_active` | -| `twinkle.models.active` | `twinkle_models_active` | -| `twinkle.sampling_sessions.active` | `twinkle_sampling_sessions_active` | -| `twinkle.futures.active` | `twinkle_futures_active` | - -## Tear down - -```bash -docker compose down -v # -v also removes the named volume -``` - -## Production note - -The LGTM all-in-one image is **for local development and demos**. Each backend -runs single-instance and shares one volume. For production, deploy each -component (Mimir / Tempo / Loki / Grafana) separately with proper persistent -storage, replicas, and an OTel Collector tier in front. The OTLP endpoint and -metric names stay the same, so your `server_config.yaml` and dashboards -transfer without changes. - -## Troubleshooting - -- **Grafana shows "No data"** — confirm `telemetry_config.enabled: true` in - your server config and that Twinkle's worker logs show - `Worker telemetry initialized`. With `debug: true` Twinkle dumps spans / - metrics to logs instead of OTLP, so set `debug: false` once verified. -- **Twinkle can't reach the collector** — `otlp_endpoint` must be reachable - from the Twinkle process. If Twinkle runs in another container on the same - Docker network, use `http://twinkle-lgtm:4317` instead of `localhost`. -- **Dashboard panel shows "Datasource not found"** — open the panel, switch - the datasource dropdown to the LGTM-provisioned Prometheus / Tempo and save. - This happens when LGTM versions change the default datasource UID; the - dashboard JSON pins `uid: prometheus`. diff --git a/cookbook/observability/demo_sft_users.py b/cookbook/observability/demo_sft_users.py deleted file mode 100644 index 41d916173..000000000 --- a/cookbook/observability/demo_sft_users.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python -"""End-to-end demo: 5 users running parallel SFT, full trace + log + metric. - -Generates traffic that exercises every layer the spec instruments: -- Gateway / Model spans (HTTP edge) -- ServerState business spans (create_session, register_model, register_replica, - store_future_status, unload_model) -- Task-queue execution spans -- Per-user logs at INFO/WARN/ERROR with trace_id auto-attached -- HTTP request counters + per-deployment task duration histograms -- Resource gauges (CPU / memory / process RSS) - -Run: - PYTHONPATH=src python cookbook/observability/demo_sft_users.py - -Then in Grafana (http://localhost:3000): -- Tempo Search → Service=twinkle-server, Tags: twinkle.session_id= - → all spans for that user's whole session -- Loki Explore → {service_name="twinkle-server"} | trace_id = `` - → every log for that trace -- Prometheus Explore → twinkle_http_requests_total / twinkle_task_execution_seconds - → request rate + task latencies -""" -from __future__ import annotations - -import logging -import random -import time -import uuid -from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager - -from opentelemetry import _logs as _otel_logs, metrics, trace -from opentelemetry.util._once import Once - - -def _reset_otel_globals() -> None: - """Clear OTel one-shot guards so init_telemetry runs from a clean slate.""" - trace._TRACER_PROVIDER_SET_ONCE = Once() - trace._TRACER_PROVIDER = None - metrics._METER_PROVIDER_SET_ONCE = Once() - metrics._METER_PROVIDER = None - if hasattr(_otel_logs, '_LOGGER_PROVIDER_SET_ONCE'): - _otel_logs._LOGGER_PROVIDER_SET_ONCE = Once() - _otel_logs._LOGGER_PROVIDER = None - - -def setup_telemetry(otlp_endpoint: str = 'http://localhost:4317') -> None: - """Initialize the real production telemetry pipeline.""" - _reset_otel_globals() - - from twinkle.server.telemetry import provider - from twinkle.server.telemetry.provider import TelemetryConfig, init_telemetry - - provider._initialized = False - init_telemetry(TelemetryConfig( - enabled=True, - debug=False, - service_name='twinkle-server', - otlp_endpoint=otlp_endpoint, - export_interval_ms=1000, - )) - - # Resource collector (CPU / Mem / GPU) - from twinkle.server.telemetry.resource_metrics import ( - get_collector, reset_collector_for_tests, - ) - reset_collector_for_tests() - get_collector().maybe_start() - - # Trigger MetricsRegistry once so its observable counters / histograms - # land in the meter provider before the workers start emitting. - from twinkle.server.telemetry.metrics import MetricsRegistry - MetricsRegistry.get() - - -@contextmanager -def _gateway_span(tracer, name: str, attrs: dict): - """Emulate the Gateway's HTTP-edge span: kind=server, route attrs.""" - with tracer.start_as_current_span( - name, attributes={'http.method': 'POST', 'http.route': name, **attrs} - ) as span: - yield span - - -def run_sft_for_user(user_idx: int, num_steps: int = 8) -> dict: - """Run one full SFT session for a single user — exercises every layer.""" - from twinkle.server.telemetry.correlation import ( - BASE_MODEL, MODEL_ID, REPLICA_ID, SESSION_ID, TOKEN_ID, - ) - from twinkle.server.telemetry.metrics import MetricsRegistry - from twinkle.server.telemetry.tracing import traced_operation - - log = logging.getLogger(f'twinkle.demo.user{user_idx}') - log.setLevel(logging.INFO) - metrics_reg = MetricsRegistry.get() - tracer = trace.get_tracer('twinkle.gateway') - - sid = f'session_{uuid.uuid4().hex[:8]}' - token = f'tok_user_{user_idx}' - base_model = ['Qwen/Qwen3.5-4B', 'Qwen/Qwen3.5-7B', 'Qwen/Qwen3.5-1.8B'][user_idx % 3] - replica_id = f'replica_{user_idx % 3}' - - # ---- 1. /create_session --------------------------------------------- - with _gateway_span(tracer, 'POST /tinker/create_session', - {SESSION_ID: sid, TOKEN_ID: token}): - log.info(f'creating session for user{user_idx}', - extra={'twinkle.session_id': sid, 'twinkle.token_id': token}) - with traced_operation('server_state.create_session', attrs={SESSION_ID: sid}): - time.sleep(random.uniform(0.005, 0.02)) - metrics_reg.requests_total.add(1, {'route': '/tinker/create_session', 'status': '200'}) - - # ---- 2. /create_model (registers a base + LoRA, picks a replica) ----- - mid = f'mid_{uuid.uuid4().hex[:8]}' - with _gateway_span(tracer, 'POST /tinker/create_model', - {SESSION_ID: sid, MODEL_ID: mid, TOKEN_ID: token, BASE_MODEL: base_model}): - log.info(f'register_model base={base_model} replica={replica_id}', - extra={'twinkle.session_id': sid, 'twinkle.model_id': mid, - 'twinkle.token_id': token, 'twinkle.base_model': base_model}) - with traced_operation('server_state.register_replica', attrs={REPLICA_ID: replica_id}): - time.sleep(random.uniform(0.005, 0.02)) - with traced_operation('server_state.register_model', - attrs={SESSION_ID: sid, MODEL_ID: mid, REPLICA_ID: replica_id, - TOKEN_ID: token, BASE_MODEL: base_model}): - time.sleep(random.uniform(0.01, 0.04)) - metrics_reg.requests_total.add(1, {'route': '/tinker/create_model', 'status': '200'}) - - # ---- 3. forward_backward × num_steps (the actual SFT loop) ---------- - losses = [] - for step in range(num_steps): - with _gateway_span(tracer, 'POST /tinker/forward_backward', - {SESSION_ID: sid, MODEL_ID: mid, 'sft.step': step}): - wait = random.uniform(0.001, 0.015) - execute = random.uniform(0.05, 0.20) - metrics_reg.task_wait_seconds.record(wait, {'deployment': 'Model'}) - with traced_operation('task_queue.execute', - attrs={SESSION_ID: sid, MODEL_ID: mid, TOKEN_ID: token}): - with traced_operation('model.forward_backward', - attrs={SESSION_ID: sid, MODEL_ID: mid}): - time.sleep(execute) - loss = max(0.05, 2.5 * (0.92 ** step) + random.uniform(-0.05, 0.05)) - losses.append(loss) - if step % 4 == 0: - log.info(f'sft step={step} loss={loss:.3f}', - extra={'twinkle.session_id': sid, 'twinkle.model_id': mid, - 'sft.step': step, 'sft.loss': loss}) - metrics_reg.task_execution_seconds.record(execute, {'deployment': 'Model'}) - metrics_reg.tasks_total.add(1, {'deployment': 'Model', 'status': 'completed'}) - metrics_reg.requests_total.add(1, {'route': '/tinker/forward_backward', 'status': '200'}) - - # Simulate a user that hits the rate limit at step 3 of 8 - if user_idx == 2: - with _gateway_span(tracer, 'POST /tinker/forward_backward', - {SESSION_ID: sid, MODEL_ID: mid, 'sft.step': num_steps}): - log.warning(f'rate-limit rejection for user{user_idx}', - extra={'twinkle.session_id': sid, 'twinkle.token_id': token}) - metrics_reg.rate_limit_rejections.add(1, {'deployment': 'Model'}) - metrics_reg.requests_total.add(1, {'route': '/tinker/forward_backward', 'status': '429'}) - - # Simulate a hard failure for user 4 - if user_idx == 4: - with _gateway_span(tracer, 'POST /tinker/optim_step', - {SESSION_ID: sid, MODEL_ID: mid}): - try: - with traced_operation('model.optim_step', attrs={SESSION_ID: sid, MODEL_ID: mid}): - raise RuntimeError('optimizer NaN at user4 step5') - except RuntimeError: - log.exception(f'sft failed sid={sid} mid={mid}', - extra={'twinkle.session_id': sid, 'twinkle.model_id': mid}) - metrics_reg.tasks_total.add(1, {'deployment': 'Model', 'status': 'failed'}) - metrics_reg.requests_total.add(1, {'route': '/tinker/optim_step', 'status': '500'}) - - # ---- 4. /save_weights (client downloads LoRA) ------------------------ - with _gateway_span(tracer, 'POST /tinker/save_weights', - {SESSION_ID: sid, MODEL_ID: mid}): - log.info(f'save_weights mid={mid}', - extra={'twinkle.session_id': sid, 'twinkle.model_id': mid}) - with traced_operation('server_state.store_future_status', attrs={MODEL_ID: mid}): - time.sleep(random.uniform(0.02, 0.08)) - metrics_reg.requests_total.add(1, {'route': '/tinker/save_weights', 'status': '200'}) - - # ---- 5. /unload_model (cleanup) -------------------------------------- - with _gateway_span(tracer, 'POST /tinker/unload_model', - {SESSION_ID: sid, MODEL_ID: mid}): - log.info(f'unload_model mid={mid}', - extra={'twinkle.session_id': sid, 'twinkle.model_id': mid}) - with traced_operation('server_state.unload_model', attrs={MODEL_ID: mid}): - time.sleep(random.uniform(0.005, 0.015)) - metrics_reg.requests_total.add(1, {'route': '/tinker/unload_model', 'status': '200'}) - - return {'user_idx': user_idx, 'session_id': sid, 'model_id': mid, - 'token': token, 'base_model': base_model, - 'final_loss': losses[-1] if losses else None, - 'num_steps': num_steps} - - -def main() -> None: - setup_telemetry() - log = logging.getLogger('twinkle.demo') - log.setLevel(logging.INFO) - - NUM_USERS = 5 - log.info(f'launching {NUM_USERS} concurrent SFT runs') - - with ThreadPoolExecutor(max_workers=NUM_USERS) as pool: - futures = [pool.submit(run_sft_for_user, i, num_steps=8) for i in range(NUM_USERS)] - results = [f.result() for f in futures] - - log.info(f'all {NUM_USERS} users finished SFT') - print('\n=== Per-user summary (use these IDs to query in Grafana) ===') - for r in results: - print(f" user{r['user_idx']} token={r['token']:14s} session={r['session_id']} " - f"model={r['model_id']} base={r['base_model']:20s} " - f"final_loss={r['final_loss']:.3f}" if r['final_loss'] else '') - - # Drive resource gauges + flush everything - time.sleep(3) - trace.get_tracer_provider().force_flush(timeout_millis=10000) - metrics.get_meter_provider().force_flush(timeout_millis=10000) - from twinkle.server.telemetry import provider - provider._logger_provider.force_flush(timeout_millis=10000) - time.sleep(2) - print('\nflushed traces + logs + metrics to OTLP') - - -if __name__ == '__main__': - main() diff --git a/cookbook/observability/docker-compose.yaml b/cookbook/observability/docker-compose.yaml deleted file mode 100644 index de775c6e0..000000000 --- a/cookbook/observability/docker-compose.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Twinkle observability stack — local-dev edition. -# -# Single container: grafana/otel-lgtm bundles OTel Collector + Mimir -# (Prometheus-compatible) + Tempo + Loki + Grafana with all datasources -# pre-wired. Twinkle pushes OTLP to :4317; you read it back at :3000. -# -# Quick start: -# docker compose up -d -# open http://localhost:3000 # admin / admin (anonymous viewer also enabled) -# -# In your server_config.yaml: -# telemetry_config: -# enabled: true -# debug: false -# service_name: twinkle-server -# otlp_endpoint: http://localhost:4317 - -services: - lgtm: - image: grafana/otel-lgtm:latest - container_name: twinkle-lgtm - ports: - - "3000:3000" # Grafana UI - - "4317:4317" # OTLP gRPC — point telemetry_config.otlp_endpoint here - - "4318:4318" # OTLP HTTP (alternative) - volumes: - # Drop our pre-built Twinkle overview dashboard into the image's - # existing dashboard provisioning folder. Grafana inside the container - # auto-scans this directory on startup. - - ./grafana/dashboards/twinkle-overview.json:/otel-lgtm/grafana/conf/provisioning/dashboards/twinkle-overview.json:ro - # Persist dashboards/data across container restarts (optional) - - lgtm-data:/data - -volumes: - lgtm-data: diff --git a/cookbook/observability/grafana/dashboards/twinkle-overview.json b/cookbook/observability/grafana/dashboards/twinkle-overview.json deleted file mode 100644 index 1d7251b4e..000000000 --- a/cookbook/observability/grafana/dashboards/twinkle-overview.json +++ /dev/null @@ -1,362 +0,0 @@ -{ - "annotations": { - "list": [] - }, - "editable": true, - "fiscalYearStartMonth": 0, - "graphTooltip": 1, - "id": null, - "links": [], - "liveNow": false, - "panels": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "color": { - "mode": "palette-classic" - }, - "unit": "reqps" - } - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 0, - "y": 0 - }, - "id": 1, - "title": "HTTP request rate (per deployment)", - "type": "timeseries", - "targets": [ - { - "expr": "sum by (deployment, status) (rate(twinkle_http_requests_total[1m]))", - "legendFormat": "{{deployment}} {{status}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "s" - } - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 0 - }, - "id": 2, - "title": "HTTP latency P95 (per deployment)", - "type": "timeseries", - "targets": [ - { - "expr": "histogram_quantile(0.95, sum by (le, deployment) (rate(twinkle_http_request_duration_seconds_bucket[5m])))", - "legendFormat": "{{deployment}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 0, - "y": 8 - }, - "id": 3, - "title": "Active resources", - "type": "timeseries", - "targets": [ - { - "expr": "twinkle_sessions_active", - "legendFormat": "sessions", - "refId": "A" - }, - { - "expr": "twinkle_models_active", - "legendFormat": "models", - "refId": "B" - }, - { - "expr": "twinkle_sampling_sessions_active", - "legendFormat": "sampling sessions", - "refId": "C" - }, - { - "expr": "twinkle_futures_active", - "legendFormat": "futures", - "refId": "D" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 8 - }, - "id": 4, - "title": "Task queue depth (per deployment)", - "type": "timeseries", - "targets": [ - { - "expr": "sum by (deployment) (twinkle_queue_depth)", - "legendFormat": "{{deployment}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "s" - } - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 0, - "y": 16 - }, - "id": 5, - "title": "Task execution P95", - "type": "timeseries", - "targets": [ - { - "expr": "histogram_quantile(0.95, sum by (le, deployment) (rate(twinkle_task_execution_seconds_bucket[5m])))", - "legendFormat": "{{deployment}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "s" - } - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 16 - }, - "id": 6, - "title": "Task wait time P95", - "type": "timeseries", - "targets": [ - { - "expr": "histogram_quantile(0.95, sum by (le, deployment) (rate(twinkle_task_wait_seconds_bucket[5m])))", - "legendFormat": "{{deployment}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 0, - "y": 24 - }, - "id": 7, - "title": "Rate-limit rejections", - "type": "timeseries", - "targets": [ - { - "expr": "sum by (deployment) (rate(twinkle_rate_limit_rejections_total[1m]))", - "legendFormat": "{{deployment}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 24 - }, - "id": 8, - "title": "Task completions by status", - "type": "timeseries", - "targets": [ - { - "expr": "sum by (deployment, status) (rate(twinkle_tasks_total[1m]))", - "legendFormat": "{{deployment}} {{status}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "percentunit" - }, - "overrides": [] - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 0, - "y": 32 - }, - "id": 9, - "title": "CPU utilization", - "type": "timeseries", - "targets": [ - { - "expr": "twinkle_system_cpu_utilization", - "legendFormat": "{{instance}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "bytes" - }, - "overrides": [] - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 32 - }, - "id": 10, - "title": "Memory utilization (system)", - "type": "timeseries", - "targets": [ - { - "expr": "twinkle_system_memory_usage_bytes", - "legendFormat": "system used", - "refId": "A" - }, - { - "expr": "twinkle_process_memory_usage_bytes", - "legendFormat": "{{instance}} process RSS", - "refId": "B" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "percentunit" - }, - "overrides": [] - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 0, - "y": 40 - }, - "id": 11, - "title": "GPU utilization", - "type": "timeseries", - "targets": [ - { - "expr": "twinkle_gpu_utilization", - "legendFormat": "gpu {{gpu_index}}", - "refId": "A" - } - ] - }, - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "fieldConfig": { - "defaults": { - "unit": "bytes" - }, - "overrides": [] - }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 40 - }, - "id": 12, - "title": "GPU memory used", - "type": "timeseries", - "targets": [ - { - "expr": "twinkle_gpu_memory_usage_bytes", - "legendFormat": "gpu {{gpu_index}}", - "refId": "A" - } - ] - } - ], - "refresh": "10s", - "schemaVersion": 39, - "tags": [ - "twinkle" - ], - "templating": { - "list": [] - }, - "time": { - "from": "now-1h", - "to": "now" - }, - "timepicker": {}, - "timezone": "", - "title": "Twinkle Server Overview", - "uid": "twinkle-overview", - "version": 1, - "weekStart": "" -} diff --git a/cookbook/observability/load.py b/cookbook/observability/load.py deleted file mode 100644 index e41338024..000000000 --- a/cookbook/observability/load.py +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env python -"""Drive a running mock Twinkle server with enough traffic that every panel on -the ``Twinkle Server Overview`` Grafana dashboard shows data. - -Why this script exists ----------------------- -Most empty panels on the overview dashboard are NOT a wiring problem — they -are caused by: - -1. ``histogram_quantile(0.95, rate(..._bucket[5m]))`` returns NaN when the - 5-minute look-back has zero traffic, so latency / wait-time / execution - panels read "No data". -2. ``up_down_counter`` gauges emit on *delta* only. ``active_sessions`` / - ``active_models`` / ``queue_depth`` stay invisible until something - actually changes their underlying count at least once. -3. Mock backends execute in microseconds — even when traffic exists, - histogram P95 hugs the bottom bucket. Bump ``--max-tokens`` so the - sampler's per-request runtime lifts off the floor. - -Pre-reqs --------- -1. Observability stack + Redis running:: - - docker compose -f cookbook/observability/docker-compose.yaml up -d - docker run -d --name twinkle-redis -p 6379:6379 redis:7 - -2. Mock server running with telemetry **enabled** and a SHARED persistence - backend. The shipped ``cookbook/client/server/mock/server_config.yaml`` - ships ``mode: memory`` which is per-process — Gateway-worker sessions - are invisible to Model worker, the adapter countdown loop sees "session - not found" and expires registered adapters within ~10s (which empties - the ``twinkle_models_active`` gauge). For this load script to populate - every panel, switch persistence to redis:: - - persistence: { mode: redis, redis_url: redis://localhost:6379 } - - Telemetry: the worker reads the env flag as the literal string ``"1"`` - (see ``telemetry.worker_init.ensure_telemetry_initialized``), NOT - ``"true"``:: - - TWINKLE_TELEMETRY_ENABLED=1 \\ - TWINKLE_TELEMETRY_ENDPOINT=http://localhost:4317 \\ - python -m twinkle.server launch \\ - --config cookbook/client/server/mock/server_config.yaml - - Start Ray first (the launcher does ``ray.init(address='auto')`` and - will refuse to spin one up locally):: - - ray start --head --num-cpus=4 --disable-usage-stats - -Usage ------ -:: - - # Defaults: 4 concurrent users, 120s, ~2 req/s each - python cookbook/observability/load.py - - # Heavier: 8 users, 5 minutes, longer sampler runtime (lifts P95) - python cookbook/observability/load.py \\ - --concurrency 8 --duration 300 --max-tokens 128 - -In Grafana set the time window to ``Last 15 minutes`` for the rate[5m] -queries to be meaningful. -""" -from __future__ import annotations - -import argparse -import asyncio -import json -import random -import time -import uuid - -import httpx - -# Routes from cookbook/client/server/mock/server_config.yaml — keep in sync. -GATEWAY_ROUTE = '/api/v1' -MODEL_ROUTE = '/api/v1/model/mock' -SAMPLER_ROUTE = '/api/v1/sampler/mock' - -# Any non-empty token is accepted (``is_token_valid`` is permissive by default). -TOKEN = 'load-test-token' - - -def _headers(session_id: str, *, request_id: str, multiplex_key: str | None = None) -> dict[str, str]: - """Build the per-request header set the server middleware expects. - - The server's ``verify_request_token`` middleware requires: - - ``Twinkle-Authorization: Bearer `` - - ``X-Ray-Serve-Request-Id`` for sticky routing (any unique string ok) - - ``X-Twinkle-Session-Id`` for session correlation (optional) - - Model + sampler deployments additionally call - ``serve.get_multiplexed_model_id()`` for sticky-LoRA replica routing — - Ray Serve raises ``ValueError("The model ID cannot be empty.")`` if the - ``serve_multiplexed_model_id`` header is absent. Always set - ``multiplex_key`` for model / sampler calls; the Gateway endpoint - (``/api/v1/twinkle/create_session``) does not need it. - - Pass the SAME ``request_id`` for every call against the same registered - adapter so the sticky-LoRA key (``request_id + '-' + adapter_name``) - resolves to the registered resource on subsequent ``/forward_only`` calls. - """ - headers = { - 'Twinkle-Authorization': f'Bearer {TOKEN}', - 'X-Ray-Serve-Request-Id': request_id, - 'X-Twinkle-Session-Id': session_id, - 'Content-Type': 'application/json', - } - if multiplex_key is not None: - headers['serve_multiplexed_model_id'] = multiplex_key - return headers - - -def _lora_config_payload(rank: int = 8) -> str: - """JSON payload the server's ``deserialize_object`` will rehydrate into a - ``peft.LoraConfig``. Matches ``twinkle_client.common.serialize.serialize_object``. - """ - return json.dumps({ - '_TWINKLE_TYPE_': 'LoraConfig', - 'r': rank, - 'lora_alpha': rank * 2, - 'lora_dropout': 0.0, - 'bias': 'none', - 'task_type': 'CAUSAL_LM', - 'target_modules': ['q_proj', 'v_proj'], - }) - - -async def create_session(client: httpx.AsyncClient) -> str | None: - """POST /api/v1/twinkle/create_session — returns the SERVER-issued - ``session_id`` so subsequent ``X-Twinkle-Session-Id`` headers reference - a session the server actually persisted. Using a client-side string - would silently fail liveness checks (the adapter countdown loop in - ``utils/lifecycle/base.py`` calls ``state.get_session_last_heartbeat`` - and expires adapters within ~10s when the ID isn't found).""" - r = await client.post( - f'{GATEWAY_ROUTE}/twinkle/create_session', - headers=_headers('', request_id=uuid.uuid4().hex), - json={'metadata': {'source': 'load.py'}}, - timeout=10.0, - ) - if r.status_code != 200: - print(f' create_session -> {r.status_code} {r.text[:160]}') - return None - return r.json().get('session_id') - - -async def create_sampling_session(client: httpx.AsyncClient, session_id: str, model_path: str) -> None: - """POST /api/v1/create_sampling_session — bumps ``active_sampling_sessions``. - This is a Tinker route mounted at the gateway root (NOT under - ``/twinkle/``); the Twinkle gateway handlers only expose ``create_session``.""" - try: - await client.post( - f'{GATEWAY_ROUTE}/create_sampling_session', - headers=_headers(session_id, request_id=uuid.uuid4().hex), - json={ - 'session_id': session_id, - 'sampling_session_seq_id': 0, - 'model_path': model_path, - 'base_model': 'mock-model', - }, - timeout=10.0, - ) - except Exception: - pass - - -async def session_heartbeat(client: httpx.AsyncClient, session_id: str) -> None: - """POST /api/v1/twinkle/session_heartbeat — refreshes the session so - the adapter countdown loop doesn't expire registered adapters mid-load.""" - try: - await client.post( - f'{GATEWAY_ROUTE}/twinkle/session_heartbeat', - headers=_headers(session_id, request_id=uuid.uuid4().hex), - json={'session_id': session_id}, - timeout=5.0, - ) - except Exception: - pass - - -async def add_adapter(client: httpx.AsyncClient, adapter_name: str, session_id: str, request_id: str) -> bool: - """POST /api/v1/model/mock/twinkle/add_adapter_to_model — moves - ``active_models`` gauge and goes through the task queue (queue_depth + - task_execution histograms). - """ - body = {'adapter_name': adapter_name, 'config': _lora_config_payload()} - r = await client.post( - f'{MODEL_ROUTE}/twinkle/add_adapter_to_model', - headers=_headers(session_id, request_id=request_id, multiplex_key=adapter_name), - json=body, - timeout=30.0, - ) - if r.status_code != 200: - print(f' add_adapter {adapter_name} -> {r.status_code} {r.text[:200]}') - return False - return True - - -async def sample(client: httpx.AsyncClient, session_id: str, *, max_tokens: int) -> int: - """POST /api/v1/sampler/mock/twinkle/sample — primary load. No adapter - registration needed (``adapter_name=''`` skips the resource check).""" - body = { - 'inputs': [{'input_ids': [random.randint(0, 100) for _ in range(8)]}], - 'sampling_params': {'max_tokens': max_tokens}, - 'adapter_name': '', - } - r = await client.post( - f'{SAMPLER_ROUTE}/twinkle/sample', - headers=_headers(session_id, request_id=uuid.uuid4().hex, multiplex_key=session_id), - json=body, - timeout=60.0, - ) - return r.status_code - - -async def forward_only(client: httpx.AsyncClient, adapter_name: str, session_id: str, request_id: str) -> int: - """POST /api/v1/model/mock/twinkle/forward_only against a registered adapter. - - ``request_id`` MUST be the same one used by the original ``add_adapter`` - call — the server prefixes ``request_id`` onto the adapter key for - sticky-LoRA routing, so reusing it lets ``assert_resource_exists`` find - the adapter we registered. - """ - body = { - 'inputs': [{'input_ids': [random.randint(0, 100) for _ in range(16)]}], - 'adapter_name': adapter_name, - } - r = await client.post( - f'{MODEL_ROUTE}/twinkle/forward_only', - headers=_headers(session_id, request_id=request_id, multiplex_key=adapter_name), - json=body, - timeout=30.0, - ) - return r.status_code - - -async def user_loop( - user_id: int, - base_url: str, - deadline: float, - interval: float, - max_tokens: int, -) -> None: - """Per-user driver: create_session + add_adapter once, then loop sample - (and occasional forward_only) until the deadline. Periodically heartbeats - the session so the adapter countdown loop doesn't expire registered - adapters mid-load (default adapter_timeout is 1800s, but a missing - heartbeat trips ``_is_session_alive`` long before that).""" - adapter_name = f'adapter-u{user_id}-{uuid.uuid4().hex[:6]}' - sticky_request_id = uuid.uuid4().hex - heartbeat_interval = 5.0 - - async with httpx.AsyncClient(base_url=base_url) as client: - # IMPORTANT: use the SERVER-issued session_id; sending our own client- - # side string would never match a stored session and registered - # adapters would expire within ~10s. - session_id = await create_session(client) - adapter_ok = False - if session_id: - adapter_ok = await add_adapter(client, adapter_name, session_id, sticky_request_id) - # Best-effort: bump the active_sampling_sessions gauge. - await create_sampling_session(client, session_id, model_path=f'mock://{adapter_name}') - - ok_n = err_n = 0 - last_hb = time.monotonic() - while time.monotonic() < deadline: - if session_id and time.monotonic() - last_hb >= heartbeat_interval: - await session_heartbeat(client, session_id) - last_hb = time.monotonic() - use_forward = adapter_ok and random.random() < 0.2 - try: - if use_forward: - status = await forward_only(client, adapter_name, session_id, sticky_request_id) - else: - status = await sample(client, session_id or '', max_tokens=max_tokens) - except Exception as exc: - err_n += 1 - print(f' user {user_id} request error: {exc!r}') - await asyncio.sleep(1.0) - continue - if 200 <= status < 300: - ok_n += 1 - else: - err_n += 1 - await asyncio.sleep(max(0.01, interval + random.uniform(-interval / 4, interval / 4))) - - print(f' user {user_id:>2} ok={ok_n:>4} err={err_n} ' - f'session={session_id or ""} adapter={adapter_name if adapter_ok else ""}') - - -async def main_async(args: argparse.Namespace) -> None: - deadline = time.monotonic() + args.duration - print(f'Load: base={args.base_url} concurrency={args.concurrency} ' - f'duration={args.duration}s interval={args.interval}s max_tokens={args.max_tokens}') - print(f'Hits: POST {GATEWAY_ROUTE}/twinkle/create_session') - print(f' POST {MODEL_ROUTE}/twinkle/add_adapter_to_model') - print(f' POST {MODEL_ROUTE}/twinkle/forward_only (~20%)') - print(f' POST {SAMPLER_ROUTE}/twinkle/sample (~80%)') - await asyncio.gather(*[ - user_loop(i, args.base_url, deadline, args.interval, args.max_tokens) for i in range(args.concurrency) - ]) - print('Done. Allow ~30s for the next OTLP export tick, then refresh Grafana ') - print('with the time window set to "Last 15 minutes".') - - -def main() -> None: - p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) - p.add_argument('--base-url', default='http://localhost:8000', help='Server base URL (default: %(default)s)') - p.add_argument('--concurrency', type=int, default=4, help='Parallel users (default: %(default)s)') - p.add_argument('--duration', type=int, default=120, help='Total seconds to run (default: %(default)s)') - p.add_argument( - '--interval', - type=float, - default=0.5, - help='Mean seconds between requests per worker; ±25%% jitter applied. ' - 'Lower → higher RPS. Default: %(default)s') - p.add_argument( - '--max-tokens', - type=int, - default=64, - help='Mock sampler runtime scales with max_tokens. Bump to >= 64 so ' - 'task_execution P95 lifts off the bottom histogram bucket. ' - 'Default: %(default)s') - args = p.parse_args() - asyncio.run(main_async(args)) - - -if __name__ == '__main__': - main() diff --git a/docs/source_en/Usage Guide/Observability.md b/docs/source_en/Usage Guide/Observability.md deleted file mode 100644 index 01c64b910..000000000 --- a/docs/source_en/Usage Guide/Observability.md +++ /dev/null @@ -1,101 +0,0 @@ -# Observability - -Twinkle Server emits OpenTelemetry traces, metrics, and logs from every Ray -Serve deployment. This guide covers the standardized **correlation keys**, -the Ray Serve **trace-context propagation** mechanism, and an end-to-end -**LGTM** example using the Loki / Grafana / Tempo / Mimir docker-compose -stack shipped under `cookbook/observability/`. - -## Correlation keys - -Every business-layer span carries a subset of these attributes when the -corresponding identifier is known to the operation. All names share the -`twinkle.` prefix so you can filter Tempo / Loki by a single namespace. - -| Attribute | Set when the operation is associated with… | -|------------------------------|--------------------------------------------| -| `twinkle.session_id` | A client session | -| `twinkle.model_id` | A specific registered model | -| `twinkle.replica_id` | A specific Ray Serve replica | -| `twinkle.token_id` | A user authentication token | -| `twinkle.sampling_session_id`| A sampling session | -| `twinkle.base_model` | The base model behind a registered model | - -Constants live in `twinkle.server.telemetry.correlation`. Use -`set_correlation_attrs(span, {...})` to attach them — None values are -skipped, so partially-known operations never get empty attributes. - -```python -from twinkle.server.telemetry.correlation import ( - SESSION_ID, MODEL_ID, set_correlation_attrs, -) -from twinkle.server.telemetry.tracing import traced_operation - -with traced_operation('server_state.register_model', - attrs={SESSION_ID: sid, MODEL_ID: mid}): - ... -``` - -When the OpenTelemetry SDK is not installed, `traced_operation` becomes a -NoOp context manager: the body runs to completion and returns the same -result it would return when tracing is active. - -## Trace-context propagation across deployments - -The HTTP edge already injects context into outgoing headers in -`gateway/proxy.py`, and `create_tracing_middleware` extracts it on the -inbound side, so a Tinker request that passes through the Gateway proxy -shares one trace id end to end. - -The remaining gap is **Ray Serve `DeploymentHandle` calls** between -deployments — those don't go over HTTP. Use the trace-context carrier -helpers: - -```python -from twinkle.server.telemetry.context_carrier import make_carrier, activate_carrier - -# caller side (e.g. Model deployment) — pass the carrier with the call -carrier = make_carrier() -result = await sampler_handle.options(...).remote(payload, trace_context=carrier) - -# callee side (e.g. Sampler deployment handler) -async def handler(payload, trace_context: dict | None = None): - with activate_carrier(trace_context): - with traced_operation('sampler.handle'): - ... -``` - -`make_carrier()` returns an empty dict and `activate_carrier(None)` is a -no-op when OTel is missing or the carrier is empty, so the path stays -safe under graceful degradation. - -## End-to-end LGTM example - -The repository ships a docker-compose stack with Grafana, Tempo (traces), -Loki (logs), and Mimir (metrics) under `cookbook/observability/`. - -```bash -# 1. Start the LGTM stack. -docker compose -f cookbook/observability/docker-compose.yml up -d - -# 2. Launch the server with telemetry enabled. -cat > /tmp/srv.yaml <<'YAML' -telemetry: - enabled: true - service_name: twinkle-server - otlp_endpoint: http://localhost:4317 -persistence: { mode: memory } -applications: [] -YAML - -python -m twinkle.server launch --config /tmp/srv.yaml & - -# 3. Issue some traffic and open Grafana at http://localhost:3000. -# In Tempo, search by tag: `twinkle.session_id = `. -``` - -CPU / memory / GPU metrics show up automatically because the -`ResourceMetricsCollector` is started inside every Ray Serve worker by -`ensure_telemetry_initialized()`. When `psutil` or `pynvml` is missing -(or no GPU is present), the affected gauges report no data and the -worker keeps serving requests. diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index e04da376f..ef477f7fc 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -12,7 +12,6 @@ Twinkle DOCUMENTATION Usage Guide/Quick-Start.md Usage Guide/Installation.md Usage Guide/Server and Client/index.rst - Usage Guide/Observability.md Usage Guide/NPU-Support.md Usage Guide/Train-as-a-Service.md Usage Guide/Introduction-with-Qwen3.5.md diff --git a/docs/source_zh/index.rst b/docs/source_zh/index.rst index 50f1f32cd..363a0e2d2 100644 --- a/docs/source_zh/index.rst +++ b/docs/source_zh/index.rst @@ -13,7 +13,6 @@ Twinkle DOCUMENTATION 使用指引/安装.md 使用指引/服务端和客户端/index.rst 使用指引/服务配置.md - 使用指引/可观测化.md 使用指引/NPU的支持.md 使用指引/训练服务.md 使用指引/Qwen3.5最佳实践.md diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" deleted file mode 100644 index 9aeb3e245..000000000 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\217\257\350\247\202\346\265\213\345\214\226.md" +++ /dev/null @@ -1,94 +0,0 @@ -# 可观测化 - -Twinkle Server 在每个 Ray Serve 部署中都会发出 OpenTelemetry 追踪、指标与日志。 -本指南覆盖标准化的 **关联键**、Ray Serve 内的 **追踪上下文传递** 机制, -以及基于 `cookbook/observability/` 下 docker-compose 栈的端到端 -**LGTM**(Loki / Grafana / Tempo / Mimir)示例。 - -## 关联键(Correlation keys) - -业务层的每个 span 在已知对应标识符时都会附带下列属性。所有名称都使用 -`twinkle.` 前缀,便于在 Tempo / Loki 中按命名空间统一筛选。 - -| 属性 | 在以下场景下设置 | -|-------------------------------|--------------------------------------| -| `twinkle.session_id` | 关联到某个客户端 session | -| `twinkle.model_id` | 关联到某个已注册的模型 | -| `twinkle.replica_id` | 关联到某个 Ray Serve 副本 | -| `twinkle.token_id` | 关联到某个用户认证 token | -| `twinkle.sampling_session_id` | 关联到某个采样 session | -| `twinkle.base_model` | 关联到注册模型背后的 base model | - -常量定义在 `twinkle.server.telemetry.correlation`。通过 -`set_correlation_attrs(span, {...})` 一次性附加;None 值会被跳过, -部分已知的操作不会出现空属性。 - -```python -from twinkle.server.telemetry.correlation import ( - SESSION_ID, MODEL_ID, set_correlation_attrs, -) -from twinkle.server.telemetry.tracing import traced_operation - -with traced_operation('server_state.register_model', - attrs={SESSION_ID: sid, MODEL_ID: mid}): - ... -``` - -未安装 OpenTelemetry SDK 时,`traced_operation` 退化为 NoOp 上下文管理器: -代码块照常执行并返回与启用追踪时相同的结果。 - -## 跨部署的追踪上下文传递 - -HTTP 边界已经在 `gateway/proxy.py` 中将上下文注入到出站 header, -`create_tracing_middleware` 在入站侧提取——Tinker 经 Gateway 代理的请求在端到端 -共享同一个 trace id。 - -剩下的空缺是 **Ray Serve `DeploymentHandle` 内部调用**——这些调用不走 HTTP。 -使用追踪上下文 carrier 辅助函数: - -```python -from twinkle.server.telemetry.context_carrier import make_carrier, activate_carrier - -# 调用方(如 Model 部署)—— 把 carrier 一起传给被调方 -carrier = make_carrier() -result = await sampler_handle.options(...).remote(payload, trace_context=carrier) - -# 被调方(如 Sampler handler) -async def handler(payload, trace_context: dict | None = None): - with activate_carrier(trace_context): - with traced_operation('sampler.handle'): - ... -``` - -OTel 缺失或 carrier 为空时,`make_carrier()` 返回空字典, -`activate_carrier(None)` 是无操作的上下文管理器,调用路径在优雅降级时仍然安全。 - -## 端到端 LGTM 示例 - -本仓库在 `cookbook/observability/` 下提供一套 docker-compose 栈,包含 -Grafana、Tempo(traces)、Loki(logs)和 Mimir(metrics)。 - -```bash -# 1. 启动 LGTM 栈 -docker compose -f cookbook/observability/docker-compose.yml up -d - -# 2. 启用 telemetry 启动服务 -cat > /tmp/srv.yaml <<'YAML' -telemetry: - enabled: true - service_name: twinkle-server - otlp_endpoint: http://localhost:4317 -persistence: { mode: memory } -applications: [] -YAML - -python -m twinkle.server launch --config /tmp/srv.yaml & - -# 3. 发送一些请求,浏览器打开 http://localhost:3000 进入 Grafana -# 在 Tempo 中以 `twinkle.session_id = <你的 session>` 作为 tag 检索 -``` - -CPU / 内存 / GPU 指标会自动出现,因为 `ResourceMetricsCollector` 由 -`ensure_telemetry_initialized()` 在每个 Ray Serve worker 中启动。 -当 `psutil`、`pynvml` 缺失(或没有 GPU)时,对应 gauge 报告 no data, -worker 仍然继续服务请求。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" index 10b6bab8a..a2ba7e7ac 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" @@ -2,7 +2,7 @@ Twinkle Server 的所有运行时配置都来自一个 Pydantic 聚合根 `ServerConfig`,由 YAML 文件经 `ServerConfig.from_yaml(path)` 加载。 -本指南覆盖:字段一览、支持的环境变量、一份完整 YAML 示例,以及从旧字段 +本指南覆盖:字段一览、支持的环境变量、一份最小 YAML 示例,以及从旧字段 名到当前字段名的迁移表。 ## 字段总览 @@ -13,11 +13,6 @@ Twinkle Server 的所有运行时配置都来自一个 Pydantic 聚合根 | `proxy_location` (str\|null)| null | Ray Serve proxy 部署策略,多机推荐 "EveryNode" | | `http_options.host` (str) | "localhost" | HTTP 监听地址,多机部署设为 "0.0.0.0" | | `http_options.port` (int) | 8000 | HTTP 监听端口 | -| `telemetry.enabled` (bool) | false | 是否启用 OpenTelemetry,关闭时整条管线 NoOp | -| `telemetry.debug` (bool) | false | true 控制台输出,false 走 OTLP | -| `telemetry.service_name` | "twinkle-server" | OTEL `service.name` | -| `telemetry.otlp_endpoint` | "http://localhost:4317" | OTLP gRPC 端点 | -| `telemetry.export_interval_ms`| 30000 | 指标导出间隔(毫秒) | | `persistence.mode` (str) | "memory" | "memory" / "file" / "redis" | | `persistence.file_path` | null | mode=file 时必填 | | `persistence.redis_url` | null | mode=redis 时必填 | @@ -33,8 +28,8 @@ Twinkle Server 的所有运行时配置都来自一个 Pydantic 聚合根 应用条目的 `args` 模式由 `import_path` 决定: - `import_path=server` → `ServerArgs`(`server_config`、`supported_models`、`http_options`) -- `import_path=model` → `ModelArgs`(必填 `backend: mock|transformers|megatron`、`model_id`、`device_group`、`device_mesh`) -- `import_path=sampler`→ `SamplerArgs`(必填 `sampler_type: mock|vllm|torch`、`model_id`) +- `import_path=model` → `ModelArgs`(必填 `backend: transformers|megatron`、`model_id`、`device_group`、`device_mesh`) +- `import_path=sampler`→ `SamplerArgs`(必填 `sampler_type: vllm|torch`、`model_id`) - `import_path=processor`→ `ProcessorArgs` ## 支持的环境变量 @@ -49,31 +44,29 @@ CLI 选项均声明 `envvar=`,命令行未指定时回退到环境变量: 启动器额外读取(用于跨 Ray worker 传播): -- `TWINKLE_TELEMETRY_ENABLED` / `_DEBUG` / `_SERVICE` / `_ENDPOINT` / `_INTERVAL` - `TWINKLE_PERSISTENCE_MODE` / `_FILE_PATH` / `_REDIS_URL` / `_KEY_PREFIX` -## 完整 YAML 示例 +## 最小 YAML 示例 -参见 [`cookbook/client/server/server_config.example.yaml`](https://github.com/modelscope/twinkle/blob/main/cookbook/client/server/server_config.example.yaml), -该文件每个字段都附带类型、默认值与可选项。最小可执行示例: +参见 `cookbook/client/server/transformer/server_config.yaml` 和 +`cookbook/client/server/megatron/server_config.yaml`。最小可执行示例: ```yaml http_options: { host: 0.0.0.0, port: 8000 } -telemetry: { enabled: false } persistence: { mode: memory } applications: - name: server route_prefix: /api/v1 import_path: server - args: { supported_models: [mock-model] } + args: { supported_models: [Qwen/Qwen3-4B] } - name: models route_prefix: /api/v1/model import_path: model args: - backend: mock - model_id: mock-model - device_group: { name: model, ranks: 1, device_type: cpu } - device_mesh: { device_type: cpu, dp_size: 1 } + backend: transformers + model_id: Qwen/Qwen3-4B + device_group: { name: model, ranks: 1, device_type: cuda } + device_mesh: { device_type: cuda, dp_size: 1 } ``` ## 旧字段 → 当前字段迁移表 @@ -83,7 +76,6 @@ applications: | 旧字段名 | 当前字段名 | 备注 | |---------------------------|--------------------|------| -| `telemetry_config:` | `telemetry:` | 顶层 | | `persistence_config:` | `persistence:` | 顶层 | | 在 model `args` 中:`use_megatron: true` | `backend: megatron` | 模型后端切换 | | 在 model `args` 中:`use_megatron: false` | `backend: transformers` | 模型后端切换 | diff --git a/tests/docs/test_docs_smoke.py b/tests/docs/test_docs_smoke.py deleted file mode 100644 index d24c5c51c..000000000 --- a/tests/docs/test_docs_smoke.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Smoke checks for the Phase 5 documentation set (R8.3, R11.4, R17).""" -from __future__ import annotations - -import pytest -from pathlib import Path - -REPO_ROOT = Path(__file__).resolve().parents[2] - -# ---------- file presence ------------------------------------------------- # - -OBSERVABILITY_EN = REPO_ROOT / 'docs' / 'source_en' / 'Usage Guide' / 'Observability.md' -OBSERVABILITY_ZH = REPO_ROOT / 'docs' / 'source_zh' / '使用指引' / '可观测化.md' -CONFIG_GUIDE_ZH = REPO_ROOT / 'docs' / 'source_zh' / '使用指引' / '服务配置.md' - -INDEX_EN = REPO_ROOT / 'docs' / 'source_en' / 'index.rst' -INDEX_ZH = REPO_ROOT / 'docs' / 'source_zh' / 'index.rst' - - -@pytest.mark.parametrize( - 'path', - [OBSERVABILITY_EN, OBSERVABILITY_ZH, CONFIG_GUIDE_ZH, INDEX_EN, INDEX_ZH], -) -def test_doc_exists(path: Path) -> None: - assert path.exists(), f'missing doc: {path}' - - -# ---------- observability guide content (R11.4, R17.1, R17.2) ------------ # - -_CORRELATION_KEYS = ( - 'twinkle.session_id', - 'twinkle.model_id', - 'twinkle.replica_id', - 'twinkle.token_id', - 'twinkle.sampling_session_id', - 'twinkle.base_model', -) - - -@pytest.mark.parametrize('path', [OBSERVABILITY_EN, OBSERVABILITY_ZH]) -def test_observability_lists_all_correlation_keys(path: Path) -> None: - text = path.read_text() - for key in _CORRELATION_KEYS: - assert key in text, f'{path.name}: missing correlation key {key}' - - -@pytest.mark.parametrize('path', [OBSERVABILITY_EN, OBSERVABILITY_ZH]) -def test_observability_describes_propagation(path: Path) -> None: - text = path.read_text() - # Mentions the carrier helpers + the propagation surface. - assert 'make_carrier' in text - assert 'activate_carrier' in text - assert 'DeploymentHandle' in text - - -@pytest.mark.parametrize('path', [OBSERVABILITY_EN, OBSERVABILITY_ZH]) -def test_observability_has_lgtm_example(path: Path) -> None: - text = path.read_text() - assert 'docker compose' in text or 'docker-compose' in text - assert 'cookbook/observability' in text - - -# ---------- server-config guide content (R8.3, R17.3) -------------------- # - - -def test_config_guide_lists_top_level_fields() -> None: - text = CONFIG_GUIDE_ZH.read_text() - for field in ('telemetry', 'persistence', 'task_queue', 'applications', 'http_options'): - assert field in text - - -def test_config_guide_documents_envvars() -> None: - text = CONFIG_GUIDE_ZH.read_text() - assert 'TWINKLE_SERVER_CONFIG' in text - assert 'TWINKLE_RAY_NAMESPACE' in text - - -def test_config_guide_includes_yaml_example() -> None: - text = CONFIG_GUIDE_ZH.read_text() - assert 'applications:' in text - assert 'backend: mock' in text or 'backend:' in text - # Reference to the documented example file. - assert 'server_config.example.yaml' in text - - -def test_config_guide_has_migration_table() -> None: - text = CONFIG_GUIDE_ZH.read_text() - # Both legacy → current rows must be present (R8.3). - assert 'telemetry_config' in text and 'telemetry:' in text - assert 'persistence_config' in text and 'persistence:' in text - assert 'use_megatron' in text and 'backend:' in text - - -# ---------- index links (R17.4) ------------------------------------------ # - - -def test_index_zh_links_both_guides() -> None: - text = INDEX_ZH.read_text() - assert '可观测化.md' in text - assert '服务配置.md' in text - - -def test_index_en_links_observability() -> None: - text = INDEX_EN.read_text() - assert 'Observability.md' in text diff --git a/tests/integration/test_lgtm_telemetry.py b/tests/integration/test_lgtm_telemetry.py deleted file mode 100644 index 2a916fe86..000000000 --- a/tests/integration/test_lgtm_telemetry.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""End-to-end OTLP telemetry tests against a local backend (R11.x, R13.3). - -Pushes traces via OTLP and reads them back through the trace backend's -HTTP API to verify: - -- correlation keys land on business spans (R11.2) -- the trace-context carrier round-trip places gateway/model/sampler spans - under one trace id, even across the OTLP pipeline (R13.3) - -The test auto-detects which trace backend is reachable on -``http://localhost:4317`` (OTLP gRPC): - -* **Tempo via Grafana** at ``http://localhost:3000`` — preferred. Bring - it up with the bundled stack: ``docker compose -f - cookbook/observability/docker-compose.yaml up -d``. -* **Jaeger** at ``http://localhost:16686`` — lighter fallback with the - same OTLP receiver. Start with ``docker run -d -e COLLECTOR_OTLP_ENABLED=true - -p 16686:16686 -p 4317:4317 jaegertracing/all-in-one:1.62.0``. - -Skips when neither is up. - -Resource-metric exposure (R12.1) and Grafana dashboard structure (R12.5) -are already covered by the in-process tests in -``tests/server/telemetry/test_tracing_and_correlation.py``; the OTLP-→-Mimir -hop is OTel SDK code, not Twinkle code, so it has no separate Twinkle test. -""" -from __future__ import annotations - -import httpx -import os -import pytest -import socket -import time -import urllib.parse -import uuid -from contextlib import contextmanager - -OTLP_ENDPOINT = os.environ.get('TWINKLE_TEST_OTLP_ENDPOINT', 'http://localhost:4317') -GRAFANA_URL = os.environ.get('TWINKLE_TEST_GRAFANA_URL', 'http://localhost:3000') -JAEGER_URL = os.environ.get('TWINKLE_TEST_JAEGER_URL', 'http://localhost:16686') - - -def _tcp_open(url: str, timeout: float = 1.0) -> bool: - parsed = urllib.parse.urlparse(url) - host = parsed.hostname or 'localhost' - port = parsed.port or (443 if parsed.scheme == 'https' else 80) - try: - with socket.create_connection((host, port), timeout=timeout): - return True - except OSError: - return False - - -def _grafana_ready() -> bool: - if not _tcp_open(GRAFANA_URL): - return False - try: - return httpx.get(f'{GRAFANA_URL}/api/health', timeout=2.0).status_code == 200 - except Exception: - return False - - -def _jaeger_ready() -> bool: - if not _tcp_open(JAEGER_URL): - return False - try: - return httpx.get(f'{JAEGER_URL}/', timeout=2.0).status_code == 200 - except Exception: - return False - - -def _detect_backend() -> str | None: - if not _tcp_open(OTLP_ENDPOINT): - return None - if _grafana_ready(): - return 'tempo' - if _jaeger_ready(): - return 'jaeger' - return None - - -_BACKEND = _detect_backend() - -pytestmark = pytest.mark.skipif( - _BACKEND is None, - reason=(f'No trace backend reachable. OTLP at {OTLP_ENDPOINT}, Grafana at {GRAFANA_URL}, ' - f'Jaeger at {JAEGER_URL}. Start one (cookbook/observability/docker-compose.yaml ' - 'or `docker run jaegertracing/all-in-one:1.62.0`).'), -) - -# ---------- helpers ------------------------------------------------------- # - - -def _force_replace_global_providers(tracer_provider, meter_provider) -> None: - """Force-replace the global OTel providers even if another test already set them. - - OTel's ``set_tracer_provider`` is one-shot per process — the conftest in - ``tests/server/telemetry/`` may have installed an in-memory exporter that - we'd otherwise inherit. Reset the underlying ``_TRACER_PROVIDER_SET_ONCE`` - guard so OTLP exporters become active for these tests. - """ - from opentelemetry import metrics, trace - from opentelemetry.util._once import Once - - # Replace tracer provider. - trace._TRACER_PROVIDER_SET_ONCE = Once() # type: ignore[attr-defined] - trace._TRACER_PROVIDER = None # type: ignore[attr-defined] - trace.set_tracer_provider(tracer_provider) - - # Replace meter provider. - metrics._METER_PROVIDER_SET_ONCE = Once() # type: ignore[attr-defined] - metrics._METER_PROVIDER = None # type: ignore[attr-defined] - metrics.set_meter_provider(meter_provider) - - -@contextmanager -def _telemetry_session(service_name: str): - """Initialize a fresh OTLP pipeline pointed at the local backend, force-flush at exit.""" - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter - from opentelemetry.sdk.metrics import MeterProvider - from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader - from opentelemetry.sdk.resources import Resource - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import BatchSpanProcessor - - resource = Resource.create({'service.name': service_name}) - tracer_provider = TracerProvider(resource=resource) - tracer_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint=OTLP_ENDPOINT))) - - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter(endpoint=OTLP_ENDPOINT), - export_interval_millis=1000, - ) - meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) - - _force_replace_global_providers(tracer_provider, meter_provider) - try: - yield service_name - finally: - try: - tracer_provider.force_flush(timeout_millis=5000) - meter_provider.force_flush(timeout_millis=5000) - except Exception: - pass - - -def _query_trace(service: str, trace_id_hex: str, attempts: int = 30, delay: float = 1.0) -> dict | None: - """Poll the configured backend until ``trace_id_hex`` appears.""" - if _BACKEND == 'tempo': - url = f'{GRAFANA_URL}/api/datasources/proxy/uid/tempo/api/traces/{trace_id_hex}' - for _ in range(attempts): - try: - r = httpx.get(url, timeout=5.0) - if r.status_code == 200 and r.json().get('batches'): - return r.json() - except Exception: - pass - time.sleep(delay) - return None - - # Jaeger: GET /api/traces/{id} - url = f'{JAEGER_URL}/api/traces/{trace_id_hex}' - for _ in range(attempts): - try: - r = httpx.get(url, timeout=5.0) - if r.status_code == 200: - data = r.json().get('data') or [] - if data and data[0].get('spans'): - return data[0] - except Exception: - pass - time.sleep(delay) - return None - - -def _spans_in_trace(payload: dict) -> list[dict]: - """Return a normalized list of spans across both backends.""" - if _BACKEND == 'tempo': - out = [] - for batch in payload.get('batches', []): - for scope in batch.get('scopeSpans', []): - for span in scope.get('spans', []): - out.append({ - 'name': span.get('name'), - 'attributes': { - a['key']: a.get('value', {}).get('stringValue') - for a in span.get('attributes', []) - }, - }) - return out - # Jaeger trace JSON: top-level "spans" with operationName + tags. - return [{ - 'name': s['operationName'], - 'attributes': { - t['key']: t.get('value') - for t in s.get('tags', []) - }, - } for s in payload.get('spans', [])] - - -# ---------- 7.15: trace + correlation visible in the trace store --------- # - - -def test_business_span_with_correlation_visible_e2e() -> None: - """A business span carrying twinkle.session_id / twinkle.model_id is - retrievable from the trace store after going through the OTLP pipeline - (R11.2).""" - from opentelemetry import trace - - from twinkle.server.telemetry.correlation import MODEL_ID, SESSION_ID - from twinkle.server.telemetry.tracing import traced_operation - - service = f'twinkle-test-trace-{uuid.uuid4().hex[:6]}' - session_id = f'sess-{uuid.uuid4().hex[:8]}' - model_id = f'mid-{uuid.uuid4().hex[:8]}' - - with _telemetry_session(service): - tracer = trace.get_tracer('twinkle.test.trace') - with tracer.start_as_current_span('integration.parent') as parent: - with traced_operation( - 'server_state.register_model', - attrs={ - SESSION_ID: session_id, - MODEL_ID: model_id - }, - ): - pass - trace_id_hex = format(parent.get_span_context().trace_id, '032x') - - payload = _query_trace(service, trace_id_hex) - assert payload is not None, f'trace {trace_id_hex} not found in {_BACKEND}' - - attrs_per_span = [s['attributes'] for s in _spans_in_trace(payload)] - assert any(a.get(SESSION_ID) == session_id - for a in attrs_per_span), (f'{SESSION_ID} not on any span in {_BACKEND}: {attrs_per_span}') - assert any(a.get(MODEL_ID) == model_id - for a in attrs_per_span), (f'{MODEL_ID} not on any span in {_BACKEND}: {attrs_per_span}') - - -# ---------- 10.4: single-trace-id fan-out across deployments (R13.3) ----- # - - -def test_carrier_round_trip_shares_trace_id_e2e() -> None: - """Simulate the Gateway → Model → Sampler hop via the carrier helpers. - The trace store records all three spans under one trace id.""" - from opentelemetry import trace - - from twinkle.server.telemetry.context_carrier import activate_carrier, make_carrier - - service = f'twinkle-test-fanout-{uuid.uuid4().hex[:6]}' - with _telemetry_session(service): - tracer = trace.get_tracer('twinkle.test.fanout') - - with tracer.start_as_current_span('gateway.route') as parent: - trace_id = parent.get_span_context().trace_id - carrier = make_carrier() - - with activate_carrier(carrier): - with tracer.start_as_current_span('model.handle') as child: - assert child.get_span_context().trace_id == trace_id - downstream = make_carrier() - - with activate_carrier(downstream): - with tracer.start_as_current_span('sampler.handle') as grandchild: - assert grandchild.get_span_context().trace_id == trace_id - - trace_id_hex = format(trace_id, '032x') - payload = _query_trace(service, trace_id_hex) - assert payload is not None, f'fan-out trace {trace_id_hex} not found in {_BACKEND}' - - span_names = {s['name'] for s in _spans_in_trace(payload)} - assert {'gateway.route', 'model.handle', 'sampler.handle'}.issubset(span_names), span_names diff --git a/tests/integration/test_mock_mode_startup.py b/tests/integration/test_mock_mode_startup.py index 11f504c2d..320e95ab2 100644 --- a/tests/integration/test_mock_mode_startup.py +++ b/tests/integration/test_mock_mode_startup.py @@ -20,6 +20,7 @@ import pytest import time +from tests.server.fixtures import MOCK_SERVER_CONFIG from twinkle.server.config import ServerConfig pytestmark = pytest.mark.skipif( @@ -72,7 +73,7 @@ def test_mock_mode_reaches_ready_under_30s_and_is_deterministic(ray_cluster) -> from twinkle.server.model import build_model_app from twinkle.server.sampler import build_sampler_app - cfg = ServerConfig.from_yaml('cookbook/client/server/mock/server_config.yaml') + cfg = ServerConfig.from_yaml(MOCK_SERVER_CONFIG) # Use a randomized port so concurrent runs / leftover processes don't collide. port = 18000 + (os.getpid() % 1000) diff --git a/tests/server/cli/test_cli.py b/tests/server/cli/test_cli.py index 46f340574..8dc83e95d 100644 --- a/tests/server/cli/test_cli.py +++ b/tests/server/cli/test_cli.py @@ -13,6 +13,7 @@ from typer.testing import CliRunner from unittest import mock +from tests.server.fixtures import MOCK_SERVER_CONFIG from twinkle.server.cli.app import app, main from twinkle.server.config import ServerConfig from twinkle.server.exceptions import ConfigMismatchError @@ -20,9 +21,8 @@ from twinkle.server.state.backend.memory_backend import MemoryBackend from twinkle.server.state.config_signature import _SIGNATURE_KEY, compute_signature, validate_against_backend -REPO_ROOT = Path(__file__).resolve().parents[3] -EXAMPLE = REPO_ROOT / 'cookbook' / 'client' / 'server' / 'server_config.example.yaml' -MOCK_CFG = REPO_ROOT / 'cookbook' / 'client' / 'server' / 'mock' / 'server_config.yaml' +EXAMPLE = MOCK_SERVER_CONFIG +MOCK_CFG = MOCK_SERVER_CONFIG # ---------- 9.5 CLI subcommand existence + exit codes (R14.3, R14.4) ------ # diff --git a/tests/server/fixtures/__init__.py b/tests/server/fixtures/__init__.py new file mode 100644 index 000000000..bad4af55d --- /dev/null +++ b/tests/server/fixtures/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared on-disk fixtures for server tests.""" +from __future__ import annotations + +from pathlib import Path + +FIXTURES_DIR = Path(__file__).resolve().parent + +# All-mock CPU-only server config used by CLI tests and the in-process e2e +# Ray Serve startup test. The YAML is the abstraction — callers pass this +# path to ServerConfig.from_yaml / CLI --config flags. +MOCK_SERVER_CONFIG = FIXTURES_DIR / 'server_config_mock.yaml' diff --git a/cookbook/client/server/mock/server_config.yaml b/tests/server/fixtures/server_config_mock.yaml similarity index 59% rename from cookbook/client/server/mock/server_config.yaml rename to tests/server/fixtures/server_config_mock.yaml index 45ed71eb2..8370a2448 100644 --- a/cookbook/client/server/mock/server_config.yaml +++ b/tests/server/fixtures/server_config_mock.yaml @@ -1,32 +1,20 @@ -# Twinkle Server Configuration — Mock backend (CPU-only / no GPU) +# Test-only Twinkle Server config — CPU-only mock backends. # -# NOT FOR PRODUCTION. This config wires the all-mock model + sampler backends -# so the server starts in seconds on a CPU-only host with no torch / -# transformers / vllm / megatron installed. Mock backends return fixed -# numpy-derived results without performing real model computation or -# sampling — use it for local development, CI smoke tests, and HTTP-surface -# debugging. For real training/inference use one of the GPU configs in -# ../transformer/ or ../megatron/. +# This file is the single fixture consumed by CLI tests and the in-process +# Ray Serve e2e test. Mock model + mock sampler need no GPU / torch / vllm / +# megatron, so the launch path can run on any CI host. Not for production. proxy_location: EveryNode http_options: - host: 0.0.0.0 + host: 127.0.0.1 port: 8000 -telemetry: - enabled: false - debug: false - service_name: twinkle-server - otlp_endpoint: http://localhost:4317 - -# In-process MemoryBackend — no Redis required for the mock workflow. persistence: mode: memory applications: - # 1. Tinker-compatible gateway - name: server route_prefix: /api/v1 import_path: server @@ -45,18 +33,17 @@ applications: ray_actor_options: num_cpus: 0.1 - # 2. Mock model — numpy-only, deterministic. Skips twinkle.initialize. - name: models-mock route_prefix: /api/v1/model/mock import_path: model args: - backend: mock # Mock backend: numpy-only, CPU-only + backend: mock model_id: mock-model nproc_per_node: 1 device_group: name: model ranks: 1 - device_type: cpu # No GPU required + device_type: cpu device_mesh: device_type: cpu dp_size: 1 @@ -72,12 +59,11 @@ applications: ray_actor_options: num_cpus: 0.1 - # 3. Mock sampler — numpy-only, deterministic. No vllm import. - name: sampler-mock route_prefix: /api/v1/sampler/mock import_path: sampler args: - sampler_type: mock # Mock sampler: numpy-only, CPU-only + sampler_type: mock model_id: mock-model nproc_per_node: 1 device_group: diff --git a/tests/server/sampler/test_mock_sampler.py b/tests/server/sampler/test_mock_sampler.py index ec32858cb..269a768a9 100644 --- a/tests/server/sampler/test_mock_sampler.py +++ b/tests/server/sampler/test_mock_sampler.py @@ -142,26 +142,3 @@ def test_mock_sampler_module_does_not_directly_import_vllm() -> None: assert forbidden not in text, f'mock_sampler.py contains {forbidden!r}' -# ---------- Mock example config loads (R5.4) ------------------------------ # - - -def test_mock_example_config_loads_via_server_config() -> None: - from twinkle.server.config import ServerConfig - - repo_root = Path(__file__).resolve().parents[3] - cfg_path = repo_root / 'cookbook' / 'client' / 'server' / 'mock' / 'server_config.yaml' - cfg = ServerConfig.from_yaml(cfg_path) - backends = { - a.name: getattr(a.args, 'backend', None) or getattr(a.args, 'sampler_type', None) - for a in cfg.applications - } - assert backends.get('models-mock') == 'mock' - assert backends.get('sampler-mock') == 'mock' - - -def test_mock_readme_documents_launch_and_targets() -> None: - repo_root = Path(__file__).resolve().parents[3] - readme = (repo_root / 'cookbook' / 'client' / 'server' / 'mock' / 'README.md').read_text() - assert 'python -m twinkle.server' in readme - assert '30 seconds' in readme or '30s' in readme - assert 'Not for production' in readme or 'NOT FOR PRODUCTION' in readme.upper() diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py index 636f8f87f..e4e0bf826 100644 --- a/tests/server/telemetry/test_tracing_and_correlation.py +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -261,27 +261,3 @@ def test_pyproject_declares_telemetry_extras() -> None: assert 'telemetry =' in text assert 'psutil' in text assert 'pynvml' in text - - -def test_grafana_dashboard_includes_resource_panels() -> None: - """Grafana dashboard JSON ships CPU / Memory / GPU panels (R12.5).""" - import json - from pathlib import Path - - repo_root = Path(__file__).resolve().parents[3] - dashboard = json.loads( - (repo_root / 'cookbook' / 'observability' / 'grafana' / 'dashboards' / 'twinkle-overview.json').read_text()) - titles = ' | '.join(p['title'].lower() for p in dashboard['panels']) - for required in ('cpu', 'memory', 'gpu utilization', 'gpu memory'): - assert required in titles, f'dashboard missing panel containing {required!r}' - - # Each resource gauge name must be referenced by at least one panel target. - targets = ' | '.join(t.get('expr', '') for p in dashboard['panels'] for t in p.get('targets', [])) - for metric in ( - 'twinkle_system_cpu_utilization', - 'twinkle_system_memory_usage_bytes', - 'twinkle_process_memory_usage_bytes', - 'twinkle_gpu_utilization', - 'twinkle_gpu_memory_usage_bytes', - ): - assert metric in targets, f'dashboard does not query metric {metric!r}' From 8cd73a852e839c2a13a362f1303af664f85b1533 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 2 Jun 2026 17:31:52 +0800 Subject: [PATCH 34/34] =?UTF-8?q?docs(zh):=20drop=20=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E9=85=8D=E7=BD=AE.md=20too=20=E2=80=94=20same=20unfinalized=20?= =?UTF-8?q?refactor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors the en side, which has no Server-Configuration guide. The ServerConfig schema is part of the same not-yet-confirmed surface as mock / observability — pull it out of the published toctree now. --- docs/source_zh/index.rst | 1 - ...15\345\212\241\351\205\215\347\275\256.md" | 85 ------------------- 2 files changed, 86 deletions(-) delete mode 100644 "docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" diff --git a/docs/source_zh/index.rst b/docs/source_zh/index.rst index 363a0e2d2..3d07d4b2a 100644 --- a/docs/source_zh/index.rst +++ b/docs/source_zh/index.rst @@ -12,7 +12,6 @@ Twinkle DOCUMENTATION 使用指引/快速开始.md 使用指引/安装.md 使用指引/服务端和客户端/index.rst - 使用指引/服务配置.md 使用指引/NPU的支持.md 使用指引/训练服务.md 使用指引/Qwen3.5最佳实践.md diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" deleted file mode 100644 index a2ba7e7ac..000000000 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\351\205\215\347\275\256.md" +++ /dev/null @@ -1,85 +0,0 @@ -# 服务配置 - -Twinkle Server 的所有运行时配置都来自一个 Pydantic 聚合根 -`ServerConfig`,由 YAML 文件经 `ServerConfig.from_yaml(path)` 加载。 -本指南覆盖:字段一览、支持的环境变量、一份最小 YAML 示例,以及从旧字段 -名到当前字段名的迁移表。 - -## 字段总览 - -| 字段 (类型) | 默认值 | 含义 | -|-----------------------------|-----------------------------|------| -| `ray_namespace` (str\|null) | null → "twinkle_cluster" | Ray cluster namespace | -| `proxy_location` (str\|null)| null | Ray Serve proxy 部署策略,多机推荐 "EveryNode" | -| `http_options.host` (str) | "localhost" | HTTP 监听地址,多机部署设为 "0.0.0.0" | -| `http_options.port` (int) | 8000 | HTTP 监听端口 | -| `persistence.mode` (str) | "memory" | "memory" / "file" / "redis" | -| `persistence.file_path` | null | mode=file 时必填 | -| `persistence.redis_url` | null | mode=redis 时必填 | -| `persistence.key_prefix` | "" | 全局 key 前缀 | -| `task_queue.rps_limit` | 100.0 | 每用户 token 的 RPS,0 关闭 | -| `task_queue.tps_limit` | 16000.0 | 每用户 token 的 input tokens/秒,0 关闭 | -| `task_queue.window_seconds` | 1.0 | 滑动窗口宽度(秒),必须 > 0 | -| `task_queue.queue_timeout` | 300.0 | 任务排队最长等待(秒) | -| `task_queue.execution_timeout`| 120.0 | 任务执行最长(秒),0 关闭 | -| `task_queue.max_input_tokens`| 16000 | 单请求最大 input tokens | -| `applications` (list) | [] | 部署清单:每项含 `name` / `route_prefix` / `import_path` / `args` / `deployments` | - -应用条目的 `args` 模式由 `import_path` 决定: - -- `import_path=server` → `ServerArgs`(`server_config`、`supported_models`、`http_options`) -- `import_path=model` → `ModelArgs`(必填 `backend: transformers|megatron`、`model_id`、`device_group`、`device_mesh`) -- `import_path=sampler`→ `SamplerArgs`(必填 `sampler_type: vllm|torch`、`model_id`) -- `import_path=processor`→ `ProcessorArgs` - -## 支持的环境变量 - -CLI 选项均声明 `envvar=`,命令行未指定时回退到环境变量: - -| 选项 | 环境变量 | -|----------------------|-------------------------------| -| `--config / -c` | `TWINKLE_SERVER_CONFIG` | -| `--namespace` | `TWINKLE_RAY_NAMESPACE` | -| `--format`(print-config) | `TWINKLE_PRINT_FORMAT` | - -启动器额外读取(用于跨 Ray worker 传播): - -- `TWINKLE_PERSISTENCE_MODE` / `_FILE_PATH` / `_REDIS_URL` / `_KEY_PREFIX` - -## 最小 YAML 示例 - -参见 `cookbook/client/server/transformer/server_config.yaml` 和 -`cookbook/client/server/megatron/server_config.yaml`。最小可执行示例: - -```yaml -http_options: { host: 0.0.0.0, port: 8000 } -persistence: { mode: memory } -applications: - - name: server - route_prefix: /api/v1 - import_path: server - args: { supported_models: [Qwen/Qwen3-4B] } - - name: models - route_prefix: /api/v1/model - import_path: model - args: - backend: transformers - model_id: Qwen/Qwen3-4B - device_group: { name: model, ranks: 1, device_type: cuda } - device_mesh: { device_type: cuda, dp_size: 1 } -``` - -## 旧字段 → 当前字段迁移表 - -本次重构对运维侧字段名做了一次干净的破坏性变更。**不再支持旧名作为别名**, -请按下表更新 YAML: - -| 旧字段名 | 当前字段名 | 备注 | -|---------------------------|--------------------|------| -| `persistence_config:` | `persistence:` | 顶层 | -| 在 model `args` 中:`use_megatron: true` | `backend: megatron` | 模型后端切换 | -| 在 model `args` 中:`use_megatron: false` | `backend: transformers` | 模型后端切换 | -| 在 sampler `args` 中:(隐含 vllm) | `sampler_type: vllm` | 显式声明 | - -YAML 中保留旧名会触发 `pydantic.ValidationError`,并在错误消息中点出 -不被识别的字段,便于直接修复。