diff --git a/.gitignore b/.gitignore index 8cfd041ff..daf3fa743 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,7 @@ images /custom/ megatron_output/ .qoder +.kiro/ # Pytorch *.pth diff --git a/Dockerfile b/Dockerfile index 8107ebcb3..668b339e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ RUN pip install flash-linear-attention -U --no-cache-dir RUN pip install numpy==2.2 --no-cache-dir # Install tinker, ray, and other deps -RUN pip install --no-cache-dir tinker==0.16.1 "ray[serve]" transformers peft<=0.18 accelerate -U +RUN pip install --no-cache-dir tinker==0.16.1 "ray[serve]" transformers peft<=0.18 accelerate redis opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp -U # Clone and install twinkle, checkout to latest v-tag RUN git clone https://github.com/modelscope/twinkle.git diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 696200200..7b18ed68c 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -9,6 +9,17 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 9000 # Port number for the server +# Persistence configuration for ServerState (sessions, models, futures, ...). +# Top-level placement makes the launcher propagate this to every Ray worker +# via env vars, so the configured backend is used regardless of which +# deployment initializes the ServerState actor first. +# mode: memory | file | redis +# file_path: required for `file` mode +# redis_url / key_prefix: required for `redis` mode +# persistence: +# mode: file +# file_path: /tmp/twinkle_state.json + # Applications: each entry defines a service component deployed on the server applications: @@ -84,7 +95,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.6-27B import_path: model args: - use_megatron: true # Use Megatron-LM backend + backend: megatron # Use Megatron-LM backend model_id: "ms://Qwen/Qwen3.6-27B" # ModelScope model identifier max_length: 65536 # model max length max_loras: 3 # model max loras diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index 36adb3328..7eed4699d 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -38,7 +38,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - use_megatron: true + backend: megatron model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 2 # Number of GPU processes per node diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 570142afa..ee23cc33a 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -9,6 +9,17 @@ http_options: host: 0.0.0.0 # Listen on all network interfaces port: 8000 # Port number for the server +# Persistence configuration for ServerState (sessions, models, futures, ...). +# Top-level placement makes the launcher propagate this to every Ray worker +# via env vars, so the configured backend is used regardless of which +# deployment initializes the ServerState actor first. +# mode: memory | file | redis +# file_path: required for `file` mode +# redis_url / key_prefix: required for `redis` mode +persistence: + mode: file + file_path: /tmp/twinkle_state.json + # Applications: each entry defines a service component deployed on the server applications: @@ -38,7 +49,7 @@ applications: route_prefix: /api/v1/model/Qwen/Qwen3.5-4B import_path: model args: - use_megatron: false # Use HuggingFace Transformers backend + backend: transformers # Model backend: transformers | megatron model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 1 # Number of GPU processes per node @@ -64,43 +75,43 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_TRUST_REMOTE_CODE: "1" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: 1 # Number of GPUs to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "1" # 4. Processor Service - name: processor diff --git a/cookbook/client/twinkle/self_host/self_cognition.py b/cookbook/client/twinkle/self_host/self_cognition.py index c5b771aac..997a77322 100644 --- a/cookbook/client/twinkle/self_host/self_cognition.py +++ b/cookbook/client/twinkle/self_host/self_cognition.py @@ -24,7 +24,7 @@ base_model = 'Qwen/Qwen3.5-4B' base_url = 'http://localhost:8000' api_key = 'EMPTY_API_KEY' -save_dir = '/model' +save_dir = '/tmp/twinkle_sft_output' # Step 2: Initialize the Twinkle client to communicate with the remote server. @@ -108,8 +108,10 @@ def train(): start_step = progress['cur_step'] # Step 7: Run the training loop + max_steps = 10 # Limit to 10 steps for quick verification logger.info(model.get_train_configs().model_dump()) + global_step = 0 for epoch in range(3): logger.info(f'Starting epoch {epoch}') for cur_step, batch in enumerate(dataloader, start=start_step + 1): @@ -128,12 +130,22 @@ def train(): # # Advance the learning rate scheduler by one step # model.lr_step() + global_step += 1 + # Log the loss every 2 steps (aligned with gradient accumulation) if cur_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') + # Stop after max_steps + if global_step >= max_steps: + logger.info(f'Reached max_steps={max_steps}, stopping training.') + break + + if global_step >= max_steps: + break + # Step 8: Save the trained checkpoint twinkle_path = model.save( name=f'twinkle-epoch-{epoch}', diff --git a/pyproject.toml b/pyproject.toml index 964a7548c..8949c8693 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,12 @@ dependencies = [ "safetensors", "peft>=0.11.0,<=0.19.0", "transformers", + "typer>=0.9.0", ] +[project.scripts] +twinkle-server = "twinkle.server.cli:main" + [project.optional-dependencies] transformers = [ "accelerate", @@ -27,6 +31,9 @@ megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]", "mcore_bridg vllm = ["vllm>=0.11"] ray = ["ray[serve]"] tinker = ["tinker==0.14.0"] +test = ["hypothesis>=6.0"] +telemetry = ["psutil>=5.9.0", "pynvml>=11.0.0"] +redis = ["redis>=5.0"] docs = [ "sphinx>=5.3.0,<6.0.0", "docutils>=0.16.0,<0.17.0", diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index 8f97ef097..28e86ad15 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -1,117 +1,22 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -""" -CLI entry point for Twinkle Server. +"""CLI entry point for Twinkle Server. + +Thin shim — delegates to the typer-based :mod:`twinkle.server.cli` so the +``python -m twinkle.server`` command and the ``twinkle-server`` console +script share one implementation. + +Usage:: -Usage: - # From config file - python -m twinkle.server --config server_config.yaml + python -m twinkle.server launch --config server_config.yaml + python -m twinkle.server check-config --config server_config.yaml + python -m twinkle.server print-config --config server_config.yaml + python -m twinkle.server clear persistence --config server_config.yaml """ from __future__ import annotations -import argparse -import os import sys -from pathlib import Path - -from twinkle import get_logger - -logger = get_logger() - - -def create_parser() -> argparse.ArgumentParser: - """Create the argument parser.""" - parser = argparse.ArgumentParser( - prog='python -m twinkle.server', - description='Twinkle Server Launcher - Unified launcher supporting both Tinker and Twinkle clients', - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Start server from YAML config file - python -m twinkle.server --config server_config.yaml - """, - ) - - # Config file option - parser.add_argument( - '-c', - '--config', - type=str, - required=True, - metavar='PATH', - help='Path to YAML configuration file (required)', - ) - - # Ray options - parser.add_argument( - '--namespace', - type=str, - metavar='NS', - help="Ray namespace (default: 'twinkle_cluster')", - ) - - # Runtime options - parser.add_argument( - '--log-level', - type=str, - default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], - metavar='LEVEL', - help='Logging level (default: INFO)', - ) - - return parser - - -def main(args: list[str] | None = None) -> int: - """ - Main entry point for the CLI. - - Args: - args: Command line arguments (uses sys.argv if None) - - Returns: - Exit code (0 for success, non-zero for error) - """ - parser = create_parser() - parsed_args = parser.parse_args(args) - - try: - from twinkle.server.launcher import launch_server - - # Apply log level so that all loggers (including those created later) - # pick up the user-specified level via the LOG_LEVEL env var that - # get_logger() already reads. - os.environ['LOG_LEVEL'] = parsed_args.log_level - - config_path = Path(parsed_args.config) - if not config_path.exists(): - logger.error(f'Config file not found: {config_path}') - return 1 - - launch_server( - config_path=config_path, - ray_namespace=parsed_args.namespace, - ) - - return 0 - - except KeyboardInterrupt: - logger.info('Server stopped by user') - return 0 - except FileNotFoundError as e: - logger.error(f'File not found: {e}') - return 1 - except ValueError as e: - logger.error(f'Configuration error: {e}') - return 1 - except ImportError as e: - logger.error(f'Import error: {e}') - logger.error('Make sure all required dependencies are installed') - return 1 - except Exception as e: - logger.exception(f'Unexpected error: {e}') - return 1 +from twinkle.server.cli import main if __name__ == '__main__': sys.exit(main()) diff --git a/src/twinkle/server/cli/__init__.py b/src/twinkle/server/cli/__init__.py new file mode 100644 index 000000000..4be1e0a1e --- /dev/null +++ b/src/twinkle/server/cli/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Twinkle Server CLI (typer).""" +from .app import app, main + +__all__ = ['app', 'main'] diff --git a/src/twinkle/server/cli/app.py b/src/twinkle/server/cli/app.py new file mode 100644 index 000000000..32076781d --- /dev/null +++ b/src/twinkle/server/cli/app.py @@ -0,0 +1,165 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Typer-based operations CLI (R14, R15). + +Provides four subcommands: + +- ``launch`` — start the Twinkle Server from a YAML config. + Validates the persistence config signature against + the persistence backend BEFORE ``ray.init`` so a + configuration drift fails fast (R15.1). +- ``check-config`` — validate a config file; exit 0 on success, non-zero + with the validation error on failure (R14.3, R14.4). +- ``print-config`` — emit the fully resolved + normalized ``ServerConfig`` + as YAML (R14.5). +- ``clear persistence``— delete persisted state for the namespace derived + from a config file (R14.2). +""" +from __future__ import annotations + +import asyncio +import json +import sys +import typer +import yaml +from pathlib import Path +from typing import Optional + +from twinkle.server.config import ServerConfig +from twinkle.server.exceptions import ConfigMismatchError, ConfigParseError + +app = typer.Typer( + add_completion=False, + no_args_is_help=True, + help='Operations CLI for Twinkle Server.', +) + +clear_app = typer.Typer( + no_args_is_help=True, + help='Clear server-side state.', +) +app.add_typer(clear_app, name='clear') + +CONFIG_OPTION = typer.Option( + ..., + '--config', + '-c', + envvar='TWINKLE_SERVER_CONFIG', + help='Path to the YAML configuration file.', + metavar='PATH', +) +NAMESPACE_OPTION = typer.Option( + None, + '--namespace', + envvar='TWINKLE_RAY_NAMESPACE', + help='Ray namespace (overrides ray_namespace in the config).', +) + + +def _load_config(path: Path) -> ServerConfig: + """Load + validate ``path``; print typed errors and exit non-zero on failure.""" + try: + return ServerConfig.from_yaml(path) + except FileNotFoundError as e: + typer.echo(f'error: {e}', err=True) + raise typer.Exit(code=2) + except ConfigParseError as e: + typer.echo(f'error: {e}', err=True) + raise typer.Exit(code=2) + except Exception as e: # pydantic.ValidationError + cross-field + typer.echo(f'error: invalid configuration\n{e}', err=True) + raise typer.Exit(code=2) + + +def _signature_payload(config: ServerConfig) -> dict: + """The dict whose hash drives signature validation. + + Restricted to the persistence-relevant fields so unrelated edits don't + invalidate the persisted state. Future phases can broaden this. + """ + return {'persistence': config.persistence.model_dump(mode='json')} + + +@app.command('launch') +def launch_cmd( + config: Path = CONFIG_OPTION, + namespace: str | None = NAMESPACE_OPTION, +) -> None: + """Start the Twinkle Server from a YAML config file (R14.1, R15.1).""" + cfg = _load_config(config) + + # Validate the persistence config signature BEFORE we touch Ray (R15.1). + try: + from twinkle.server.state.config_signature import validate_against_backend + + asyncio.run(validate_against_backend(cfg.persistence, _signature_payload(cfg))) + except ConfigMismatchError as e: + typer.echo(f'error: {e}', err=True) + raise typer.Exit(code=3) + + # Defer the heavy launcher import until after drift validation passes so + # the failure path stays cheap (and a missing Ray install doesn't block + # `check-config`). + from twinkle.server.launcher import ServerLauncher + + launcher = ServerLauncher(config=cfg, ray_namespace=namespace or cfg.ray_namespace) + launcher.launch() + + +@app.command('check-config') +def check_config_cmd(config: Path = CONFIG_OPTION) -> None: + """Validate ``config`` and exit 0 on success, non-zero on failure (R14.3, R14.4).""" + _load_config(config) + typer.echo('ok') + + +@app.command('print-config') +def print_config_cmd( + config: Path = CONFIG_OPTION, + fmt: str = typer.Option('yaml', '--format', envvar='TWINKLE_PRINT_FORMAT', help='yaml|json'), +) -> None: + """Emit the validated, normalized ``ServerConfig`` (R14.5).""" + cfg = _load_config(config) + payload = cfg.to_yaml_dict() + if fmt == 'json': + typer.echo(json.dumps(payload, indent=2, sort_keys=True)) + else: + typer.echo(yaml.safe_dump(payload, sort_keys=True).rstrip()) + + +@clear_app.command('persistence') +def clear_persistence_cmd(config: Path = CONFIG_OPTION) -> None: + """Remove persisted state for the namespace derived from ``config`` (R14.2).""" + cfg = _load_config(config) + from twinkle.server.state.backend.factory import create_backend + + async def _clear() -> int: + backend = create_backend(cfg.persistence) + keys = await backend.keys('*') + removed = 0 + for k in keys: + await backend.delete(k) + removed += 1 + return removed + + n = asyncio.run(_clear()) + typer.echo(f'cleared {n} keys from persistence backend (mode={cfg.persistence.mode})') + + +def main(argv: list[str] | None = None) -> int: + """Programmatic entry point used by ``__main__.py`` and tests. + + Runs the typer app in standalone mode and converts its ``SystemExit`` + into a plain return code so callers can react without re-trapping. + """ + try: + app(args=argv, standalone_mode=True) + except SystemExit as exc: + code = exc.code + if code is None: + return 0 + return int(code) if not isinstance(code, int) else code + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/twinkle/server/common/router.py b/src/twinkle/server/common/router.py index dee1bd36e..cc9b94915 100644 --- a/src/twinkle/server/common/router.py +++ b/src/twinkle/server/common/router.py @@ -4,7 +4,7 @@ RequestRouter, RunningReplica) from typing import Dict, List, Optional -from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.state import ServerState, get_server_state from twinkle.utils.logger import get_logger logger = get_logger() @@ -15,7 +15,7 @@ class StickyLoraRequestRouter(FIFOMixin, MultiplexMixin, RequestRouter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() async def choose_replicas( self, diff --git a/src/twinkle/server/config/__init__.py b/src/twinkle/server/config/__init__.py new file mode 100644 index 000000000..8bb39eebe --- /dev/null +++ b/src/twinkle/server/config/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Server configuration package — aggregate root and per-deployment specs.""" + +from .application_spec import ApplicationSpec, HttpOptions, ModelArgs, ProcessorArgs, SamplerArgs, ServerArgs +from .server_config import ServerConfig + +__all__ = [ + 'ApplicationSpec', + 'HttpOptions', + 'ModelArgs', + 'ProcessorArgs', + 'SamplerArgs', + 'ServerArgs', + 'ServerConfig', +] diff --git a/src/twinkle/server/config/application_spec.py b/src/twinkle/server/config/application_spec.py new file mode 100644 index 000000000..171859299 --- /dev/null +++ b/src/twinkle/server/config/application_spec.py @@ -0,0 +1,154 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Per-deployment ``ApplicationSpec`` and typed argument schemas (R6, R3). + +Each deployment kind (``server | model | sampler | processor``) carries its +own ``args`` block with strict field validation. ``ApplicationSpec`` holds +the routing metadata plus the deployment kind and validates ``args`` against +the matching ``*Args`` schema in a model validator. + +The schemas use ``extra='forbid'`` so unknown args (typos, copy-paste from +other deployments) fail at load time instead of being silently dropped. +""" +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing import Any, Literal + +from twinkle.server.utils.task_queue.config import TaskQueueConfig + +# ---------- shared helpers ------------------------------------------------- # + + +class _ArgsBase(BaseModel): + """Base class for every per-deployment args schema (extra='forbid').""" + + model_config = ConfigDict(extra='forbid') + + +class HttpOptions(BaseModel): + """HTTP listener settings (host/port). + + Re-exported from ``server_config`` for convenience and so that + ``ServerArgs`` can carry it without importing the aggregate root. + """ + + model_config = ConfigDict(extra='forbid') + + host: str = 'localhost' + port: int = 8000 + + +# ---------- per-deployment args schemas (R3.x) ----------------------------- # + + +class ModelArgs(_ArgsBase): + """Args for the ``model`` deployment. + + The ``backend`` field selects the model implementation and replaces the + legacy ``use_megatron: bool`` flag. Phase 0c introduces this field; the + actual dispatch on its value is wired up in Phase 1 (R3.1-3.3, R3.9). + """ + + model_id: str + nproc_per_node: int = 1 + device_group: dict[str, Any] + device_mesh: dict[str, Any] + backend: Literal['mock', 'transformers', 'megatron'] + adapter_config: dict[str, Any] | None = None + queue_config: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + max_loras: int = 5 + max_length: int | None = None + + +class SamplerArgs(_ArgsBase): + """Args for the ``sampler`` deployment. + + ``sampler_type`` selects the sampler implementation (R3.4-3.6, R3.10). + """ + + model_id: str + nproc_per_node: int = 1 + device_group: dict[str, Any] + device_mesh: dict[str, Any] + sampler_type: Literal['mock', 'vllm', 'torch'] + engine_args: dict[str, Any] | None = None + queue_config: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + + +class ServerArgs(_ArgsBase): + """Args for the gateway ``server`` deployment.""" + + server_config: dict[str, Any] | None = None + supported_models: list[Any] | None = None + http_options: HttpOptions | None = None + route_prefix: str | None = None + + +class ProcessorArgs(_ArgsBase): + """Args for the ``processor`` deployment.""" + + ncpu_proc_per_node: int | None = None + device_group: dict[str, Any] | None = None + device_mesh: dict[str, Any] | None = None + queue_config: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + + +_ARGS_SCHEMA: dict[str, type[_ArgsBase]] = { + 'server': ServerArgs, + 'model': ModelArgs, + 'sampler': SamplerArgs, + 'processor': ProcessorArgs, +} + +# ---------- ApplicationSpec ------------------------------------------------ # + + +class ApplicationSpec(BaseModel): + """One application entry under ``ServerConfig.applications``. + + The ``args`` block is validated against the schema selected by + ``import_path``: ``server`` → ``ServerArgs``, ``model`` → ``ModelArgs``, + etc. Unknown keys at this level (or inside ``args``) are rejected. A + missing ``args`` block is treated as ``{}`` and validated against the + matching schema, so any required field (e.g. ``backend`` on a model + deployment) raises with the offending field path instead of silently + falling back to a different schema's default. + """ + + model_config = ConfigDict(extra='forbid') + + name: str + route_prefix: str = '/' + import_path: Literal['server', 'model', 'sampler', 'processor'] + # ``args`` is filled in by the ``mode='before'`` validator below; the + # default of ``{}`` is only meaningful for kinds whose schema has no + # required fields (currently ``server``). + args: ServerArgs | ModelArgs | SamplerArgs | ProcessorArgs = Field(default=None) # type: ignore[assignment] + deployments: list[dict[str, Any]] = Field(default_factory=list) + + @model_validator(mode='before') + @classmethod + def _coerce_args_to_schema(cls, data: Any) -> Any: + """Validate the raw ``args`` block against the schema for ``import_path``. + + Keying off ``import_path`` makes the failure messages point at the + right schema and avoids the ambiguity Pydantic's structural Union + resolution would introduce when two schemas share field names. + """ + if not isinstance(data, dict): + return data + import_path = data.get('import_path') + if import_path not in _ARGS_SCHEMA: + # Let Pydantic's Literal validator handle bad import_path values. + return data + schema = _ARGS_SCHEMA[import_path] + raw_args = data.get('args') + if isinstance(raw_args, schema): + return data + if raw_args is None: + raw_args = {} + if not isinstance(raw_args, dict): + # Non-dict, non-instance args is a hard schema error — surface + # it through the matching schema for a clean error message. + raw_args = dict(raw_args) if hasattr(raw_args, 'keys') else raw_args + return {**data, 'args': schema.model_validate(raw_args)} diff --git a/src/twinkle/server/config/server_config.py b/src/twinkle/server/config/server_config.py new file mode 100644 index 000000000..0183b4739 --- /dev/null +++ b/src/twinkle/server/config/server_config.py @@ -0,0 +1,84 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Aggregate-root server configuration (R6, R7, R8). + +``ServerConfig`` is the single Pydantic model that nests every configuration +subsystem the launcher consumes (telemetry, persistence, task-queue, and the +list of ``ApplicationSpec``). Loading a YAML file and validating it goes +through one entry point — ``ServerConfig.from_yaml(path)`` — so the launcher +no longer reaches into a raw dict. + +Top-level fields use their current names with no aliases for legacy names +(``telemetry_config``, ``persistence_config``); the model is configured with +``extra='forbid'`` so a YAML that uses a legacy field is rejected with the +offending name pointed at (R8.1, R8.2). +""" +from __future__ import annotations + +from pathlib import Path +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing import Any + +from twinkle.server.exceptions import ConfigParseError +from twinkle.server.state.backend.factory import PersistenceConfig +from twinkle.server.telemetry.provider import TelemetryConfig +from twinkle.server.utils.task_queue.config import TaskQueueConfig +from .application_spec import ApplicationSpec, HttpOptions + + +class ServerConfig(BaseModel): + """Top-level server configuration aggregate root.""" + + model_config = ConfigDict(extra='forbid') + + ray_namespace: str | None = None + proxy_location: str | None = None + http_options: HttpOptions = Field(default_factory=HttpOptions) + telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig) + persistence: PersistenceConfig = Field(default_factory=PersistenceConfig) + task_queue: TaskQueueConfig = Field(default_factory=TaskQueueConfig) + applications: list[ApplicationSpec] = Field(default_factory=list) + + # ---- loading ---------------------------------------------------------- # + + @classmethod + def from_yaml(cls, path: str | Path) -> ServerConfig: + """Load and validate a YAML file into a ``ServerConfig``. + + Raises: + FileNotFoundError: ``path`` does not exist or cannot be read. + ConfigParseError: ``path`` exists but is not well-formed YAML. + pydantic.ValidationError: a field or cross-field constraint + fails — the error names every offending field. + """ + from omegaconf import OmegaConf + + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f'Config file not found: {p}') + try: + raw = OmegaConf.to_container(OmegaConf.load(p), resolve=True) + except Exception as e: # malformed YAML / OmegaConf parse failure + raise ConfigParseError(f'Malformed YAML in {p}: {e}') from e + if raw is None: + raw = {} + if not isinstance(raw, dict): + raise ConfigParseError(f'Top-level YAML in {p} must be a mapping, got {type(raw).__name__}', ) + return cls.model_validate(raw) + + # ---- cross-field validation (R7) ------------------------------------- # + + @model_validator(mode='after') + def _validate_cross_field(self) -> ServerConfig: + # R7.1: redis mode requires redis_url + if self.persistence.mode == 'redis' and not self.persistence.redis_url: + raise ValueError("persistence.redis_url is required when persistence.mode == 'redis'", ) + # R7.2: file mode requires file_path + if self.persistence.mode == 'file' and not self.persistence.file_path: + raise ValueError("persistence.file_path is required when persistence.mode == 'file'", ) + return self + + # ---- round-trip / serialization (R6.7) ------------------------------- # + + def to_yaml_dict(self) -> dict[str, Any]: + """Return a JSON-mode dict suitable for ``yaml.safe_dump`` / round-trip.""" + return self.model_dump(mode='json') diff --git a/src/twinkle/server/exceptions.py b/src/twinkle/server/exceptions.py new file mode 100644 index 000000000..b99c19bf1 --- /dev/null +++ b/src/twinkle/server/exceptions.py @@ -0,0 +1,63 @@ +"""Twinkle Server unified exception hierarchy.""" + +from __future__ import annotations + + +class TwinkleServerError(Exception): + """Base class for all Twinkle Server exceptions.""" + pass + + +class StateBackendError(TwinkleServerError): + """State backend operation failed (connection lost, timeout, data serialization error, etc.).""" + pass + + +class ConfigMismatchError(TwinkleServerError): + """Configuration signature mismatch — config changed since last launch. + + Persisted data may be incompatible with the current configuration; the + operator must reconcile (revert the config change or clear persisted + state) before the server can start. + """ + pass + + +class ConfigError(TwinkleServerError): + """Invalid configuration value for a known field. + + Used when a field is present and parseable but its value is not in the + permitted set (e.g. ``backend`` is ``""`` or ``"hf"``). Carries enough + detail for the operator to find and fix the offending YAML entry without + re-running the server. + """ + + def __init__( + self, + field: str, + value: object, + allowed: list[str] | tuple[str, ...] | None = None, + message: str | None = None, + ) -> None: + self.field = field + self.value = value + self.allowed = list(allowed) if allowed is not None else None + if message is None: + allowed_part = f', allowed: {self.allowed}' if self.allowed is not None else '' + message = f'Invalid value for {field}: {value!r}{allowed_part}' + super().__init__(message) + + +class ConfigParseError(TwinkleServerError): + """The configuration source could not be parsed (malformed YAML, ...). + + Distinct from ``pydantic.ValidationError`` (which signals that a parsed + value violates a field/cross-field rule) and from ``FileNotFoundError`` + (which signals that the source could not be read at all). + """ + pass + + +class ResourceExhaustedError(TwinkleServerError): + """Resource exhausted — queue full, insufficient memory, connection pool exhausted, etc.""" + pass diff --git a/src/twinkle/server/gateway/proxy.py b/src/twinkle/server/gateway/proxy.py index 0978b8e03..14dc35424 100644 --- a/src/twinkle/server/gateway/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -12,6 +12,7 @@ from fastapi import Request, Response from typing import Any +from twinkle.server.telemetry.tracing import inject_context from twinkle.utils.logger import get_logger logger = get_logger() @@ -97,6 +98,10 @@ async def proxy_request( target_url = self._build_target_url(service_type, base_model, endpoint) headers = self._prepare_headers(request.headers) + # Inject current trace context into outgoing headers for distributed tracing. + # When telemetry is not initialized, this is a noop. + inject_context(headers) + try: logger.debug( 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 755c5d2b4..b0644f3c7 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -14,8 +14,9 @@ from typing import Any import twinkle_client.types as types +from twinkle.server.state import get_server_state +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import get_server_state from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger from .proxy import ServiceProxy @@ -42,6 +43,24 @@ def __init__(self, types.SupportedModel(model_name='Qwen/Qwen3.6-27B'), ] self._modelscope_config_lock = asyncio.Lock() + self._state_cleanup_started = False + + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + Ray Serve binds ``serve.get_replica_context().servable_object`` AFTER + FastAPI ``lifespan`` startup, so the cleanup task cannot run there + (``get_self()`` returns ``None`` during lifespan). Lazy-init here on + the first request instead. ``start_cleanup_task`` is idempotent via + its internal ``_cleanup_running`` guard. + """ + if self._state_cleanup_started: + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True def _normalize_models(self, supported_models): if not supported_models: @@ -96,6 +115,12 @@ def get_self() -> GatewayServer: @asynccontextmanager async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + # NOTE: ``state.start_cleanup_task()`` cannot run here — Ray Serve binds + # ``servable_object`` AFTER lifespan startup. Lazy-started from the + # first request via the ``ensure_state_cleanup_started`` middleware. yield try: await get_self().proxy.close() @@ -104,10 +129,26 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) + @app.middleware('http') + async def ensure_state_cleanup_started(request: Request, call_next): + # Lazy-init the state cleanup + metrics loops on first request — see + # GatewayServer._ensure_state_cleanup_started. Gateway has no per- + # handler hook, so a tiny middleware covers every route. + try: + await get_self()._ensure_state_cleanup_started() + except Exception as e: + logger.debug(f'state cleanup lazy-init skipped: {e}') + return await call_next(request) + @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + # Registration order matters: FastAPI runs middleware in LIFO order, so the + # last-registered wraps the outermost layer. Tracing first → metrics last + # makes metrics the outermost wrapper and capture the full end-to-end + # latency including tracing overhead and auth. + app.middleware('http')(create_tracing_middleware('Gateway')) app.middleware('http')(create_metrics_middleware('Gateway')) _register_tinker_routes(app, get_self) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index e2a6179ac..41662c5fb 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -21,12 +21,15 @@ """ from __future__ import annotations +import os import signal import threading from pathlib import Path from typing import Any, Callable, Dict, NoReturn, Optional, Union from twinkle import get_logger +from twinkle.server.config import ServerConfig +from twinkle.server.config.application_spec import ApplicationSpec from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches logger = get_logger() @@ -53,22 +56,63 @@ class ServerLauncher: def __init__( self, - config: dict[str, Any] | None = None, + config: ServerConfig, ray_namespace: str | None = None, ): """ Initialize the server launcher. Args: - config: Configuration dictionary + config: A validated :class:`ServerConfig` instance. Raw dicts are + rejected — operators must build ``ServerConfig`` via + ``ServerConfig.from_yaml`` or its constructor so cross-field + validation runs before the launcher consumes anything (R6.6). ray_namespace: Ray namespace (default: 'twinkle_cluster') """ - self.config = config or {} + if not isinstance(config, ServerConfig): + raise TypeError('ServerLauncher requires a typed ServerConfig instance; ' + f'got {type(config).__name__}. Build one with ' + 'ServerConfig.from_yaml(path) or ServerConfig(...).') + self.config: ServerConfig = config self.ray_namespace = ray_namespace self._builders: dict[str, Callable] = {} self._ray_initialized = False self._serve_started = False + # Telemetry env var keys that need to be propagated to Ray worker processes + _TELEMETRY_ENV_KEYS: tuple[str, ...] = ( + 'TWINKLE_TELEMETRY_ENABLED', + 'TWINKLE_TELEMETRY_DEBUG', + 'TWINKLE_TELEMETRY_SERVICE', + 'TWINKLE_TELEMETRY_ENDPOINT', + 'TWINKLE_TELEMETRY_INTERVAL', + ) + + def _build_telemetry_env_vars(self) -> dict[str, str]: + """Collect telemetry env vars from os.environ for propagation to Ray workers. + + These vars are read by ``ensure_telemetry_initialized()`` inside the + FastAPI startup hook running in each worker process. + """ + return {k: os.environ[k] for k in self._TELEMETRY_ENV_KEYS if k in os.environ} + + def _build_persistence_env_vars(self) -> dict[str, str]: + """Collect persistence env vars from os.environ for propagation to Ray workers. + + These vars are read by ``PersistenceConfig.from_env()`` inside any + worker that calls ``get_server_state()`` without an explicit config, + which makes the chosen backend independent of deployment startup order. + """ + from twinkle.server.state.backend.factory import PERSISTENCE_ENV_KEYS + return {k: os.environ[k] for k in PERSISTENCE_ENV_KEYS if k in os.environ} + + def _build_propagated_env_vars(self) -> dict[str, str]: + """Aggregate all env vars that must reach Ray worker processes.""" + merged: dict[str, str] = {} + merged.update(self._build_telemetry_env_vars()) + merged.update(self._build_persistence_env_vars()) + return merged + def _get_builders(self) -> dict[str, Callable]: """Get the builder functions for all app types.""" if self._builders: @@ -123,12 +167,18 @@ def _init_ray(self) -> None: import ray - namespace = self.ray_namespace or self.config.get('ray_namespace') or 'twinkle_cluster' + namespace = self.ray_namespace or self.config.ray_namespace or 'twinkle_cluster' if not ray.is_initialized(): # Use runtime_env to apply patches in worker processes # This is required because Ray Serve's ProxyActor runs in separate processes runtime_env = get_runtime_env_for_patches() + # Propagate telemetry + persistence env vars to all Ray workers + propagated_env_vars = self._build_propagated_env_vars() + if propagated_env_vars: + merged_env_vars = dict(runtime_env.get('env_vars') or {}) + merged_env_vars.update(propagated_env_vars) + runtime_env['env_vars'] = merged_env_vars # Connect to existing cluster if available, otherwise start local instance ray.init( address='auto', @@ -155,30 +205,35 @@ def _start_serve(self) -> None: # Serve not running — nothing to shut down pass - http_options = self.config.get('http_options', {}) - if isinstance(http_options, dict): - http_options = dict(http_options) - else: - http_options = dict(http_options) if http_options else {} - - serve.start(http_options=http_options) - logger.info(f'Ray Serve started with http_options={http_options}') + http_options = self.config.http_options.model_dump() + serve_kwargs: dict[str, Any] = {'http_options': http_options} + # ``proxy_location`` controls where the Ray Serve HTTP proxy runs + # (``EveryNode`` / ``HeadOnly`` / ``Disabled``). The example configs + # set this field, so honour it here instead of silently ignoring. + if self.config.proxy_location: + serve_kwargs['proxy_location'] = self.config.proxy_location + serve.start(**serve_kwargs) + logger.info(f'Ray Serve started with http_options={http_options}, ' + f'proxy_location={self.config.proxy_location!r}') self._serve_started = True - def _deploy_application(self, app_config: dict[str, Any]) -> None: + def _deploy_application(self, app_spec: ApplicationSpec) -> None: """Deploy a single application. Args: - app_config: Application configuration dictionary + app_spec: Validated :class:`ApplicationSpec` from the typed config. """ from ray import serve - name = app_config.get('name', 'app') - route_prefix = app_config.get('route_prefix', '/') - import_path = app_config.get('import_path', 'server') - args = app_config.get('args', {}) or {} - deployments = app_config.get('deployments', []) + name = app_spec.name + route_prefix = app_spec.route_prefix + import_path = app_spec.import_path + # Re-serialize the typed args back to a kwargs dict for the builder. + # Using ``mode='python'`` keeps nested Pydantic models as dicts (which + # the legacy builders expect) without losing field-level validation. + args = app_spec.args.model_dump(mode='python', exclude_none=True) + deployments = list(app_spec.deployments or []) logger.info(f'Starting {name} at {route_prefix}...') @@ -193,12 +248,27 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: if isinstance(deploy_config, dict): deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} + # Inject telemetry + persistence env vars into the deployment's + # runtime_env so that Ray Serve replicas (worker processes) can + # initialize telemetry and resolve the configured persistence backend + # regardless of deployment startup order. + # User-specified env_vars take precedence over our defaults. + propagated_env_vars = self._build_propagated_env_vars() + if propagated_env_vars: + ray_actor_options = dict(deploy_options.get('ray_actor_options') or {}) + runtime_env = dict(ray_actor_options.get('runtime_env') or {}) + env_vars = dict(runtime_env.get('env_vars') or {}) + for k, v in propagated_env_vars.items(): + env_vars.setdefault(k, v) + runtime_env['env_vars'] = env_vars + ray_actor_options['runtime_env'] = runtime_env + deploy_options['ray_actor_options'] = ray_actor_options + # Pass http_options to server apps for internal proxy routing - http_options = self.config.get('http_options', {}) - if import_path == 'server' and http_options: - args['http_options'] = http_options + if import_path == 'server': + args.setdefault('http_options', self.config.http_options.model_dump()) - app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()}) + app = builder(deploy_options=deploy_options, **args) serve.run(app, name=name, route_prefix=route_prefix) logger.info(f'Deployed {name} at {route_prefix}') @@ -213,30 +283,44 @@ def launch(self) -> None: # Apply Ray Serve patches before initializing Ray apply_ray_serve_patches() + # Initialize telemetry if configured + telemetry = self.config.telemetry + if telemetry.enabled: + from twinkle.server.telemetry import init_telemetry + init_telemetry(telemetry) + # Export config to env vars for Ray worker processes + os.environ['TWINKLE_TELEMETRY_ENABLED'] = '1' + os.environ['TWINKLE_TELEMETRY_DEBUG'] = '1' if telemetry.debug else '0' + os.environ['TWINKLE_TELEMETRY_SERVICE'] = telemetry.service_name + os.environ['TWINKLE_TELEMETRY_ENDPOINT'] = telemetry.otlp_endpoint + os.environ['TWINKLE_TELEMETRY_INTERVAL'] = str(telemetry.export_interval_ms) + + # Export top-level persistence to env vars so any worker + # (not just Gateway) can build the same backend on first call to + # get_server_state(). + persistence = self.config.persistence + for k, v in persistence.to_env_vars().items(): + os.environ[k] = v + logger.info(f'Persistence backend configured: mode={persistence.mode}') + self._init_ray() self._start_serve() - applications = self.config.get('applications', []) + applications = self.config.applications if not applications: logger.warning('No applications configured') return - for app_config in applications: - if isinstance(app_config, dict): - self._deploy_application(app_config) - else: - self._deploy_application(dict(app_config)) + for app_spec in applications: + self._deploy_application(app_spec) - http_options = self.config.get('http_options', {}) - host = http_options.get('host', 'localhost') - port = http_options.get('port', 8000) + host = self.config.http_options.host + port = self.config.http_options.port print('\nAll applications started!') print('Endpoints:') - for app_config in applications: - route_prefix = app_config.get('route_prefix', '/') if isinstance(app_config, - dict) else app_config.route_prefix - print(f' - http://{host}:{port}{route_prefix}') + for app_spec in applications: + print(f' - http://{host}:{port}{app_spec.route_prefix}') # Graceful shutdown via signal handling shutdown_event = threading.Event() @@ -265,59 +349,38 @@ def from_yaml( config_path: str | Path, ray_namespace: str | None = None, ) -> ServerLauncher: - """ - Create a ServerLauncher from a YAML config file. - - Args: - config_path: Path to the YAML config file - ray_namespace: Override Ray namespace from config + """Build a ``ServerLauncher`` from a YAML config file. - Returns: - Configured ServerLauncher instance + Thin wrapper over :meth:`ServerConfig.from_yaml`. ``FileNotFoundError`` + / ``ConfigParseError`` / ``pydantic.ValidationError`` propagate so the + caller can surface a precise message before the launcher is constructed. """ - from omegaconf import OmegaConf - - config_path = Path(config_path) - if not config_path.exists(): - raise FileNotFoundError(f'Config file not found: {config_path}') - - config = OmegaConf.load(config_path) - config_dict = OmegaConf.to_container(config, resolve=True) - + config = ServerConfig.from_yaml(config_path) return cls( - config=config_dict, - ray_namespace=ray_namespace or config_dict.get('ray_namespace'), + config=config, + ray_namespace=ray_namespace or config.ray_namespace, ) def launch_server( - config: dict[str, Any] | None = None, + config: ServerConfig | None = None, config_path: str | Path | None = None, ray_namespace: str | None = None, ) -> None: - """ - Launch a twinkle server with flexible configuration options. + """Launch a twinkle server. - This is the main entry point for launching servers programmatically. - The call blocks until a SIGINT/SIGTERM signal is received. - - Args: - config: Configuration dictionary (takes precedence over config_path) - config_path: Path to YAML config file - ray_namespace: Ray namespace + Exactly one of ``config`` (a :class:`ServerConfig` instance) or + ``config_path`` (a YAML file) must be provided. The call blocks until a + SIGINT/SIGTERM signal is received. Raises: - ValueError: If neither config nor config_path is provided + ValueError: neither ``config`` nor ``config_path`` was provided. + TypeError: ``config`` is not a :class:`ServerConfig` instance — raw + dicts are no longer accepted (R6.6). Examples: - # From YAML config launch_server(config_path="server_config.yaml") - - # From Python dict - launch_server(config={ - "http_options": {"host": "0.0.0.0", "port": 8000}, - "applications": [...] - }) + launch_server(config=ServerConfig(...)) """ if config is None and config_path is None: raise ValueError("Either 'config' or 'config_path' must be provided") @@ -325,7 +388,7 @@ def launch_server( if config is not None: launcher = ServerLauncher( config=config, - ray_namespace=ray_namespace or config.get('ray_namespace'), + ray_namespace=ray_namespace or config.ray_namespace, ) else: launcher = ServerLauncher.from_yaml( diff --git a/src/twinkle/server/model/__init__.py b/src/twinkle/server/model/__init__.py index 1a203083e..8499387b1 100644 --- a/src/twinkle/server/model/__init__.py +++ b/src/twinkle/server/model/__init__.py @@ -1,3 +1,19 @@ -from .app import build_model_app +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Model deployment package. + +``build_model_app`` is exposed lazily via ``__getattr__`` so importing the +mock backend (``twinkle.server.model.backends.mock_model``) on a CPU-only +host doesn't pull in torch/transformers via ``app.py`` at package-init time +(R1.2, R4.3). +""" +from __future__ import annotations __all__ = ['build_model_app'] + + +def __getattr__(name: str): + if name == 'build_model_app': + from .app import build_model_app + + return build_model_app + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 5d0bc2285..fca378c0e 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -15,9 +15,11 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.exceptions import ConfigError +from twinkle.server.state import ServerState, get_server_state +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.lifecycle import AdapterManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger @@ -28,6 +30,35 @@ logger = get_logger() +_MODEL_BACKENDS: tuple[str, ...] = ('mock', 'transformers', 'megatron') + + +def _validate_model_backend(backend: Any) -> str: + """Pure validation of the ``backend`` selector (R3.9). + + Raises :class:`ConfigError` (naming the field, value, and allowed set) + when ``backend`` is missing, empty, non-string, or not exactly one of + the permitted values. No imports or side effects. + """ + if not isinstance(backend, str) or backend == '' or backend not in _MODEL_BACKENDS: + raise ConfigError(field='backend', value=backend, allowed=list(_MODEL_BACKENDS)) + return backend + + +def _dispatch_model_backend(backend: str, ctor_kwargs: dict[str, Any]) -> Any: + """Instantiate the model backend selected by an already-validated ``backend``.""" + if backend == 'mock': + from .backends.mock_model import TwinkleCompatMockModel + + return TwinkleCompatMockModel(**ctor_kwargs) + if backend == 'megatron': + from .backends.megatron_model import TwinkleCompatMegatronModel + + return TwinkleCompatMegatronModel(**ctor_kwargs) + from .backends.transformers_model import TwinkleCompatTransformersModel + + return TwinkleCompatTransformersModel(**ctor_kwargs) + class ModelManagement(TaskQueueMixin, AdapterManagerMixin): """Unified model management service. @@ -45,41 +76,46 @@ def __init__(self, nproc_per_node: int, device_group: dict[str, Any], device_mesh: dict[str, Any], - use_megatron: bool = False, + backend: str, adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) + # R3.9: validate ``backend`` BEFORE any side effect (twinkle.initialize, + # DeviceGroup construction, replica registration). An invalid value + # never produces a partial backend nor reaches a ready state. + backend = _validate_model_backend(backend) + self.backend = backend + # Skip twinkle.initialize for the mock backend (R3.7) — the largest + # startup-time saving and the only way to start without CUDA/torch. + if backend != 'mock': + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False, + ) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.use_megatron = use_megatron + self.device_group = None + self.device_mesh = None self.replica_id = serve.get_replica_context().replica_id.unique_id self.max_loras = kwargs.get('max_loras', 5) self.base_model = model_id - # Choose model backend - if use_megatron: - from ..model.backends.megatron_model import TwinkleCompatMegatronModel - - self.model = TwinkleCompatMegatronModel( - model_id=model_id, + ctor_kwargs: dict[str, Any] = {'model_id': model_id, **kwargs} + if backend != 'mock': + ctor_kwargs.update( device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=self.replica_id, - **kwargs) - else: - from ..model.backends.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) + ) + self.model = _dispatch_model_backend(backend, ctor_kwargs) - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() self._replica_registered = False # Initialize mixins @@ -93,6 +129,24 @@ async def _ensure_replica_registered(self): await self.state.register_replica(self.replica_id, self.max_loras) self._replica_registered = True + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + Cannot run in FastAPI ``lifespan``: Ray Serve binds + ``serve.get_replica_context().servable_object`` AFTER the lifespan + startup phase, so a lifespan call has no ``self`` to reach. By the + time a request arrives, the binding exists. ``state.start_cleanup_task`` + is itself idempotent via ``_cleanup_running``, but the per-instance + flag avoids the await on every subsequent request. + """ + if getattr(self, '_state_cleanup_started', False): + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True + @serve.multiplexed(max_num_models_per_replica=5) async def _sticky_entry(self, sticky_key: str): return sticky_key @@ -106,6 +160,7 @@ async def _ensure_sticky(self): async def _on_request_start(self, request: Request) -> str: await self._ensure_sticky() await self._ensure_replica_registered() + await self._ensure_state_cleanup_started() token = get_token_from_request(request) return token @@ -133,7 +188,7 @@ def build_model_app(model_id: str, device_group: dict[str, Any], device_mesh: dict[str, Any], deploy_options: dict[str, Any], - use_megatron: bool = False, + backend: str, adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): @@ -147,7 +202,9 @@ def build_model_app(model_id: str, device_group: Device group configuration dict device_mesh: Device mesh configuration dict for tensor parallelism deploy_options: Ray Serve deployment options - use_megatron: Whether to use Megatron backend (vs Transformers) + backend: Model backend selector — ``mock`` | ``transformers`` | ``megatron`` + (R3.1-3.3, R3.9). Validated up front; bad values raise + :class:`ConfigError` before any side effect. adapter_config: Adapter lifecycle config (timeout, per-token limits) queue_config: Task queue configuration (rate limiting, etc.) **kwargs: Additional model initialization arguments @@ -155,18 +212,27 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ + # Fail fast on bad backend values at builder time (the launcher imports + # this builder at startup, so the error surfaces before deployment). + backend = _validate_model_backend(backend) # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + def get_self() -> ModelManagement: return serve.get_replica_context().servable_object @asynccontextmanager async def lifespan(app: FastAPI): - try: - await get_self()._ensure_replica_registered() - except Exception as e: - logger.warning(f'Failed to register replica at startup: {e}') + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + # NOTE: ``state.start_cleanup_task()`` and ``_ensure_replica_registered()`` + # cannot run here — Ray Serve binds ``servable_object`` AFTER lifespan + # startup, so ``get_self()`` returns ``None`` and the call would crash. + # They are lazy-started from the first request via + # ``_on_request_start`` → ``_ensure_state_cleanup_started`` and + # ``_ensure_replica_registered`` respectively. yield try: await get_self().shutdown() @@ -179,6 +245,10 @@ async def lifespan(app: FastAPI): async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + # Registration order: FastAPI runs middleware LIFO. Tracing first → metrics + # last makes metrics the outermost wrapper, so its latency observation + # covers the full request path including tracing overhead. + app.middleware('http')(create_tracing_middleware('Model')) app.middleware('http')(create_metrics_middleware('Model')) _register_tinker_routes(app, get_self) @@ -190,8 +260,16 @@ async def verify_token(request: Request, call_next): request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter), )( ModelManagementWithIngress) - return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, - use_megatron, adapter_config, queue_config, **kwargs) + return DeploymentClass.options(**deploy_options).bind( + model_id, + nproc_per_node, + device_group, + device_mesh, + backend, + adapter_config, + queue_config, + **kwargs, + ) build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/model/backends/mock_model.py b/src/twinkle/server/model/backends/mock_model.py new file mode 100644 index 000000000..7215350b8 --- /dev/null +++ b/src/twinkle/server/model/backends/mock_model.py @@ -0,0 +1,241 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Numpy-only mock model backend (R1, R3, R4). + +Provides ``TwinkleCompatMockModel``, a stand-in for the real +``TwinkleCompatTransformersModel`` whose only purpose is to exercise the +server's HTTP and dispatch paths on a CPU-only host with no torch / +transformers / vllm / megatron installed. Determinism is keyed by +``(model_id, adapter_name, seed, input_shape)`` so repeated requests with the +same payload produce identical numpy-derived results. + +This module deliberately avoids importing ``torch``, ``transformers``, +``vllm``, ``megatron`` or any module whose own imports would pull them +in transitively (e.g. ``twinkle.server.model.backends.common``). The class is +duck-typed against ``TwinkleCompatModelBase`` rather than subclassing it — +the base class lives in a torch-importing module and would defeat the +import-isolation requirement (R1.2). +""" +from __future__ import annotations + +import hashlib +import numpy as np +from typing import Any + + +def _seed_for(model_id: str, adapter_name: str | None, seed: int, *extra: Any) -> int: + """Deterministic per-request RNG seed derived from string/int components. + + Uses SHA-256 over a canonical string form rather than Python's built-in + ``hash()``: the latter is salted per process (PYTHONHASHSEED) for tuples + containing strings, which would make identical requests on different + replicas / restarts produce different outputs and break R2.5 / R4.4 / R4.5 + across processes. + """ + parts = (str(model_id), str(adapter_name), str(int(seed)), *(repr(x) for x in extra)) + digest = hashlib.sha256('\x1f'.join(parts).encode('utf-8')).digest() + # numpy seeds must fit in uint32; take the first 4 bytes of the digest. + return int.from_bytes(digest[:4], 'big') + + +class TwinkleCompatMockModel: + """Numpy-only mock model. + + Public API mirrors the methods that the model FastAPI handlers call on + ``self.model``. Every method either returns a deterministic numpy-derived + payload or completes as a no-op without raising. + """ + + def __init__( + self, + model_id: str, + *, + hidden_size: int = 8, + vocab_size: int = 32, + seed: int = 0, + **kwargs: Any, + ) -> None: + self.model_id = model_id + self._hidden_size = int(hidden_size) + self._vocab_size = int(vocab_size) + self._rng_seed = int(seed) + # adapter_name -> arbitrary config payload + self._adapters: dict[str, dict[str, Any]] = {} + + # ----- Forward family (R1.3) ----------------------------------------- # + + def _build_forward_result( + self, + inputs: Any, + adapter_name: str | None, + *, + loss_value: float = 0.0, + ) -> list[dict[str, Any]]: + """Return one deterministic synthetic per-input record. + + Shapes are derived from the input so ``_tinker_build_output``-style + callers see correctly-sized arrays. + """ + seq_lens = _input_seq_lengths(inputs) + out: list[dict[str, Any]] = [] + for idx, seq_len in enumerate(seq_lens): + rng = np.random.default_rng(_seed_for(self.model_id, adapter_name, self._rng_seed, idx, seq_len)) + logprobs = rng.uniform(-2.0, 0.0, size=seq_len).astype(np.float32) + elementwise_loss = rng.uniform(0.0, 1.0, size=seq_len).astype(np.float32) + out.append({ + 'logprobs': logprobs.tolist(), + 'elementwise_loss': elementwise_loss.tolist(), + 'loss': float(loss_value), + }) + return out + + def tinker_forward_only(self, *, inputs: Any, adapter_name: str | None = None, **kwargs: Any) -> list[Any]: + return [self._build_forward_result(inputs, adapter_name), 0.0] + + def tinker_forward_backward(self, *, inputs: Any, adapter_name: str, loss_fn: str, **kwargs: Any) -> list[Any]: + loss_seed = _seed_for(self.model_id, adapter_name, self._rng_seed, 'loss', loss_fn) + loss = float(np.random.default_rng(loss_seed).uniform(0.0, 1.0)) + return [self._build_forward_result(inputs, adapter_name, loss_value=loss), loss] + + def forward(self, *, inputs: Any, **kwargs: Any) -> list[dict[str, Any]]: + return self._build_forward_result(inputs, kwargs.get('adapter_name')) + + def forward_only(self, *, inputs: Any, **kwargs: Any) -> list[dict[str, Any]]: + return self._build_forward_result(inputs, kwargs.get('adapter_name')) + + def forward_backward(self, *, inputs: Any, **kwargs: Any) -> list[Any]: + loss = float(np.random.default_rng(self._rng_seed).uniform(0.0, 1.0)) + return [self._build_forward_result(inputs, kwargs.get('adapter_name'), loss_value=loss), loss] + + def calculate_loss(self, *, inputs: Any, **kwargs: Any) -> float: + return float(np.random.default_rng(self._rng_seed).uniform(0.0, 1.0)) + + # ----- Backward / optimizer (R1.4) ----------------------------------- # + + def backward(self, *args: Any, **kwargs: Any) -> None: + return None + + def step(self, *args: Any, **kwargs: Any) -> None: + return None + + def zero_grad(self, *args: Any, **kwargs: Any) -> None: + return None + + def lr_step(self, *args: Any, **kwargs: Any) -> None: + return None + + def clip_grad_norm(self, *args: Any, **kwargs: Any) -> float: + return 0.0 + + def clip_grad_and_step(self, *args: Any, **kwargs: Any) -> None: + return None + + def tinker_step(self, *, adam_params: Any = None, **kwargs: Any) -> None: + return None + + def tinker_calculate_metric(self, is_training: bool, **kwargs: Any) -> dict[str, float]: + return {'loss': 0.5, 'grad_norm': 0.1} + + def calculate_metric(self, *args: Any, **kwargs: Any) -> dict[str, float]: + return {'loss': 0.5, 'grad_norm': 0.1} + + def tinker_load(self, checkpoint_dir: str, **kwargs: Any) -> None: + return None + + # ----- Configuration setters (R1.4) ---------------------------------- # + + def set_loss(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_optimizer(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_lr_scheduler(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_template(self, *args: Any, **kwargs: Any) -> None: + return None + + def set_processor(self, *args: Any, **kwargs: Any) -> None: + return None + + def add_metric(self, *args: Any, **kwargs: Any) -> None: + return None + + def apply_patch(self, *args: Any, **kwargs: Any) -> None: + return None + + # ----- Persistence stubs (R1.4) -------------------------------------- # + + def save(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {'status': 'ok', 'path': None} + + def load(self, *args: Any, **kwargs: Any) -> None: + return None + + def resume_from_checkpoint(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {'status': 'ok', 'progress': {}} + + def get_state_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {} + + def get_train_configs(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return {} + + # ----- Adapter management (R1.5, R1.6, R1.7) ------------------------- # + + def add_adapter(self, adapter_name: str, **cfg: Any) -> None: + """Record an adapter without loading real weights (R1.5).""" + self._adapters[adapter_name] = dict(cfg) + + def add_adapter_to_model(self, adapter_name: str, config: Any = None, **cfg: Any) -> None: + merged: dict[str, Any] = dict(cfg) + if config is not None: + merged.setdefault('config', config) + self._adapters[adapter_name] = merged + + def remove_adapter(self, adapter_name: str) -> None: + """Remove ``adapter_name`` (R1.6); raise on absent (R1.7).""" + if adapter_name not in self._adapters: + raise KeyError(f'adapter not present: {adapter_name}') + del self._adapters[adapter_name] + + def has_adapter(self, adapter_name: str) -> bool: + return adapter_name in self._adapters + + +def _input_seq_lengths(inputs: Any) -> list[int]: + """Best-effort recovery of per-datum sequence lengths from heterogeneous inputs. + + The real backend pulls lengths from ``Datum.loss_fn_inputs['target_tokens']``, + but we want to stay numpy-only and avoid importing the tinker types. Falls + back to ``[1]`` so callers always get at least one record back. + """ + if inputs is None: + return [1] + if isinstance(inputs, list): + if not inputs: + return [1] + out: list[int] = [] + for item in inputs: + length = _seq_length_of(item) + out.append(length) + return out + return [_seq_length_of(inputs)] + + +def _seq_length_of(item: Any) -> int: + # Datum-like: model_input.tokens or loss_fn_inputs['target_tokens'] + for attr in ('model_input', 'inputs', 'tokens'): + v = getattr(item, attr, None) + if v is None: + continue + tokens = getattr(v, 'tokens', v) + if hasattr(tokens, '__len__'): + return max(1, len(tokens)) + if isinstance(item, dict): + for k in ('input_ids', 'tokens', 'target_tokens'): + if k in item and hasattr(item[k], '__len__'): + return max(1, len(item[k])) + if hasattr(item, '__len__'): + return max(1, len(item)) + return 1 diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 40fdadbea..9b6e90b06 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -14,15 +14,17 @@ from __future__ import annotations import os +from contextlib import asynccontextmanager from fastapi import FastAPI, Request from ray import serve from typing import Any, Dict, Optional import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle.server.state import ServerState, get_server_state +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.lifecycle import ProcessorManagerMixin from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token from .twinkle_handlers import _register_processor_routes @@ -62,7 +64,7 @@ def __init__(self, # processor objects keyed by processor_id self.resource_dict: dict[str, Any] = {} - self.state: ServerStateProxy = get_server_state() + self.state: ServerState = get_server_state() _cfg = processor_config or {} _env_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) @@ -81,6 +83,22 @@ async def _ensure_sticky(self): await self._sticky_entry(sticky_key) # Lazy-start countdown task on first request (requires running event loop) self._ensure_countdown_started() + await self._ensure_state_cleanup_started() + + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + See ``model/app.py``: ``servable_object`` is unbound during lifespan, + so we lazy-init here. Processor has no ``_on_request_start`` hook; + every routed request flows through ``_ensure_sticky`` first. + """ + if getattr(self, '_state_cleanup_started', False): + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True def _on_processor_expired(self, processor_id: str) -> None: """Called by the countdown thread when a processor's session expires.""" @@ -117,19 +135,35 @@ def build_processor_app(ncpu_proc_per_node: int, Returns: Ray Serve deployment bound with configuration. """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). - app = FastAPI() + + def get_self() -> ProcessorManagement: + return serve.get_replica_context().servable_object + + @asynccontextmanager + async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + # NOTE: ``state.start_cleanup_task()`` cannot run here — Ray Serve binds + # ``servable_object`` AFTER lifespan startup. Lazy-started from the + # first request via ``_ensure_sticky`` → ``_ensure_state_cleanup_started``. + yield + + app = FastAPI(lifespan=lifespan) @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + # Registration order: FastAPI runs middleware LIFO. Tracing first → metrics + # last makes metrics the outermost wrapper, so its latency observation + # covers the full request path including tracing overhead. + app.middleware('http')(create_tracing_middleware('Processor')) app.middleware('http')(create_metrics_middleware('Processor')) - def get_self() -> ProcessorManagement: - return serve.get_replica_context().servable_object - _register_processor_routes(app, get_self) ProcessorManagementWithIngress = serve.ingress(app)(ProcessorManagement) diff --git a/src/twinkle/server/processor/twinkle_handlers.py b/src/twinkle/server/processor/twinkle_handlers.py index 66799ee82..7e6a1f7e3 100644 --- a/src/twinkle/server/processor/twinkle_handlers.py +++ b/src/twinkle/server/processor/twinkle_handlers.py @@ -18,6 +18,8 @@ from .app import ProcessorManagement import twinkle_client.types as types +from twinkle.server.telemetry.correlation import SESSION_ID, TOKEN_ID +from twinkle.server.telemetry.tracing import traced_operation from twinkle.server.utils.validation import get_session_id_from_request, get_token_from_request from twinkle.utils.logger import get_logger from twinkle_client.common.serialize import deserialize_object @@ -77,7 +79,14 @@ def _do_create(): return getattr(processor_module, class_type)( remote_group=_remote_group, device_mesh=_device_mesh, instance_id=processor_id, **resolved_kwargs) - processor = await asyncio.get_running_loop().run_in_executor(None, _do_create) + # R10.3: span the primary processor.create op with token + session correlation. + with traced_operation( + f'processor.create.{processor_type_name}.{class_type}', + attrs={ + TOKEN_ID: token, + SESSION_ID: session_id, + }): + processor = await asyncio.get_running_loop().run_in_executor(None, _do_create) self.resource_dict[processor_id] = processor return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) @@ -117,7 +126,11 @@ def _do_call(): except StopIteration: return True, None - is_exhausted, result = await asyncio.get_running_loop().run_in_executor(None, _do_call) + # R10.3: span the primary processor.call op so each invocation is observable. + with traced_operation( + f'processor.call.{function_name}', + attrs={TOKEN_ID: get_token_from_request(request)}): + is_exhausted, result = await asyncio.get_running_loop().run_in_executor(None, _do_call) if function_name == '__next__': if is_exhausted: diff --git a/src/twinkle/server/sampler/__init__.py b/src/twinkle/server/sampler/__init__.py index 58db90983..3569f1056 100644 --- a/src/twinkle/server/sampler/__init__.py +++ b/src/twinkle/server/sampler/__init__.py @@ -1,3 +1,18 @@ -from .app import build_sampler_app +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Sampler deployment package. + +``build_sampler_app`` is exposed lazily via ``__getattr__`` so importing the +mock backend (``twinkle.server.sampler.backends.mock_sampler``) on a CPU-only +host doesn't pull in vllm via ``app.py`` at package-init time (R2.2, R4.3). +""" +from __future__ import annotations __all__ = ['build_sampler_app'] + + +def __getattr__(name: str): + if name == 'build_sampler_app': + from .app import build_sampler_app + + return build_sampler_app + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 177344727..2cebf7166 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -7,14 +7,17 @@ """ from __future__ import annotations +from contextlib import asynccontextmanager from fastapi import FastAPI, Request from ray import serve from typing import Any, Dict, Optional import twinkle from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.exceptions import ConfigError +from twinkle.server.state import ServerState, get_server_state +from twinkle.server.telemetry.tracing import create_tracing_middleware from twinkle.server.utils.metrics import create_metrics_middleware -from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger @@ -24,6 +27,40 @@ logger = get_logger() +_SAMPLER_TYPES: tuple[str, ...] = ('mock', 'vllm', 'torch') + + +def _validate_sampler_type(sampler_type: Any) -> str: + """Pure validation of the ``sampler_type`` selector (R3.10). + + Raises :class:`ConfigError` (naming the field, value, and allowed set) + when the value is missing, empty, non-string, or not exactly one of the + permitted values. No imports or side effects. + """ + if (not isinstance(sampler_type, str) or sampler_type == '' or sampler_type not in _SAMPLER_TYPES): + raise ConfigError(field='sampler_type', value=sampler_type, allowed=list(_SAMPLER_TYPES)) + return sampler_type + + +def _dispatch_sampler_backend(sampler_type: str, ctor_kwargs: dict[str, Any]) -> Any: + """Instantiate the sampler selected by an already-validated ``sampler_type``.""" + if sampler_type == 'mock': + from .backends.mock_sampler import MockSampler + + # MockSampler accepts only model_id/seed/vocab_size — strip extras silently. + return MockSampler( + model_id=ctor_kwargs.get('model_id'), + seed=ctor_kwargs.get('seed', 0), + vocab_size=ctor_kwargs.get('vocab_size', 32), + ) + if sampler_type == 'torch': + from twinkle.sampler import TorchSampler # type: ignore[attr-defined] + + return TorchSampler(**ctor_kwargs) + from twinkle.sampler import vLLMSampler + + return vLLMSampler(**ctor_kwargs) + class SamplerManagement(TaskQueueMixin): """Unified sampler management service. @@ -40,35 +77,49 @@ def __init__(self, nproc_per_node: int, device_group: dict[str, Any], device_mesh: dict[str, Any], - sampler_type: str = 'vllm', + sampler_type: str, engine_args: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) + # R3.10: validate ``sampler_type`` BEFORE any side effect. + sampler_type = _validate_sampler_type(sampler_type) + # Skip twinkle.initialize for the mock backend (R3.8) — start without + # CUDA/torch/vllm. + if sampler_type != 'mock': + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False, + ) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.device_group = None + self.device_mesh = None self.sampler_type = sampler_type self.model_id = model_id replica_context = serve.get_replica_context() replica_id = replica_context.replica_id.unique_id - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - - self.state: ServerStateProxy = get_server_state() + sampler_kwargs: dict[str, Any] = {'model_id': model_id} + if sampler_type != 'mock': + sampler_kwargs.update( + engine_args=engine_args or {}, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **{ + k: v + for k, v in kwargs.items() if k not in ('engine_args', ) + }, + ) + self.sampler = _dispatch_sampler_backend(sampler_type, sampler_kwargs) + + self.state: ServerState = get_server_state() # Initialize task queue mixin self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Sampler') @@ -81,8 +132,24 @@ async def _ensure_sticky(self): sticky_key = serve.get_multiplexed_model_id() await self._sticky_entry(sticky_key) + async def _ensure_state_cleanup_started(self) -> None: + """Start ServerState cleanup + metrics loops on the first request. + + See the matching helper in ``model/app.py`` for the lifespan-timing + rationale: ``servable_object`` is unbound during lifespan, so we lazy- + init here. + """ + if getattr(self, '_state_cleanup_started', False): + return + try: + await self.state.start_cleanup_task() + except Exception as e: + logger.warning(f'Failed to start ServerState cleanup task: {e}') + self._state_cleanup_started = True + async def _on_request_start(self, request: Request) -> str: await self._ensure_sticky() + await self._ensure_state_cleanup_started() token = get_token_from_request(request) return token @@ -92,7 +159,7 @@ def build_sampler_app(model_id: str, device_group: dict[str, Any], device_mesh: dict[str, Any], deploy_options: dict[str, Any], - sampler_type: str = 'vllm', + sampler_type: str, engine_args: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): @@ -107,7 +174,9 @@ def build_sampler_app(model_id: str, device_group: Device group configuration dict device_mesh: Device mesh configuration dict for parallelism deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') + sampler_type: Sampler selector — ``mock`` | ``vllm`` | ``torch`` (R3.4-3.6, + R3.10). Validated up front; bad values raise :class:`ConfigError` + before any side effect. engine_args: Additional engine arguments for the sampler queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) **kwargs: Additional arguments passed to the sampler @@ -115,22 +184,41 @@ def build_sampler_app(model_id: str, Returns: Ray Serve deployment bound with configuration """ + # Fail fast at builder time on bad sampler_type values. + sampler_type = _validate_sampler_type(sampler_type) + # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). + + def get_self() -> SamplerManagement: + return serve.get_replica_context().servable_object + + @asynccontextmanager + async def lifespan(app: FastAPI): + # Initialize telemetry in worker process (after deserialization) + from twinkle.server.telemetry.worker_init import ensure_telemetry_initialized + ensure_telemetry_initialized() + # NOTE: ``state.start_cleanup_task()`` cannot run here — Ray Serve binds + # ``servable_object`` AFTER lifespan startup. Lazy-started from the + # first request via ``_on_request_start`` → ``_ensure_state_cleanup_started``. + yield + app = FastAPI( title='Unified Sampler', description='REST API for distributed text generation inference (Tinker + Twinkle)', - version='1.0.0') + version='1.0.0', + lifespan=lifespan) @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + # Registration order: FastAPI runs middleware LIFO. Tracing first → metrics + # last makes metrics the outermost wrapper, so its latency observation + # covers the full request path including tracing overhead. + app.middleware('http')(create_tracing_middleware('Sampler')) app.middleware('http')(create_metrics_middleware('Sampler')) - def get_self() -> SamplerManagement: - return serve.get_replica_context().servable_object - # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time _register_tinker_sampler_routes(app, get_self) _register_twinkle_sampler_routes(app, get_self) diff --git a/src/twinkle/server/sampler/backends/__init__.py b/src/twinkle/server/sampler/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/twinkle/server/sampler/backends/mock_sampler.py b/src/twinkle/server/sampler/backends/mock_sampler.py new file mode 100644 index 000000000..786591041 --- /dev/null +++ b/src/twinkle/server/sampler/backends/mock_sampler.py @@ -0,0 +1,118 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Numpy-only mock sampler backend (R2, R3). + +Implements the same surface as :class:`twinkle.sampler.base.Sampler` — +``sample``, ``apply_patch``, ``add_adapter_to_sampler`` — using only numpy. +The class is intentionally **duck-typed** rather than subclassed because +``twinkle.sampler.__init__`` eagerly imports the vLLM engine, which would +defeat the no-vllm-import requirement (R2.2) on a CPU-only host. + +Outputs are deterministic — keyed by ``(model_id, adapter_name, seed, +prompt_index, sample_index)`` — so repeated calls with the same parameters +produce identical token sequences and logprobs (R2.5). +""" +from __future__ import annotations + +import hashlib +import numpy as np +from typing import Any, List, Optional + +# These data containers don't pull torch / vllm. +from twinkle.data_format import SampledSequence, SampleResponse, SamplingParams + + +def _stable_seed(*parts: Any) -> int: + """Cross-process-stable numpy seed (uint32) derived from string parts. + + Python's built-in ``hash()`` of a tuple containing strings is salted per + process (PYTHONHASHSEED), which would make identical sample requests on + different replicas / restarts produce different outputs and violate R2.5. + Use a stable digest instead. + """ + canonical = '\x1f'.join(str(p) for p in parts).encode('utf-8') + digest = hashlib.sha256(canonical).digest() + return int.from_bytes(digest[:4], 'big') + + +class MockSampler: + """Deterministic numpy-only sampler. + + Provides the public methods callable from the sampler app and the Tinker / + Twinkle handlers; ``has_adapter`` is added for convenience and tests. + """ + + def __init__(self, model_id: str, *, seed: int = 0, vocab_size: int = 32, **kwargs: Any) -> None: + self.model_id = model_id + self._seed = int(seed) + self._vocab_size = int(vocab_size) + self._adapters: dict[str, Any] = {} + # Match the Sampler base attributes so duck-typed callers don't surprise. + self.engine = None + self.template = None + + # ----- Sampler interface (R2.1, R2.3, R2.4, R2.5, R2.6) -------------- # + + def sample( + self, + inputs: Any, + sampling_params: SamplingParams | None = None, + adapter_name: str = '', + *, + num_samples: int = 1, + **kwargs: Any, + ) -> list[SampleResponse]: + # The real ``vLLMSampler.sample`` accepts extra keyword arguments + # (``adapter_path``, ``adapter_uri``, etc.) that the Tinker / Twinkle + # handlers always forward. Swallow them here so the mock backend + # stays callable through the same handler call sites without a + # TypeError. + max_tokens = self._resolve_max_tokens(sampling_params) + if max_tokens is None or max_tokens < 1: + raise ValueError(f'max_tokens must be >= 1, got {max_tokens!r} ' + '(set sampling_params.max_tokens to a positive integer)') + + normalized = self._normalize_inputs(inputs) + responses: list[SampleResponse] = [] + for prompt_idx, _ in enumerate(normalized): + sequences: list[SampledSequence] = [] + for sample_idx in range(num_samples): + seed = _stable_seed(self.model_id, adapter_name, self._seed, prompt_idx, sample_idx) + rng = np.random.default_rng(seed) + tokens = [int(t) for t in rng.integers(low=0, high=max(1, self._vocab_size), size=max_tokens)] + logprobs_per_token = rng.uniform(-2.0, 0.0, size=max_tokens).astype(float).tolist() + # One logprob entry per emitted token (R2.4) — list of (id, logprob). + logprobs = [[(tok, float(lp))] for tok, lp in zip(tokens, logprobs_per_token)] + sequences.append(SampledSequence( + stop_reason='length', + tokens=tokens, + logprobs=logprobs, + )) + responses.append(SampleResponse(sequences=sequences)) + return responses + + def apply_patch(self, patch_cls: Any, **kwargs: Any) -> None: + return None + + # ----- Adapter management (R2.7) ------------------------------------- # + + def add_adapter_to_sampler(self, adapter_name: str, config: Any) -> None: + self._adapters[adapter_name] = config + + def has_adapter(self, adapter_name: str) -> bool: + return adapter_name in self._adapters + + # ----- Helpers ------------------------------------------------------- # + + @staticmethod + def _normalize_inputs(inputs: Any) -> list[Any]: + if inputs is None: + return [None] + if isinstance(inputs, list): + return inputs if inputs else [None] + return [inputs] + + @staticmethod + def _resolve_max_tokens(params: SamplingParams | None) -> int | None: + if params is None: + return None + return getattr(params, 'max_tokens', None) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index b27ec23d0..2d4dc86e5 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -19,6 +19,8 @@ import twinkle_client.types as types from twinkle.data_format import InputFeature, SamplingParams, Trajectory +from twinkle.server.telemetry.correlation import MODEL_ID, TOKEN_ID +from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger logger = get_logger() @@ -165,7 +167,8 @@ async def set_template( ) -> types.SetTemplateResponse: """Set the chat template for encoding Trajectory inputs.""" extra_kwargs = body.model_extra or {} - self.sampler.set_template(body.template_cls, **extra_kwargs) + with traced_operation('sampler.set_template'): + self.sampler.set_template(body.template_cls, **extra_kwargs) return types.SetTemplateResponse() @app.post('/twinkle/add_adapter_to_sampler', response_model=types.AddAdapterResponse) @@ -181,7 +184,8 @@ async def add_adapter_to_sampler( from peft import LoraConfig config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - self.sampler.add_adapter_to_sampler(full_adapter_name, config) + with traced_operation('sampler.add_adapter_to_sampler', attrs={MODEL_ID: full_adapter_name}): + self.sampler.add_adapter_to_sampler(full_adapter_name, config) return types.AddAdapterResponse(adapter_name=full_adapter_name) @@ -193,4 +197,5 @@ async def apply_patch( ) -> None: extra_kwargs = body.model_extra or {} patch_cls = deserialize_object(body.patch_cls) - self.sampler.apply_patch(patch_cls, **extra_kwargs) + with traced_operation('sampler.apply_patch'): + self.sampler.apply_patch(patch_cls, **extra_kwargs) diff --git a/src/twinkle/server/utils/state/__init__.py b/src/twinkle/server/state/__init__.py similarity index 58% rename from src/twinkle/server/utils/state/__init__.py rename to src/twinkle/server/state/__init__.py index 0e34697ad..04af91e8d 100644 --- a/src/twinkle/server/utils/state/__init__.py +++ b/src/twinkle/server/state/__init__.py @@ -1,11 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .backend import PersistenceConfig, create_backend from .base import BaseManager from .config_manager import ConfigManager +from .config_signature import SignatureMismatchPolicy, compute_signature, validate_config_signature from .future_manager import FutureManager from .model_manager import ModelManager from .models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord +from .replica_registry import ReplicaRegistry from .sampling_manager import SamplingSessionManager -from .server_state import ServerState, ServerStateProxy, get_server_state +from .server_state import ServerState, get_server_state, reset_server_state_cache from .session_manager import SessionManager __all__ = [ @@ -24,6 +27,14 @@ 'ConfigManager', # Server state 'ServerState', - 'ServerStateProxy', + 'ReplicaRegistry', 'get_server_state', + 'reset_server_state_cache', + # Persistence backend factory + 'PersistenceConfig', + 'create_backend', + # Config signature validation + 'compute_signature', + 'validate_config_signature', + 'SignatureMismatchPolicy', ] diff --git a/src/twinkle/server/state/backend/__init__.py b/src/twinkle/server/state/backend/__init__.py new file mode 100644 index 000000000..1007e53ae --- /dev/null +++ b/src/twinkle/server/state/backend/__init__.py @@ -0,0 +1,14 @@ +from .base import StateBackend +from .factory import PersistenceConfig, create_backend +from .file_backend import FileBackend +from .memory_backend import MemoryBackend +from .redis_backend import RedisBackend + +__all__ = [ + 'StateBackend', + 'FileBackend', + 'MemoryBackend', + 'RedisBackend', + 'PersistenceConfig', + 'create_backend', +] diff --git a/src/twinkle/server/state/backend/base.py b/src/twinkle/server/state/backend/base.py new file mode 100644 index 000000000..863f6e99e --- /dev/null +++ b/src/twinkle/server/state/backend/base.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class StateBackend(ABC): + """Unified interface for state storage backends. + + All state management operations go through this interface, supporting + multiple backend implementations (memory, file, Redis). + """ + + @abstractmethod + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """Store key-value pair with optional TTL in seconds.""" + ... + + @abstractmethod + async def get(self, key: str) -> Any | None: + """Retrieve value, return None if not found or expired.""" + ... + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete key, silently ignore if not found.""" + ... + + @abstractmethod + async def exists(self, key: str) -> bool: + """Check if key exists and is not expired.""" + ... + + @abstractmethod + async def keys(self, pattern: str) -> list[str]: + """Return all key names matching the pattern. Supports * wildcard (e.g. 'session::*').""" + ... + + @abstractmethod + async def count(self, pattern: str) -> int: + """Count keys matching the pattern.""" + ... + + @abstractmethod + async def set_nx(self, key: str, value: Any) -> bool: + """Set if not exists. Return True if successfully set, False if key already exists.""" + ... + + @abstractmethod + async def close(self) -> None: + """Close backend connection / release resources.""" + ... + + @abstractmethod + async def health_check(self) -> bool: + """Check if backend is healthy and available.""" + ... diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py new file mode 100644 index 000000000..aac2756a3 --- /dev/null +++ b/src/twinkle/server/state/backend/factory.py @@ -0,0 +1,94 @@ +"""Backend factory for creating StateBackend instances based on configuration.""" +from __future__ import annotations + +import logging +import os +from pydantic import BaseModel, ConfigDict +from typing import Literal + +from .base import StateBackend +from .memory_backend import MemoryBackend + +logger = logging.getLogger(__name__) + +# Env var keys propagated by the launcher so that any Ray worker can rebuild +# the same PersistenceConfig regardless of which deployment initializes the +# ServerState actor first. +PERSISTENCE_ENV_KEYS: tuple[str, ...] = ( + 'TWINKLE_PERSISTENCE_MODE', + 'TWINKLE_PERSISTENCE_FILE_PATH', + 'TWINKLE_PERSISTENCE_REDIS_URL', + 'TWINKLE_PERSISTENCE_KEY_PREFIX', +) + + +class PersistenceConfig(BaseModel): + """Configuration for state persistence backend.""" + + model_config = ConfigDict(extra='forbid') + + mode: Literal['memory', 'file', 'redis'] = 'memory' + file_path: str | None = None # required for file mode + redis_url: str | None = None # required for redis mode + key_prefix: str = '' # optional global key prefix + + def to_env_vars(self) -> dict[str, str]: + """Serialize this config to env var key/value pairs for worker propagation.""" + env: dict[str, str] = {'TWINKLE_PERSISTENCE_MODE': self.mode} + if self.file_path: + env['TWINKLE_PERSISTENCE_FILE_PATH'] = self.file_path + if self.redis_url: + env['TWINKLE_PERSISTENCE_REDIS_URL'] = self.redis_url + if self.key_prefix: + env['TWINKLE_PERSISTENCE_KEY_PREFIX'] = self.key_prefix + return env + + @classmethod + def from_env(cls) -> PersistenceConfig | None: + """Reconstruct a PersistenceConfig from env vars set by the launcher. + + Returns ``None`` when ``TWINKLE_PERSISTENCE_MODE`` is unset, so callers + can distinguish "no env-configured persistence" from "memory mode". + """ + mode = os.environ.get('TWINKLE_PERSISTENCE_MODE') + if not mode: + return None + return cls( + mode=mode, + file_path=os.environ.get('TWINKLE_PERSISTENCE_FILE_PATH'), + redis_url=os.environ.get('TWINKLE_PERSISTENCE_REDIS_URL'), + key_prefix=os.environ.get('TWINKLE_PERSISTENCE_KEY_PREFIX', ''), + ) + + +def create_backend(config: PersistenceConfig | None = None) -> StateBackend: + """Create a StateBackend instance based on persistence configuration. + + Args: + config: Persistence configuration. Defaults to memory mode if None. + + Returns: + A configured StateBackend instance. + + Raises: + ValueError: If required config fields are missing for the selected mode. + ImportError: If required packages are not installed. + """ + if config is None: + config = PersistenceConfig() + + match config.mode: + case 'memory': + return MemoryBackend() + case 'file': + if not config.file_path: + raise ValueError('file_path is required for file persistence mode') + from .file_backend import FileBackend + return FileBackend(config.file_path) + case 'redis': + if not config.redis_url: + raise ValueError('redis_url is required for redis persistence mode') + from .redis_backend import RedisBackend + return RedisBackend(config.redis_url, key_prefix=config.key_prefix) + case _: + raise ValueError(f'Unknown persistence mode: {config.mode}') diff --git a/src/twinkle/server/state/backend/file_backend.py b/src/twinkle/server/state/backend/file_backend.py new file mode 100644 index 000000000..747c1dc61 --- /dev/null +++ b/src/twinkle/server/state/backend/file_backend.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import asyncio +import fcntl +import json +import os +import tempfile +import time +from fnmatch import fnmatch +from typing import Any + +from .base import StateBackend + + +class FileBackend(StateBackend): + """Local JSON file-based persistent state backend implementation. + + Storage format is a single JSON file: + ``{key: {"value": ..., "expire_at": float|null}}``. + File I/O is wrapped with ``asyncio.to_thread`` to avoid blocking the + event loop. Writes use temp file + ``os.replace`` for atomic replacement, + protected by ``fcntl.flock`` against multi-process concurrent writes. + """ + + def __init__(self, file_path: str) -> None: + self._file_path = file_path + self._init_file() + + def _init_file(self) -> None: + """Auto-create file or directory if not exists.""" + dir_path = os.path.dirname(self._file_path) + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(self._file_path): + with open(self._file_path, 'w', encoding='utf-8') as f: + json.dump({}, f) + + def _load_sync(self) -> dict[str, dict[str, Any]]: + """Synchronously read JSON file, return complete data dict.""" + try: + with open(self._file_path, encoding='utf-8') as f: + data = json.load(f) + except (json.JSONDecodeError, FileNotFoundError): + data = {} + return data + + def _save_sync(self, data: dict[str, dict[str, Any]]) -> None: + """Synchronous write: clean expired keys -> write temp file -> flock -> os.replace.""" + # Clean expired keys before writing + now = time.time() + data = {k: v for k, v in data.items() if v.get('expire_at') is None or v['expire_at'] > now} + + dir_path = os.path.dirname(self._file_path) or '.' + fd = tempfile.NamedTemporaryFile( + mode='w', + suffix='.tmp', + dir=dir_path, + delete=False, + encoding='utf-8', + ) + try: + json.dump(data, fd, ensure_ascii=False) + fd.flush() + os.fsync(fd.fileno()) + fd.close() + + # Apply exclusive lock to temp file then atomic replace + with open(fd.name) as lock_f: + fcntl.flock(lock_f.fileno(), fcntl.LOCK_EX) + os.replace(fd.name, self._file_path) + fcntl.flock(lock_f.fileno(), fcntl.LOCK_UN) + except BaseException: + # Clean up temp file + if os.path.exists(fd.name): + os.unlink(fd.name) + raise + + async def _load(self) -> dict[str, dict[str, Any]]: + return await asyncio.to_thread(self._load_sync) + + async def _save(self, data: dict[str, dict[str, Any]]) -> None: + await asyncio.to_thread(self._save_sync, data) + + def _is_expired(self, entry: dict[str, Any]) -> bool: + """Check if entry is expired.""" + expire_at = entry.get('expire_at') + return expire_at is not None and time.time() >= expire_at + + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """Store key-value pair with optional TTL in seconds.""" + expire_at = (time.time() + ttl) if ttl is not None else None + data = await self._load() + data[key] = {'value': value, 'expire_at': expire_at} + await self._save(data) + + async def get(self, key: str) -> Any | None: + """Retrieve value, return None if not found or expired.""" + data = await self._load() + entry = data.get(key) + if entry is None: + return None + if self._is_expired(entry): + del data[key] + await self._save(data) + return None + return entry['value'] + + async def delete(self, key: str) -> None: + """Delete key, silently ignore if not found.""" + data = await self._load() + if key in data: + del data[key] + await self._save(data) + + async def exists(self, key: str) -> bool: + """Check if key exists and is not expired.""" + data = await self._load() + entry = data.get(key) + if entry is None: + return False + if self._is_expired(entry): + del data[key] + await self._save(data) + return False + return True + + async def keys(self, pattern: str) -> list[str]: + """Return all key names matching the pattern. Supports * wildcard.""" + data = await self._load() + result: list[str] = [] + expired_keys: list[str] = [] + for key, entry in data.items(): + if self._is_expired(entry): + expired_keys.append(key) + continue + if fnmatch(key, pattern): + result.append(key) + if expired_keys: + for key in expired_keys: + del data[key] + await self._save(data) + return result + + async def count(self, pattern: str) -> int: + """Count keys matching the pattern.""" + return len(await self.keys(pattern)) + + async def set_nx(self, key: str, value: Any) -> bool: + """Set if not exists. Return True if successfully set, False if key already exists and is not expired.""" + data = await self._load() + entry = data.get(key) + if entry is not None and not self._is_expired(entry): + return False + data[key] = {'value': value, 'expire_at': None} + await self._save(data) + return True + + async def close(self) -> None: + """Close backend. File backend requires no persistent connection, no-op.""" + pass + + async def health_check(self) -> bool: + """Check if file path is writable.""" + try: + return os.access(self._file_path, os.W_OK) + except OSError: + return False diff --git a/src/twinkle/server/state/backend/memory_backend.py b/src/twinkle/server/state/backend/memory_backend.py new file mode 100644 index 000000000..3bbde4d15 --- /dev/null +++ b/src/twinkle/server/state/backend/memory_backend.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import time +from fnmatch import fnmatch +from typing import Any + +from .base import StateBackend + + +class MemoryBackend(StateBackend): + """In-memory dictionary-based state backend implementation. + + Uses ``dict[str, tuple[Any, float | None]]`` to store (value, expire_at). + Expiration is checked lazily during get/exists, suitable for Ray Actor single-threaded model. + """ + + def __init__(self) -> None: + self._store: dict[str, tuple[Any, float | None]] = {} + + def _is_expired(self, key: str) -> bool: + """Check if key is expired. If expired, delete and return True.""" + entry = self._store.get(key) + if entry is None: + return True + _, expire_at = entry + if expire_at is not None and time.time() >= expire_at: + del self._store[key] + return True + return False + + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """Store key-value pair with optional TTL in seconds.""" + expire_at = (time.time() + ttl) if ttl is not None else None + self._store[key] = (value, expire_at) + + async def get(self, key: str) -> Any | None: + """Retrieve value, return None if not found or expired.""" + if self._is_expired(key): + return None + value, _ = self._store[key] + return value + + async def delete(self, key: str) -> None: + """Delete key, silently ignore if not found.""" + self._store.pop(key, None) + + async def exists(self, key: str) -> bool: + """Check if key exists and is not expired.""" + return not self._is_expired(key) + + async def keys(self, pattern: str) -> list[str]: + """Return all key names matching the pattern. Supports * wildcard.""" + result: list[str] = [] + # Collect expired keys during iteration to avoid modifying dict while iterating + expired_keys: list[str] = [] + for key, (_, expire_at) in self._store.items(): + if expire_at is not None and time.time() >= expire_at: + expired_keys.append(key) + continue + if fnmatch(key, pattern): + result.append(key) + for key in expired_keys: + del self._store[key] + return result + + async def count(self, pattern: str) -> int: + """Count keys matching the pattern.""" + return len(await self.keys(pattern)) + + async def set_nx(self, key: str, value: Any) -> bool: + """Set if not exists. Return True if successfully set, False if key already exists.""" + if not self._is_expired(key): + return False + self._store[key] = (value, None) + return True + + async def close(self) -> None: + """Close backend, clear storage.""" + self._store.clear() + + async def health_check(self) -> bool: + """Check if backend is healthy and available. Memory backend always returns True.""" + return True diff --git a/src/twinkle/server/state/backend/redis_backend.py b/src/twinkle/server/state/backend/redis_backend.py new file mode 100644 index 000000000..7f431924a --- /dev/null +++ b/src/twinkle/server/state/backend/redis_backend.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import json +from typing import Any + +from .base import StateBackend + +try: + import redis.asyncio as aioredis + + _REDIS_AVAILABLE = True +except ImportError: + _REDIS_AVAILABLE = False + + +class RedisBackend(StateBackend): + """Redis-based persistent state backend implementation. + + Uses ``redis.asyncio`` client, values are stored as Redis strings via JSON serialization. + TTL is managed by Redis native EXPIRE mechanism. + """ + + def __init__(self, redis_url: str, key_prefix: str = '') -> None: + if not _REDIS_AVAILABLE: + raise ImportError('redis package required. Install with: pip install redis') + self._client = aioredis.from_url(redis_url, decode_responses=True) + self._prefix = key_prefix + + def _make_key(self, key: str) -> str: + """Add namespace prefix to key.""" + return f'{self._prefix}{key}' if self._prefix else key + + def _strip_prefix(self, key: str) -> str: + """Remove namespace prefix from full key.""" + return key[len(self._prefix):] if self._prefix else key + + async def set(self, key: str, value: Any, ttl: int | None = None) -> None: + """Store key-value pair with optional TTL in seconds.""" + real_key = self._make_key(key) + data = json.dumps(value) + if ttl is not None: + await self._client.set(real_key, data, ex=ttl) + else: + await self._client.set(real_key, data) + + async def get(self, key: str) -> Any | None: + """Retrieve value, return None if not found or expired.""" + real_key = self._make_key(key) + raw = await self._client.get(real_key) + if raw is None: + return None + return json.loads(raw) + + async def delete(self, key: str) -> None: + """Delete key, silently ignore if not found.""" + real_key = self._make_key(key) + await self._client.delete(real_key) + + async def exists(self, key: str) -> bool: + """Check if key exists and is not expired.""" + real_key = self._make_key(key) + return bool(await self._client.exists(real_key)) + + async def keys(self, pattern: str) -> list[str]: + """Return all key names matching the pattern. Supports * wildcard. + + Note: For high key volumes in production, consider using SCAN to avoid blocking. + """ + real_pattern = self._make_key(pattern) + raw_keys = await self._client.keys(real_pattern) + return [self._strip_prefix(k) for k in raw_keys] + + async def count(self, pattern: str) -> int: + """Count keys matching the pattern.""" + return len(await self.keys(pattern)) + + async def set_nx(self, key: str, value: Any, ttl: int | None = None) -> bool: + """Set if not exists. Return True if successfully set, False if key already exists.""" + real_key = self._make_key(key) + data = json.dumps(value) + if ttl is not None: + result = await self._client.set(real_key, data, nx=True, ex=ttl) + else: + result = await self._client.set(real_key, data, nx=True) + return result is not None + + async def close(self) -> None: + """Close Redis connection.""" + await self._client.aclose() + + async def health_check(self) -> bool: + """Check if Redis is healthy and available.""" + try: + return await self._client.ping() + except Exception: + return False diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py new file mode 100644 index 000000000..97c2bf968 --- /dev/null +++ b/src/twinkle/server/state/base.py @@ -0,0 +1,105 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import logging +import time +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from pydantic import BaseModel +from typing import Generic, TypeVar + +from twinkle.server.state.backend.base import StateBackend + +T = TypeVar('T', bound=BaseModel) +logger = logging.getLogger(__name__) + + +class BaseManager(ABC, Generic[T]): + """Abstract base class for resource managers using StateBackend. + + Provides common async CRUD operations and timestamp parsing. + Subclasses must implement `cleanup_expired`. + """ + + def __init__(self, backend: StateBackend, key_prefix: str, record_type: type[T], expiration_timeout: float): + self._backend = backend + self._prefix = key_prefix # e.g. 'session::', 'model::', 'sampling::', 'future::' + self._record_type = record_type + self.expiration_timeout = expiration_timeout + + def _make_key(self, resource_id: str) -> str: + return f'{self._prefix}{resource_id}' + + def _strip_prefix(self, key: str) -> str: + return key[len(self._prefix):] + + # ----- CRUD ----- + + async def add(self, resource_id: str, record: T) -> None: + """Store a record in the backend.""" + await self._backend.set(self._make_key(resource_id), record.model_dump()) + + async def get(self, resource_id: str) -> T | None: + """Retrieve a record by ID.""" + data = await self._backend.get(self._make_key(resource_id)) + if data is None: + return None + return self._record_type.model_validate(data) + + async def remove(self, resource_id: str) -> bool: + """Remove a record. Returns True if it existed.""" + key = self._make_key(resource_id) + exists = await self._backend.exists(key) + if exists: + await self._backend.delete(key) + return exists + + async def count(self) -> int: + """Count all records managed by this manager.""" + return await self._backend.count(f'{self._prefix}*') + + async def keys(self) -> list[str]: + """Get all resource IDs (without prefix).""" + raw_keys = await self._backend.keys(f'{self._prefix}*') + return [self._strip_prefix(k) for k in raw_keys] + + async def get_all(self) -> dict[str, T]: + """Load all records from backend. Used for index rebuilding.""" + all_keys = await self._backend.keys(f'{self._prefix}*') + result = {} + for key in all_keys: + data = await self._backend.get(key) + if data is not None: + resource_id = self._strip_prefix(key) + result[resource_id] = self._record_type.model_validate(data) + return result + + # ----- Cleanup ----- + + @abstractmethod + async def cleanup_expired(self, cutoff_time: float, **kwargs) -> int: + """Remove all records older than cutoff_time. + + Args: + cutoff_time: Unix timestamp; records with activity before this are removed. + + Returns: + Number of records removed. + """ + ... + + # ----- Helpers ----- + + def _parse_timestamp(self, timestamp_str: str) -> float: + """Parse an ISO-format timestamp string to a Unix timestamp. + + Falls back to the current time so that unparseable entries are + never accidentally kept alive forever. + """ + try: + dt = datetime.fromisoformat(timestamp_str) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.timestamp() + except (ValueError, TypeError, AttributeError): + return time.time() diff --git a/src/twinkle/server/state/config_manager.py b/src/twinkle/server/state/config_manager.py new file mode 100644 index 000000000..ecf76f63b --- /dev/null +++ b/src/twinkle/server/state/config_manager.py @@ -0,0 +1,81 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +from typing import Any + +from .backend import StateBackend + +# Key prefix used to namespace configuration entries inside the backend. +_CONFIG_PREFIX = 'config::' +_CONFIG_PATTERN = f'{_CONFIG_PREFIX}*' + + +class ConfigManager: + """ + Manages key-value configuration entries via a :class:`StateBackend`. + + Configuration entries have no expiry; they persist until explicitly removed + or cleared. This manager does not inherit from BaseManager because config + values are arbitrary Python objects rather than Pydantic models, and all + storage is delegated to the injected backend. + + Methods are ``async`` because :class:`StateBackend` operations are async. + ConfigManager is expected to run inside a single-threaded Ray Actor, so + there is no need to add additional locking on top of the backend. + """ + + def __init__(self, backend: StateBackend) -> None: + self._backend = backend + + @staticmethod + def _make_key(key: str) -> str: + return f'{_CONFIG_PREFIX}{key}' + + # ----- CRUD ----- + + async def add(self, key: str, value: Any) -> None: + """Add or overwrite a configuration value.""" + await self._backend.set(self._make_key(key), value) + + async def add_or_get(self, key: str, value: Any) -> Any: + """Add a value if the key does not exist; otherwise return the existing value. + + Args: + key: Configuration key. + value: Value to store if the key is absent. + + Returns: + The existing or newly stored value. + """ + backend_key = self._make_key(key) + existing = await self._backend.get(backend_key) + if existing is not None: + return existing + # Use set_nx for atomicity within a single backend; if another + # writer already populated the key we return the winning value. + if await self._backend.set_nx(backend_key, value): + return value + return await self._backend.get(backend_key) + + async def get(self, key: str) -> Any | None: + """Return the configuration value for key, or None.""" + return await self._backend.get(self._make_key(key)) + + async def pop(self, key: str) -> Any | None: + """Remove and return the configuration value for key, or None.""" + backend_key = self._make_key(key) + value = await self._backend.get(backend_key) + if value is None: + return None + await self._backend.delete(backend_key) + return value + + async def clear(self) -> None: + """Remove all configuration entries.""" + keys = await self._backend.keys(_CONFIG_PATTERN) + for backend_key in keys: + await self._backend.delete(backend_key) + + async def count(self) -> int: + """Return the number of stored configuration entries.""" + return await self._backend.count(_CONFIG_PATTERN) diff --git a/src/twinkle/server/state/config_signature.py b/src/twinkle/server/state/config_signature.py new file mode 100644 index 000000000..f38fa3cb5 --- /dev/null +++ b/src/twinkle/server/state/config_signature.py @@ -0,0 +1,156 @@ +"""Configuration signature validation for state persistence integrity.""" +from __future__ import annotations + +import hashlib +import json +import logging +from enum import Enum +from typing import Any + +from twinkle.server.exceptions import ConfigMismatchError +from twinkle.server.state.backend.base import StateBackend + +logger = logging.getLogger(__name__) + +_SIGNATURE_KEY = '_meta::config_signature' +_PAYLOAD_KEY = '_meta::config_payload' + + +class SignatureMismatchPolicy(str, Enum): + """Policy for handling config signature mismatches.""" + WARN = 'warn' # Log warning and continue + CLEAR = 'clear' # Clear all backend data and continue + ABORT = 'abort' # Raise error, refuse to start + + +def compute_signature(config: dict[str, Any]) -> str: + """Compute a SHA256 hash of the configuration dictionary. + + Args: + config: Configuration dictionary to hash. + + Returns: + Hex string of SHA256 hash. + """ + # Sort keys for deterministic serialization + serialized = json.dumps(config, sort_keys=True, default=str) + return hashlib.sha256(serialized.encode()).hexdigest() + + +async def validate_config_signature( + backend: StateBackend, + current_config: dict[str, Any], + policy: SignatureMismatchPolicy = SignatureMismatchPolicy.WARN, +) -> bool: + """Validate configuration signature against stored value. + + Compares the current config's hash with the previously stored hash. + On first run (no stored hash), stores the current hash and returns True. + + Args: + backend: State backend to read/write signature. + current_config: Current configuration dictionary. + policy: Action to take on mismatch. + + Returns: + True if signature matches or is new. False if mismatch with WARN/CLEAR policy. + + Raises: + ConfigMismatchError: If policy is ABORT and signature doesn't match. + """ + current_sig = compute_signature(current_config) + stored_sig = await backend.get(_SIGNATURE_KEY) + + if stored_sig is None: + # First run — store signature + logger.info('No previous config signature found. Storing current signature.') + await backend.set(_SIGNATURE_KEY, current_sig) + return True + + if stored_sig == current_sig: + logger.debug('Config signature matches stored value.') + return True + + # Mismatch detected + logger.warning(f'Config signature mismatch! ' + f'Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... ' + f'Policy: {policy.value}') + + if policy == SignatureMismatchPolicy.WARN: + # Update to new signature and continue + await backend.set(_SIGNATURE_KEY, current_sig) + return False + + elif policy == SignatureMismatchPolicy.CLEAR: + # Clear all data except meta keys, store new signature + logger.warning('Clearing all backend data due to config signature mismatch.') + all_keys = await backend.keys('*') + for key in all_keys: + if not key.startswith('_meta::'): + await backend.delete(key) + await backend.set(_SIGNATURE_KEY, current_sig) + return False + + elif policy == SignatureMismatchPolicy.ABORT: + raise ConfigMismatchError(f'Configuration signature mismatch. ' + f'Stored: {stored_sig[:12]}..., Current: {current_sig[:12]}... ' + f"Use policy='warn' or 'clear' to allow startup with changed config.") + + return False + + +# --------------------------------------------------------------------------- +# CLI startup hook (R15) +# --------------------------------------------------------------------------- + + +def _format_diff(stored: dict[str, Any] | None, current: dict[str, Any]) -> str: + """Render a stored-vs-current diff suitable for a remediation hint.""" + lines: list[str] = [] + keys = sorted(set((stored or {}).keys()) | set(current.keys())) + for k in keys: + s = (stored or {}).get(k, '') + c = current.get(k, '') + if s != c: + lines.append(f' - {k}: stored={s!r} current={c!r}') + return '\n'.join(lines) if lines else ' (no field-level diff — values differ at nested level)' + + +async def validate_against_backend(persistence_config: Any, current_config: dict[str, Any]) -> None: + """Validate the persistence config signature on launcher startup (R15). + + Builds a backend from ``persistence_config`` (a :class:`PersistenceConfig`), + computes the current signature, and compares it to the stored value: + - if no signature is stored, store the current one and continue (R15.4); + - if signatures match, return cleanly; + - if they differ, raise :class:`ConfigMismatchError` with a stored-vs- + current diff and a remediation hint (R15.2, R15.3). + + Designed to be called BEFORE ``ray.init`` so the launcher can fail fast + without spinning up the cluster (R15.1). + """ + from twinkle.server.state.backend.factory import create_backend + + backend = create_backend(persistence_config) + current_sig = compute_signature(current_config) + stored_sig = await backend.get(_SIGNATURE_KEY) + + if stored_sig is None: + await backend.set(_SIGNATURE_KEY, current_sig) + # Persist the payload alongside the signature so a future drift diff can + # render real ``stored vs current`` field-level differences (R15.3). + await backend.set(_PAYLOAD_KEY, current_config) + logger.info('No previous config signature found. Stored current signature.') + return + + if stored_sig == current_sig: + return + + stored_payload = await backend.get(_PAYLOAD_KEY) + diff = _format_diff(stored_payload if isinstance(stored_payload, dict) else None, current_config) + raise ConfigMismatchError('Persistence configuration drifted since the last launch. ' + f'Stored signature: {stored_sig[:12]}..., current signature: {current_sig[:12]}...\n' + f'Differences:\n{diff}\n' + 'Remediation: either revert the persistence section to match the stored ' + 'value, or clear the persisted state with ' + '`python -m twinkle.server clear persistence --config ` and relaunch.') diff --git a/src/twinkle/server/utils/state/future_manager.py b/src/twinkle/server/state/future_manager.py similarity index 81% rename from src/twinkle/server/utils/state/future_manager.py rename to src/twinkle/server/state/future_manager.py index 0af069a86..6b9cf016f 100644 --- a/src/twinkle/server/utils/state/future_manager.py +++ b/src/twinkle/server/state/future_manager.py @@ -4,20 +4,23 @@ from datetime import datetime from typing import Any +from .backend.base import StateBackend from .base import BaseManager from .models import FutureRecord class FutureManager(BaseManager[FutureRecord]): - """ - Manages async task futures / request statuses. + """Manages async task futures / request statuses. Expiry is based on `updated_at` (falls back to `created_at`). """ + def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: + super().__init__(backend, 'future::', FutureRecord, expiration_timeout) + # ----- Future-specific operations ----- - def store_status( + async def store_status( self, request_id: str, status: str, @@ -45,7 +48,7 @@ def store_status( result = result.model_dump() now = datetime.now().isoformat() - existing = self._store.get(request_id) + existing = await self.get(request_id) if existing is not None: existing.status = status @@ -59,8 +62,9 @@ def store_status( existing.queue_state = queue_state if queue_state_reason is not None: existing.queue_state_reason = queue_state_reason + await self.add(request_id, existing) else: - self._store[request_id] = FutureRecord( + record = FutureRecord( status=status, model_id=model_id, reason=reason, @@ -70,10 +74,11 @@ def store_status( created_at=now, updated_at=now, ) + await self.add(request_id, record) # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float) -> int: + async def cleanup_expired(self, cutoff_time: float, **kwargs) -> int: """Remove futures whose last update is older than cutoff_time. Args: @@ -82,14 +87,15 @@ def cleanup_expired(self, cutoff_time: float) -> int: Returns: Number of futures removed. """ + all_records = await self.get_all() expired_ids = [] - for request_id, record in self._store.items(): + for request_id, record in all_records.items(): timestamp_str = record.updated_at or record.created_at timestamp = self._parse_timestamp(timestamp_str) if timestamp < cutoff_time: expired_ids.append(request_id) for request_id in expired_ids: - del self._store[request_id] + await self.remove(request_id) return len(expired_ids) diff --git a/src/twinkle/server/state/model_manager.py b/src/twinkle/server/state/model_manager.py new file mode 100644 index 000000000..091f4a74d --- /dev/null +++ b/src/twinkle/server/state/model_manager.py @@ -0,0 +1,151 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Backend-backed model manager (R19). + +Every index this manager exposes — token → model count, replica → loaded model +count, replica capacity — is computed from the persisted ``model::*`` and +``replica::*`` records in the shared :class:`StateBackend`. There is no +in-process cache, so two workers connected to the same backend see one +consistent view of the cluster's model registry without going through a +detached Ray Actor. +""" +from __future__ import annotations + +from .backend.base import StateBackend +from .base import BaseManager +from .models import ModelRecord +from .replica_registry import ReplicaRegistry + + +class ModelManager(BaseManager[ModelRecord]): + """Manages registered models with backend-derived per-token / per-replica indexes. + + Expiry is based on ``created_at``. A model is also considered expired if + its owning session has already been removed (cascade expiry). + Enforces a per-token model limit across all model instances (server-global). + """ + + def __init__(self, backend: StateBackend, expiration_timeout: float, per_token_model_limit: int = 30) -> None: + super().__init__(backend, 'model::', ModelRecord, expiration_timeout) + self._per_token_model_limit = per_token_model_limit + self._replicas = ReplicaRegistry(backend) + + # ----- Index Rebuild -------------------------------------------------- # + + async def rebuild_indexes(self) -> None: + """Compatibility shim — indexes are now derived from the backend per call.""" + return None + + # ----- Capacity ------------------------------------------------------- # + + async def get_capacity_info(self) -> dict[str, int]: + """Return global LoRA capacity across all registered replicas.""" + replicas = await self._replicas.get_all() + loaded_per_replica = await self._loaded_per_replica() + total_max = sum(replicas.values()) + total_used = sum(loaded_per_replica.get(rid, 0) for rid in replicas) + return { + 'max_loras': total_max, + 'used_loras': total_used, + 'free_loras': max(0, total_max - total_used), + } + + # ----- Replica Registration ------------------------------------------ # + + async def register_replica(self, replica_id: str, max_loras: int) -> None: + """Register a replica and its LoRA capacity in the shared backend.""" + await self._replicas.register(replica_id, max_loras) + + async def unregister_replica(self, replica_id: str) -> None: + """Remove a replica's capacity entry and any models it owns.""" + loaded = await self._models_for_replica(replica_id) + for model_id in loaded: + await self.remove(model_id) + await self._replicas.unregister(replica_id) + + async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + """Return the subset of ``candidate_ids`` that still have capacity. + + A replica has capacity when its persisted loaded-model count is strictly + less than its declared ``max_loras``. Unknown replicas (no capacity row + in the backend) are included conservatively, matching the previous + actor-based behavior (R19.3). + """ + if not candidate_ids: + return [] + replicas = await self._replicas.get_all() + loaded_per_replica = await self._loaded_per_replica(replica_filter=set(candidate_ids) | replicas.keys()) + available: list[str] = [] + for rid in candidate_ids: + max_loras = replicas.get(rid) + if max_loras is None: + # Unknown replica — include conservatively. + available.append(rid) + continue + if loaded_per_replica.get(rid, 0) < max_loras: + available.append(rid) + return available + + # ----- CRUD ----------------------------------------------------------- # + + async def add(self, model_id: str, record: ModelRecord) -> None: + """Store a record, enforcing the per-token model limit. + + Raises: + RuntimeError: when adding ``record`` would exceed + ``per_token_model_limit`` for ``record.token``. + """ + token = record.token + current = await self._count_models_for_token(token) + if current >= self._per_token_model_limit: + raise RuntimeError(f'Model limit exceeded: {current}/{self._per_token_model_limit} models') + await super().add(model_id, record) + + async def remove(self, model_id: str) -> bool: + """Remove a record by ID.""" + record = await self.get(model_id) + if record is None: + return False + await super().remove(model_id) + return True + + # ----- Cleanup -------------------------------------------------------- # + + async def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None, **kwargs) -> int: + """Remove models older than ``cutoff_time`` or whose owning session expired.""" + session_set = set(expired_session_ids or []) + all_records = await self.get_all() + expired_ids: list[str] = [] + for model_id, record in all_records.items(): + if record.session_id and record.session_id in session_set: + expired_ids.append(model_id) + continue + created_at = self._parse_timestamp(record.created_at) + if created_at < cutoff_time: + expired_ids.append(model_id) + for model_id in expired_ids: + await self.remove(model_id) + return len(expired_ids) + + # ----- Backend-derived helpers --------------------------------------- # + + async def _count_models_for_token(self, token: str | None) -> int: + if not token: + return 0 + all_records = await self.get_all() + return sum(1 for r in all_records.values() if r.token == token) + + async def _models_for_replica(self, replica_id: str) -> list[str]: + all_records = await self.get_all() + return [mid for mid, r in all_records.items() if r.replica_id == replica_id] + + async def _loaded_per_replica(self, replica_filter: set[str] | None = None) -> dict[str, int]: + all_records = await self.get_all() + counts: dict[str, int] = {} + for record in all_records.values(): + rid = record.replica_id + if rid is None: + continue + if replica_filter is not None and rid not in replica_filter: + continue + counts[rid] = counts.get(rid, 0) + 1 + return counts diff --git a/src/twinkle/server/utils/state/models.py b/src/twinkle/server/state/models.py similarity index 100% rename from src/twinkle/server/utils/state/models.py rename to src/twinkle/server/state/models.py diff --git a/src/twinkle/server/state/replica_registry.py b/src/twinkle/server/state/replica_registry.py new file mode 100644 index 000000000..9eb8e6cbf --- /dev/null +++ b/src/twinkle/server/state/replica_registry.py @@ -0,0 +1,70 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Backend-backed registry of replica capacity (R19). + +Replaces the in-process ``_replica_max_loras`` dict that lived inside the +detached Ray Actor. Each entry persists to ``replica::::max_loras`` +in the configured :class:`StateBackend`, so two workers on the same shared +backend (Redis in multi-node, MemoryBackend in single-node) see one consistent +view of the cluster's capacity. + +The registry knows *only* about declared capacity. The current loaded-model +count is derived by querying the persisted ``model::*`` records directly — +nothing here caches that count, so concurrent writes from different workers +cannot drift into an inconsistent local index. +""" +from __future__ import annotations + +from .backend.base import StateBackend + +REPLICA_PREFIX = 'replica::' +_MAX_LORAS_SUFFIX = '::max_loras' + + +def _make_key(replica_id: str) -> str: + return f'{REPLICA_PREFIX}{replica_id}{_MAX_LORAS_SUFFIX}' + + +def _replica_id_from_key(key: str) -> str | None: + if not key.startswith(REPLICA_PREFIX) or not key.endswith(_MAX_LORAS_SUFFIX): + return None + return key[len(REPLICA_PREFIX):-len(_MAX_LORAS_SUFFIX)] + + +class ReplicaRegistry: + """Read/write replica capacity through the shared :class:`StateBackend`.""" + + def __init__(self, backend: StateBackend) -> None: + self._backend = backend + + async def register(self, replica_id: str, max_loras: int) -> None: + """Store / overwrite the declared LoRA capacity for ``replica_id``.""" + await self._backend.set(_make_key(replica_id), int(max_loras)) + + async def unregister(self, replica_id: str) -> None: + """Remove the capacity entry for ``replica_id`` (idempotent).""" + await self._backend.delete(_make_key(replica_id)) + + async def get_max_loras(self, replica_id: str) -> int | None: + """Return the declared capacity, or ``None`` if the replica is unknown.""" + value = await self._backend.get(_make_key(replica_id)) + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + async def get_all(self) -> dict[str, int]: + """Return every registered replica's declared capacity.""" + keys = await self._backend.keys(f'{REPLICA_PREFIX}*{_MAX_LORAS_SUFFIX}') + out: dict[str, int] = {} + for key in keys: + rid = _replica_id_from_key(key) + if rid is None: + continue + value = await self._backend.get(key) + try: + out[rid] = int(value) + except (TypeError, ValueError): + continue + return out diff --git a/src/twinkle/server/utils/state/sampling_manager.py b/src/twinkle/server/state/sampling_manager.py similarity index 74% rename from src/twinkle/server/utils/state/sampling_manager.py rename to src/twinkle/server/state/sampling_manager.py index ff3111a6f..7dd535a5e 100644 --- a/src/twinkle/server/utils/state/sampling_manager.py +++ b/src/twinkle/server/state/sampling_manager.py @@ -1,21 +1,24 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +from .backend.base import StateBackend from .base import BaseManager from .models import SamplingSessionRecord class SamplingSessionManager(BaseManager[SamplingSessionRecord]): - """ - Manages sampling sessions. + """Manages sampling sessions. Expiry is based on `created_at`. A sampling session is also considered expired if its owning session has already been removed (cascade expiry). """ + def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: + super().__init__(backend, 'sampling::', SamplingSessionRecord, expiration_timeout) + # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: + async def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None, **kwargs) -> int: """Remove sampling sessions that are older than cutoff_time, or whose owning session has already been expired. @@ -29,9 +32,10 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N Number of sampling sessions removed. """ session_set = set(expired_session_ids or []) + all_records = await self.get_all() expired_ids = [] - for sampling_id, record in self._store.items(): + for sampling_id, record in all_records.items(): # Cascade: owner session was expired if record.session_id and record.session_id in session_set: expired_ids.append(sampling_id) @@ -42,6 +46,6 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N expired_ids.append(sampling_id) for sampling_id in expired_ids: - del self._store[sampling_id] + await self.remove(sampling_id) return len(expired_ids) diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/state/server_state.py similarity index 50% rename from src/twinkle/server/utils/state/server_state.py rename to src/twinkle/server/state/server_state.py index fd3a76269..4002e3754 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/state/server_state.py @@ -2,15 +2,20 @@ from __future__ import annotations import asyncio -import ray import re import time import uuid from datetime import datetime from typing import Any -from twinkle.server.utils.metrics import get_resource_metrics +from twinkle.server.exceptions import ConfigMismatchError +from twinkle.server.telemetry import MetricsRegistry +from twinkle.server.telemetry.correlation import (BASE_MODEL, MODEL_ID, REPLICA_ID, SAMPLING_SESSION_ID, SESSION_ID, + TOKEN_ID) +from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger +from .backend import StateBackend +from .backend.factory import PersistenceConfig, create_backend from .config_manager import ConfigManager from .future_manager import FutureManager from .model_manager import ModelManager @@ -32,33 +37,48 @@ class ServerState: - FutureManager — async task futures - ConfigManager — key-value configuration - All methods are designed to be used with Ray actors for distributed state. + Bound directly to a shared :class:`StateBackend` (R19); no detached + Ray Actor is created. Each Ray Serve worker owns one process-local + instance and the cleanup loop is started from the deployment's + FastAPI ``lifespan`` startup hook. """ def __init__( self, + backend: StateBackend | None = None, + persistence_config: PersistenceConfig | None = None, expiration_timeout: float = 86400.0, # 24 hours in seconds cleanup_interval: float = 3600.0, # 1 hour in seconds per_token_model_limit: int = 30, + signature_config: dict[str, Any] | None = None, + signature_policy: str = 'warn', **kwargs) -> None: - self._session_mgr = SessionManager(expiration_timeout) - self._model_mgr = ModelManager(expiration_timeout, per_token_model_limit) - self._sampling_mgr = SamplingSessionManager(expiration_timeout) - self._future_mgr = FutureManager(expiration_timeout) - self._config_mgr = ConfigManager() + if backend is not None: + self._backend: StateBackend = backend + else: + self._backend = create_backend(persistence_config) + self._session_mgr = SessionManager(self._backend, expiration_timeout) + self._model_mgr = ModelManager(self._backend, expiration_timeout, per_token_model_limit) + self._sampling_mgr = SamplingSessionManager(self._backend, expiration_timeout) + self._future_mgr = FutureManager(self._backend, expiration_timeout) + self._config_mgr = ConfigManager(self._backend) self.expiration_timeout = expiration_timeout self.cleanup_interval = cleanup_interval self._cleanup_task: asyncio.Task | None = None self._cleanup_running = False + # Config signature validation state + self._signature_config = signature_config + self._signature_policy = signature_policy + # Metrics loop state self._metrics_task: asyncio.Task | None = None self._metrics_running = False self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0)) async def get_capacity_info(self) -> dict[str, int]: - return self._model_mgr.get_capacity_info() + return await self._model_mgr.get_capacity_info() # ----- Session Management ----- @@ -72,13 +92,17 @@ async def create_session(self, payload: dict[str, Any]) -> str: The session_id for the created session. """ session_id = payload.get('session_id') or f'session_{uuid.uuid4().hex}' - record = SessionRecord( - tags=list(payload.get('tags') or []), - user_metadata=payload.get('user_metadata') or {}, - sdk_version=payload.get('sdk_version'), - ) - self._session_mgr.add(session_id, record) - return session_id + with traced_operation( + 'server_state.create_session', + attrs={SESSION_ID: session_id}, + ): + record = SessionRecord( + tags=list(payload.get('tags') or []), + user_metadata=payload.get('user_metadata') or {}, + sdk_version=payload.get('sdk_version'), + ) + await self._session_mgr.add(session_id, record) + return session_id async def touch_session(self, session_id: str) -> bool: """Update session heartbeat timestamp. @@ -86,7 +110,7 @@ async def touch_session(self, session_id: str) -> bool: Returns: True if the session exists and was touched, False otherwise. """ - return self._session_mgr.touch(session_id) + return await self._session_mgr.touch(session_id) async def get_session_last_heartbeat(self, session_id: str) -> float | None: """Get the last heartbeat timestamp for a session. @@ -94,7 +118,7 @@ async def get_session_last_heartbeat(self, session_id: str) -> float | None: Returns: Last heartbeat timestamp, or None if the session does not exist. """ - return self._session_mgr.get_last_heartbeat(session_id) + return await self._session_mgr.get_last_heartbeat(session_id) # ----- Model Registration ----- @@ -122,17 +146,27 @@ async def register_model(self, 'model_id') or f"{_time}-{payload.get('base_model', 'model')}-{uuid.uuid4().hex[:8]}" _model_id = re.sub(r'[^\w\-]', '_', _model_id) - record = ModelRecord( - session_id=session_id or payload.get('session_id'), - model_seq_id=payload.get('model_seq_id'), - base_model=payload.get('base_model'), - user_metadata=payload.get('user_metadata') or {}, - lora_config=payload.get('lora_config'), - token=token, - replica_id=replica_id, - ) - self._model_mgr.add(_model_id, record) - return _model_id + with traced_operation( + 'server_state.register_model', + attrs={ + MODEL_ID: _model_id, + BASE_MODEL: payload.get('base_model'), + REPLICA_ID: replica_id, + TOKEN_ID: token, + SESSION_ID: session_id or payload.get('session_id'), + }, + ): + record = ModelRecord( + session_id=session_id or payload.get('session_id'), + model_seq_id=payload.get('model_seq_id'), + base_model=payload.get('base_model'), + user_metadata=payload.get('user_metadata') or {}, + lora_config=payload.get('lora_config'), + token=token, + replica_id=replica_id, + ) + await self._model_mgr.add(_model_id, record) + return _model_id async def unload_model(self, model_id: str) -> bool: """Remove a model from the registry. @@ -140,11 +174,11 @@ async def unload_model(self, model_id: str) -> bool: Returns: True if the model was found and removed, False otherwise. """ - return self._model_mgr.remove(model_id) + return await self._model_mgr.remove(model_id) async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: """Get metadata for a registered model as a plain dict.""" - record = self._model_mgr.get(model_id) + record = await self._model_mgr.get(model_id) return record.model_dump() if record is not None else None # ----- Replica Management ----- @@ -156,7 +190,11 @@ async def register_replica(self, replica_id: str, max_loras: int) -> None: replica_id: Unique identifier for the replica. max_loras: Maximum number of LoRA adapters the replica can hold. """ - self._model_mgr.register_replica(replica_id, max_loras) + with traced_operation( + 'server_state.register_replica', + attrs={REPLICA_ID: replica_id}, + ): + await self._model_mgr.register_replica(replica_id, max_loras) async def unregister_replica(self, replica_id: str) -> None: """Remove a replica from the registry. @@ -164,7 +202,7 @@ async def unregister_replica(self, replica_id: str) -> None: Args: replica_id: Unique identifier for the replica to remove. """ - self._model_mgr.unregister_replica(replica_id) + await self._model_mgr.unregister_replica(replica_id) async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: """Return candidate replica IDs that have not reached their max_loras limit. @@ -175,7 +213,7 @@ async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str] Returns: Filtered list of replica IDs with remaining capacity. """ - return self._model_mgr.get_available_replica_ids(candidate_ids) + return await self._model_mgr.get_available_replica_ids(candidate_ids) # ----- Sampling Session Management ----- @@ -191,25 +229,33 @@ async def create_sampling_session(self, payload: dict[str, Any], sampling_sessio """ _sampling_session_id: str = sampling_session_id or payload.get( 'sampling_session_id') or f'sampling_{uuid.uuid4().hex}' - record = SamplingSessionRecord( - session_id=payload.get('session_id'), - seq_id=payload.get('sampling_session_seq_id'), - base_model=payload.get('base_model'), - model_path=payload.get('model_path'), - ) - self._sampling_mgr.add(_sampling_session_id, record) - return _sampling_session_id + with traced_operation( + 'server_state.create_sampling_session', + attrs={ + SAMPLING_SESSION_ID: _sampling_session_id, + SESSION_ID: payload.get('session_id'), + BASE_MODEL: payload.get('base_model'), + }, + ): + record = SamplingSessionRecord( + session_id=payload.get('session_id'), + seq_id=payload.get('sampling_session_seq_id'), + base_model=payload.get('base_model'), + model_path=payload.get('model_path'), + ) + await self._sampling_mgr.add(_sampling_session_id, record) + return _sampling_session_id async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: """Get a sampling session by ID as a plain dict.""" - record = self._sampling_mgr.get(sampling_session_id) + record = await self._sampling_mgr.get(sampling_session_id) return record.model_dump() if record is not None else None # ----- Future Management ----- async def get_future(self, request_id: str) -> dict[str, Any] | None: """Retrieve a stored future result as a plain dict.""" - record = self._future_mgr.get(request_id) + record = await self._future_mgr.get(request_id) return record.model_dump() if record is not None else None async def store_future_status( @@ -241,7 +287,7 @@ async def store_future_status( queue_state: Optional queue state for tinker client (active/paused_rate_limit/paused_capacity). queue_state_reason: Optional reason for the queue state. """ - self._future_mgr.store_status( + await self._future_mgr.store_status( request_id=request_id, status=status, model_id=model_id, @@ -251,6 +297,32 @@ async def store_future_status( queue_state_reason=queue_state_reason, ) + # ----- Configuration Management ----- + + async def add_config(self, key: str, value: Any) -> None: + """Add or overwrite a configuration value.""" + await self._config_mgr.add(key, value) + + async def add_or_get_config(self, key: str, value: Any) -> Any: + """Add a config value if absent; otherwise return the existing value.""" + return await self._config_mgr.add_or_get(key, value) + + async def get_config(self, key: str) -> Any | None: + """Return the configuration value for key, or None.""" + return await self._config_mgr.get(key) + + async def pop_config(self, key: str) -> Any | None: + """Remove and return the configuration value for key, or None.""" + return await self._config_mgr.pop(key) + + async def clear_config(self) -> None: + """Remove all configuration entries.""" + await self._config_mgr.clear() + + async def count_config(self) -> int: + """Return the number of stored configuration entries.""" + return await self._config_mgr.count() + # ----- Resource Cleanup ----- async def cleanup_expired_resources(self) -> dict[str, int]: @@ -267,13 +339,14 @@ async def cleanup_expired_resources(self) -> dict[str, int]: cutoff_time = current_time - self.expiration_timeout # Collect expired session IDs first for cascade logic - expired_session_ids = self._session_mgr.get_expired_ids(cutoff_time) + expired_session_ids = await self._session_mgr.get_expired_ids(cutoff_time) # Perform actual cleanup in dependency order - sessions_removed = self._session_mgr.cleanup_expired(cutoff_time) - models_removed = self._model_mgr.cleanup_expired(cutoff_time, expired_session_ids) - samplings_removed = self._sampling_mgr.cleanup_expired(cutoff_time, expired_session_ids) - futures_removed = self._future_mgr.cleanup_expired(cutoff_time) + sessions_removed = await self._session_mgr.cleanup_expired(cutoff_time) + models_removed = await self._model_mgr.cleanup_expired(cutoff_time, expired_session_ids=expired_session_ids) + samplings_removed = await self._sampling_mgr.cleanup_expired( + cutoff_time, expired_session_ids=expired_session_ids) + futures_removed = await self._future_mgr.cleanup_expired(cutoff_time) return { 'sessions': sessions_removed, @@ -297,15 +370,29 @@ async def _cleanup_loop(self) -> None: continue async def _metrics_loop(self) -> None: - """Background task that updates resource gauge metrics every N seconds.""" - resource_metrics = get_resource_metrics() + """Background task that updates resource gauge metrics every N seconds. + + OTEL up/down counters take *deltas*, not absolute values, so we keep + track of the last reported count per resource and emit only the + difference on each tick. + """ + registry = MetricsRegistry.get() + sources = ( + ('active_sessions', self._session_mgr), + ('active_models', self._model_mgr), + ('active_sampling_sessions', self._sampling_mgr), + ('active_futures', self._future_mgr), + ) + last_values: dict[str, int] = {name: 0 for name, _ in sources} while self._metrics_running: try: await asyncio.sleep(self._metrics_update_interval) - resource_metrics.active_sessions.set(self._session_mgr.count()) - resource_metrics.active_models.set(self._model_mgr.count()) - resource_metrics.active_sampling_sessions.set(self._sampling_mgr.count()) - resource_metrics.active_futures.set(self._future_mgr.count()) + for name, mgr in sources: + current = await mgr.count() + delta = current - last_values[name] + if delta != 0: + getattr(registry, name).add(delta) + last_values[name] = current except asyncio.CancelledError: break except Exception as e: @@ -320,6 +407,8 @@ async def start_cleanup_task(self) -> bool: """ if self._cleanup_running: return False + # Rebuild in-memory indexes from backend data + await self._rebuild_indexes() self._cleanup_running = True self._cleanup_task = asyncio.create_task(self._cleanup_loop()) if not self._metrics_running: @@ -327,6 +416,21 @@ async def start_cleanup_task(self) -> bool: self._metrics_task = asyncio.create_task(self._metrics_loop()) return True + async def _rebuild_indexes(self) -> None: + """Rebuild in-memory indexes from backend data after startup. + + Also validates config signature when ``signature_config`` was provided + at construction time. + """ + # Validate config signature if provided + if self._signature_config is not None: + from twinkle.server.state.config_signature import SignatureMismatchPolicy, validate_config_signature + policy = SignatureMismatchPolicy(self._signature_policy) + await validate_config_signature(self._backend, self._signature_config, policy) + + # Rebuild model indexes + await self._model_mgr.rebuild_indexes() + async def stop_cleanup_task(self) -> bool: """Stop the background cleanup task. @@ -356,143 +460,84 @@ async def get_cleanup_stats(self) -> dict[str, Any]: 'cleanup_interval': self.cleanup_interval, 'cleanup_running': self._cleanup_running, 'resource_counts': { - 'sessions': self._session_mgr.count(), - 'models': self._model_mgr.count(), - 'sampling_sessions': self._sampling_mgr.count(), - 'futures': self._future_mgr.count(), + 'sessions': await self._session_mgr.count(), + 'models': await self._model_mgr.count(), + 'sampling_sessions': await self._sampling_mgr.count(), + 'futures': await self._future_mgr.count(), }, } # --------------------------------------------------------------------------- -# Ray proxy +# Direct-backend factory (R19) # --------------------------------------------------------------------------- - - -class ServerStateProxy: - """ - Proxy for interacting with a ServerState Ray actor. - - Wraps Ray remote calls to provide a synchronous-looking API for - interacting with the distributed ServerState actor. - """ - - def __init__(self, actor_handle) -> None: - self._actor = actor_handle - - async def get_capacity_info(self) -> dict[str, int]: - return await self._actor.get_capacity_info.remote() - - # ----- Session Management ----- - - async def create_session(self, payload: dict[str, Any]) -> str: - return await self._actor.create_session.remote(payload) - - async def touch_session(self, session_id: str) -> bool: - return await self._actor.touch_session.remote(session_id) - - async def get_session_last_heartbeat(self, session_id: str) -> float | None: - return await self._actor.get_session_last_heartbeat.remote(session_id) - - # ----- Model Registration ----- - - async def register_model(self, - payload: dict[str, Any], - token: str, - model_id: str | None = None, - replica_id: str | None = None, - session_id: str | None = None) -> str: - return await self._actor.register_model.remote(payload, token, model_id, replica_id, session_id) - - async def unload_model(self, model_id: str) -> bool: - return await self._actor.unload_model.remote(model_id) - - async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: - return await self._actor.get_model_metadata.remote(model_id) - - # ----- Replica Management ----- - - async def register_replica(self, replica_id: str, max_loras: int) -> None: - await self._actor.register_replica.remote(replica_id, max_loras) - - async def unregister_replica(self, replica_id: str) -> None: - await self._actor.unregister_replica.remote(replica_id) - - async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: - return await self._actor.get_available_replica_ids.remote(candidate_ids) - - # ----- Sampling Session Management ----- - - async def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: - return await self._actor.create_sampling_session.remote(payload, sampling_session_id) - - async def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: - return await self._actor.get_sampling_session.remote(sampling_session_id) - - # ----- Future Management ----- - - async def get_future(self, request_id: str) -> dict[str, Any] | None: - return await self._actor.get_future.remote(request_id) - - async def store_future_status( - self, - request_id: str, - status: str, - model_id: str | None, - reason: str | None = None, - result: Any = None, - queue_state: str | None = None, - queue_state_reason: str | None = None, - ) -> None: - """Store task status with optional result.""" - await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, - queue_state_reason) - - # ----- Resource Cleanup ----- - - async def cleanup_expired_resources(self) -> dict[str, int]: - return await self._actor.cleanup_expired_resources.remote() - - async def start_cleanup_task(self) -> bool: - return await self._actor.start_cleanup_task.remote() - - async def stop_cleanup_task(self) -> bool: - return await self._actor.stop_cleanup_task.remote() - - async def get_cleanup_stats(self) -> dict[str, Any]: - return await self._actor.get_cleanup_stats.remote() - - -# --------------------------------------------------------------------------- -# Factory -# --------------------------------------------------------------------------- - - -def get_server_state(actor_name: str = 'twinkle_server_state', **kwargs) -> ServerStateProxy: - """Get or create the ServerState Ray actor. - - Ensures only one ServerState actor exists with the given name. Uses a - detached actor so the state persists across driver restarts. +# +# The detached Ray Actor and the per-method ``self._actor.X.remote(...)`` +# forwarding are gone: every worker now binds a process-local +# :class:`ServerState` directly to the shared :class:`StateBackend`, and the +# existing ``await state.X(...)`` call sites awaited a coroutine before and +# still do. + +_PROCESS_STATE_CACHE: dict[str, ServerState] = {} + + +def get_server_state(actor_name: str = 'twinkle_server_state', + backend: StateBackend | None = None, + persistence_config: PersistenceConfig | None = None, + signature_config: dict[str, Any] | None = None, + signature_policy: str = 'warn', + **kwargs) -> ServerState: + """Return a process-local :class:`ServerState` bound directly to the backend. + + No detached Ray Actor is created — every deployment / worker accesses the + shared persistence backend directly, which removes the actor as a + single-point bottleneck (R19.1, R19.2). Within one process the same + ``actor_name`` is cached so repeated callers share one ``ServerState`` + instance and the cleanup loop is started exactly once. Args: - actor_name: Name for the Ray actor (default: 'twinkle_server_state'). - **kwargs: Additional keyword arguments passed to ServerState constructor - (e.g., expiration_timeout, cleanup_interval). - - Returns: - A ServerStateProxy for interacting with the actor. + actor_name: Cache key for the per-process ``ServerState`` instance. + The legacy parameter name is kept for call-site compatibility. + backend: Optional :class:`StateBackend` to inject. When ``None`` a + backend is built from ``persistence_config`` (or env vars) via + :func:`create_backend`. + persistence_config: Optional :class:`PersistenceConfig`. Accepted as a + raw dict for YAML compatibility. + signature_config: Optional dict whose hash is validated against the + stored config signature on first access. + signature_policy: ``warn`` | ``error`` | ``ignore`` for signature drift. + **kwargs: Forwarded to the :class:`ServerState` constructor + (``expiration_timeout``, ``cleanup_interval``, ...). """ - try: - actor = ray.get_actor(actor_name) - except ValueError: - try: - _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**kwargs) - try: - ray.get(actor.start_cleanup_task.remote()) - except Exception as e: - logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') - except ValueError: - actor = ray.get_actor(actor_name) - assert actor is not None - return ServerStateProxy(actor) + if isinstance(persistence_config, dict): + persistence_config = PersistenceConfig(**persistence_config) + + if backend is None and persistence_config is None: + persistence_config = PersistenceConfig.from_env() + + cached = _PROCESS_STATE_CACHE.get(actor_name) + if cached is not None: + return cached + + state = ServerState( + backend=backend, + persistence_config=persistence_config, + signature_config=signature_config, + signature_policy=signature_policy, + **kwargs, + ) + _PROCESS_STATE_CACHE[actor_name] = state + # Cleanup task is started by the deployment's FastAPI ``lifespan`` hook + # via ``await state.start_cleanup_task()`` — that's the single async + # entry point each worker has, so we don't need any sync-context + # detection here. + return state + + +def reset_server_state_cache() -> None: + """Clear the per-process ServerState cache. + + Test-only helper. Production code should never need to reset state across + requests — workers reuse one instance for the lifetime of the process. + """ + _PROCESS_STATE_CACHE.clear() diff --git a/src/twinkle/server/utils/state/session_manager.py b/src/twinkle/server/state/session_manager.py similarity index 68% rename from src/twinkle/server/utils/state/session_manager.py rename to src/twinkle/server/state/session_manager.py index e7b154cbe..630c33d2f 100644 --- a/src/twinkle/server/utils/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -3,42 +3,46 @@ import time +from .backend.base import StateBackend from .base import BaseManager from .models import SessionRecord class SessionManager(BaseManager[SessionRecord]): - """ - Manages client sessions. + """Manages client sessions. Expiry is based on `last_heartbeat`; falls back to `created_at` if no heartbeat has been recorded yet. """ + def __init__(self, backend: StateBackend, expiration_timeout: float) -> None: + super().__init__(backend, 'session::', SessionRecord, expiration_timeout) + # ----- Session-specific operations ----- - def touch(self, session_id: str) -> bool: + async def touch(self, session_id: str) -> bool: """Update the heartbeat timestamp for a session. Returns: True if the session exists and was updated, False otherwise. """ - record = self._store.get(session_id) + record = await self.get(session_id) if record is None: return False record.last_heartbeat = time.time() + await self.add(session_id, record) return True - def get_last_heartbeat(self, session_id: str) -> float | None: + async def get_last_heartbeat(self, session_id: str) -> float | None: """Return the last heartbeat timestamp, or None if the session does not exist.""" - record = self._store.get(session_id) + record = await self.get(session_id) if record is None: return None return record.last_heartbeat # ----- Cleanup ----- - def cleanup_expired(self, cutoff_time: float) -> int: + async def cleanup_expired(self, cutoff_time: float, **kwargs) -> int: """Remove sessions whose last activity is older than cutoff_time. Args: @@ -47,28 +51,29 @@ def cleanup_expired(self, cutoff_time: float) -> int: Returns: Number of sessions removed. """ + all_records = await self.get_all() expired_ids = [] - for session_id, record in self._store.items(): + for session_id, record in all_records.items(): last_activity = record.last_heartbeat if last_activity == 0.0: - # Fallback: parse created_at last_activity = self._parse_timestamp(record.created_at) if last_activity < cutoff_time: expired_ids.append(session_id) for session_id in expired_ids: - del self._store[session_id] + await self.remove(session_id) return len(expired_ids) - def get_expired_ids(self, cutoff_time: float) -> list[str]: + async def get_expired_ids(self, cutoff_time: float) -> list[str]: """Return IDs of sessions that would be removed at the given cutoff. Used by ServerState to cascade-expire dependent resources before actually deleting the sessions. """ + all_records = await self.get_all() expired_ids = [] - for session_id, record in self._store.items(): + for session_id, record in all_records.items(): last_activity = record.last_heartbeat if last_activity == 0.0: last_activity = self._parse_timestamp(record.created_at) diff --git a/src/twinkle/server/telemetry/__init__.py b/src/twinkle/server/telemetry/__init__.py new file mode 100644 index 000000000..7976dfd3e --- /dev/null +++ b/src/twinkle/server/telemetry/__init__.py @@ -0,0 +1,17 @@ +from .metrics import MetricsRegistry +from .provider import TelemetryConfig, get_meter, init_telemetry, shutdown_telemetry +from .tracing import extract_context, get_current_span, get_tracer, inject_context +from .worker_init import ensure_telemetry_initialized + +__all__ = [ + 'MetricsRegistry', + 'TelemetryConfig', + 'get_meter', + 'init_telemetry', + 'shutdown_telemetry', + 'get_tracer', + 'inject_context', + 'extract_context', + 'get_current_span', + 'ensure_telemetry_initialized', +] diff --git a/src/twinkle/server/telemetry/context_carrier.py b/src/twinkle/server/telemetry/context_carrier.py new file mode 100644 index 000000000..1cb3f7714 --- /dev/null +++ b/src/twinkle/server/telemetry/context_carrier.py @@ -0,0 +1,87 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Trace context carrier helpers for cross-deployment propagation (R13). + +``DeploymentHandle`` calls between Ray Serve deployments do not go through +HTTP, so the existing ``inject_context(headers)`` path doesn't apply. To keep +a single trace continuous across an internal handle hop, callers serialize +the active OpenTelemetry context into a small dict using :func:`make_carrier` +and pass it as a kwarg; the receiving side calls :func:`activate_carrier` +inside its handler so subsequent spans attach as children of the propagated +context. + +When the OTEL SDK is missing, both helpers degrade to a NoOp: ``make_carrier`` +returns an empty dict and ``activate_carrier`` is a no-op context manager +that runs the body without raising (R13.4 / R18.3). + +Note on current wiring (R13.3): the present server topology routes every +cross-deployment hop through the Gateway's localhost HTTP proxy, which already +carries the trace context in HTTP headers via :func:`twinkle.server.telemetry. +tracing.inject_context`. There are therefore no in-process ``DeploymentHandle`` +call sites in ``src/`` today to thread the carrier through. These helpers are +the supported integration point for any future deployment-to-deployment handle +calls (e.g. Model → Sampler over a non-HTTP path); add ``trace_context: dict | +None = None`` to the handle signature, build it on the caller with +:func:`make_carrier`, and wrap the receiver body in :func:`activate_carrier` +to preserve a single trace id across the hop. +""" +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Iterator, Mapping + +try: + from opentelemetry import context as _otel_context # type: ignore + from opentelemetry.propagate import extract as _otel_extract # type: ignore + from opentelemetry.propagate import inject as _otel_inject + + _OTEL_AVAILABLE = True +except Exception: + _OTEL_AVAILABLE = False + + +def make_carrier() -> dict[str, Any]: + """Inject the active trace context into a fresh dict (R13.1). + + Returns an empty dict when OTEL is not installed; the receiving side's + :func:`activate_carrier` handles that case gracefully. + """ + carrier: dict[str, Any] = {} + if not _OTEL_AVAILABLE: + return carrier + try: + _otel_inject(carrier) + except Exception: + # NEVER let a tracing failure block the underlying RPC. + return {} + return carrier + + +@contextmanager +def activate_carrier(carrier: Mapping[str, Any] | None) -> Iterator[None]: + """Attach the trace context from ``carrier`` for the lifetime of the block. + + Subsequent spans started inside the block become children of the + propagated context (R13.2). When ``carrier`` is ``None`` or empty, or + when OTEL is not installed, the block runs without attaching any + context and a fresh trace is started (R13.4). + """ + if not _OTEL_AVAILABLE or not carrier: + yield + return + + try: + ctx = _otel_extract(dict(carrier)) + except Exception: + yield + return + + token = None + try: + token = _otel_context.attach(ctx) + yield + finally: + if token is not None: + try: + _otel_context.detach(token) + except Exception: + pass diff --git a/src/twinkle/server/telemetry/correlation.py b/src/twinkle/server/telemetry/correlation.py new file mode 100644 index 000000000..bddfe745c --- /dev/null +++ b/src/twinkle/server/telemetry/correlation.py @@ -0,0 +1,47 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Correlation key constants for business-layer spans (R11). + +All correlation attributes share the ``twinkle.`` prefix so operators can +filter Tempo / Loki by ``twinkle.session_id``, ``twinkle.model_id``, etc. +The ``set_correlation_attrs`` helper attaches only present (non-None) values +so partially-known operations don't end up with empty attributes. +""" +from __future__ import annotations + +from typing import Any, Mapping + +PREFIX = 'twinkle.' + +SESSION_ID = f'{PREFIX}session_id' +MODEL_ID = f'{PREFIX}model_id' +REPLICA_ID = f'{PREFIX}replica_id' +TOKEN_ID = f'{PREFIX}token_id' +SAMPLING_SESSION_ID = f'{PREFIX}sampling_session_id' +BASE_MODEL = f'{PREFIX}base_model' + +CORRELATION_KEYS: tuple[str, ...] = ( + SESSION_ID, + MODEL_ID, + REPLICA_ID, + TOKEN_ID, + SAMPLING_SESSION_ID, + BASE_MODEL, +) + + +def set_correlation_attrs(span: Any, values: Mapping[str, Any] | None) -> None: + """Attach the given correlation attributes to ``span``. + + Skips ``None`` values entirely (R11.2: ``when that value is available``) + and is a no-op for NoOp spans returned when the OTEL SDK is not + installed. + """ + if not values or span is None: + return + setter = getattr(span, 'set_attribute', None) + if setter is None: + return + for key, value in values.items(): + if value is None: + continue + setter(key, value) diff --git a/src/twinkle/server/telemetry/metrics.py b/src/twinkle/server/telemetry/metrics.py new file mode 100644 index 000000000..fbc3de9a0 --- /dev/null +++ b/src/twinkle/server/telemetry/metrics.py @@ -0,0 +1,86 @@ +"""Twinkle Server metrics registry — low-invasiveness facade over OpenTelemetry metrics.""" + +from __future__ import annotations + +from .provider import get_meter + + +class MetricsRegistry: + """Centrally declares all metrics. Business code retrieves singleton via MetricsRegistry.get(). + + When telemetry is not initialized, OTEL returns a NoOp meter and all recording operations are silently no-op. + """ + + _instance: MetricsRegistry | None = None + + def __init__(self) -> None: + meter = get_meter('twinkle-server') + + # === HTTP Requests === + self.requests_total = meter.create_counter( + 'twinkle.http.requests.total', + description='Total HTTP requests received', + ) + self.request_duration_seconds = meter.create_histogram( + 'twinkle.http.request.duration_seconds', + description='HTTP request duration in seconds', + unit='s', + ) + + # === Task Queue === + self.queue_depth = meter.create_up_down_counter( + 'twinkle.queue.depth', + description='Current task queue depth', + ) + self.task_execution_seconds = meter.create_histogram( + 'twinkle.task.execution_seconds', + description='Task execution duration in seconds', + unit='s', + ) + self.task_wait_seconds = meter.create_histogram( + 'twinkle.task.wait_seconds', + description='Task wait time in queue before execution', + unit='s', + ) + self.rate_limit_rejections = meter.create_counter( + 'twinkle.rate_limit.rejections.total', + description='Total requests rejected by rate limiter', + ) + self.tasks_total = meter.create_counter( + 'twinkle.tasks.total', + description='Total task completions, partitioned by status', + ) + self.rate_limiter_active_tokens = meter.create_up_down_counter( + 'twinkle.rate_limiter.active_tokens', + description='Number of tokens currently tracked by the rate limiter', + ) + + # === Resources === + self.active_sessions = meter.create_up_down_counter( + 'twinkle.sessions.active', + description='Number of active client sessions', + ) + self.active_models = meter.create_up_down_counter( + 'twinkle.models.active', + description='Number of registered models', + ) + self.active_sampling_sessions = meter.create_up_down_counter( + 'twinkle.sampling_sessions.active', + description='Number of active sampling sessions', + ) + self.active_futures = meter.create_up_down_counter( + 'twinkle.futures.active', + description='Number of pending futures/tasks', + ) + + @classmethod + def get(cls) -> MetricsRegistry: + """Retrieve global MetricsRegistry singleton. Created on first call.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + """Reset singleton (for testing or telemetry re-initialization).""" + cls._instance = None diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py new file mode 100644 index 000000000..a8ec970f5 --- /dev/null +++ b/src/twinkle/server/telemetry/provider.py @@ -0,0 +1,274 @@ +"""OpenTelemetry provider initialization for Twinkle server. + +Bootstraps the three OTEL pillars (traces, metrics, logs) with either an OTLP +or a console exporter. Designed to be a thin, side-effect-driven module that +exposes: + +- ``TelemetryConfig``: pydantic configuration model +- ``init_telemetry``: entry point that wires up global providers +- ``shutdown_telemetry``: graceful teardown for the global providers +- ``get_meter``: convenience accessor used by ``MetricsRegistry`` +""" + +from __future__ import annotations + +import logging +from pydantic import BaseModel, ConfigDict, Field +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional OTEL imports — keep them lazy/guarded so that a missing optional +# dependency does not break the rest of the server. +# --------------------------------------------------------------------------- +try: + from opentelemetry import metrics, trace + from opentelemetry._logs import set_logger_provider + from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, ConsoleLogExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter + + _OTEL_AVAILABLE = True + _OTEL_IMPORT_ERROR: BaseException | None = None +except Exception as exc: # pragma: no cover - defensive fallback + _OTEL_AVAILABLE = False + _OTEL_IMPORT_ERROR = exc + +# OTLP exporters are a separate optional dependency from the SDK itself. +try: + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + _OTLP_AVAILABLE = True +except Exception: # pragma: no cover - defensive fallback + _OTLP_AVAILABLE = False + +# Logging instrumentor is also optional. +try: + from opentelemetry.instrumentation.logging import LoggingInstrumentor + + _LOGGING_INSTRUMENTOR_AVAILABLE = True +except Exception: # pragma: no cover - defensive fallback + _LOGGING_INSTRUMENTOR_AVAILABLE = False + +# --------------------------------------------------------------------------- +# Module-level state for shutdown. +# --------------------------------------------------------------------------- +_tracer_provider: Any | None = None +_meter_provider: Any | None = None +_logger_provider: Any | None = None +_logging_handler: Any | None = None +_initialized: bool = False + + +class _LoggingWriter: + """IO adapter that routes writes to Python logging. + + Used to redirect ConsoleExporter output through the logging system + so that Ray Serve worker logs are properly captured. Without this, + ConsoleSpanExporter / ConsoleMetricExporter write directly to + ``sys.stdout`` which Ray reroutes into its internal log files, + making telemetry output invisible to the standard logging pipeline. + """ + + def __init__(self, logger_name: str = 'twinkle.server.telemetry.export'): + self._logger = logging.getLogger(logger_name) + + def write(self, text: str) -> int: + text = text.strip() + if text: + self._logger.info(text) + return len(text) + + def flush(self) -> None: + pass + + +class TelemetryConfig(BaseModel): + """Configuration for the OpenTelemetry pipeline.""" + + model_config = ConfigDict(extra='forbid') + + enabled: bool = False + service_name: str = 'twinkle-server' + otlp_endpoint: str = 'http://localhost:4317' + debug: bool = False # True: Console Exporter; False: OTLP Exporter + export_interval_ms: int = 30000 + resource_attributes: dict = Field(default_factory=dict) + + +def init_telemetry(config: TelemetryConfig) -> None: + """Initialize the three OTEL pillars (traces, metrics, logs). + + No-op when ``config.enabled`` is False or the OTEL SDK is missing. + """ + global _tracer_provider, _meter_provider, _logger_provider + global _logging_handler, _initialized + + if not config.enabled: + return + + if not _OTEL_AVAILABLE: + logger.warning( + 'OpenTelemetry SDK not available, skipping telemetry init: %s', + _OTEL_IMPORT_ERROR, + ) + return + + if _initialized: + logger.debug('Telemetry already initialized; skipping re-init.') + return + + # ---- Resource ------------------------------------------------------- + resource_attrs: dict = {'service.name': config.service_name} + if config.resource_attributes: + resource_attrs.update(config.resource_attributes) + resource = Resource.create(resource_attrs) + + use_console = config.debug or not _OTLP_AVAILABLE + if config.debug is False and not _OTLP_AVAILABLE: + logger.warning('OTLP exporters not available; falling back to console exporters.') + + # When using console exporters, route their output through the Python + # logging system so that Ray Serve workers actually surface the data. + _console_writer = _LoggingWriter() if use_console else None + + # ---- Traces --------------------------------------------------------- + if use_console: + span_exporter = ConsoleSpanExporter(out=_console_writer) + else: + span_exporter = OTLPSpanExporter(endpoint=config.otlp_endpoint) + + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(BatchSpanProcessor(span_exporter)) + trace.set_tracer_provider(tracer_provider) + _tracer_provider = tracer_provider + + # ---- Metrics -------------------------------------------------------- + if use_console: + metric_exporter = ConsoleMetricExporter(out=_console_writer) + else: + metric_exporter = OTLPMetricExporter(endpoint=config.otlp_endpoint) + + metric_reader = PeriodicExportingMetricReader( + metric_exporter, + export_interval_millis=config.export_interval_ms, + ) + meter_provider = MeterProvider( + resource=resource, + metric_readers=[metric_reader], + ) + metrics.set_meter_provider(meter_provider) + _meter_provider = meter_provider + + # ---- Logs ----------------------------------------------------------- + if use_console: + log_exporter = ConsoleLogExporter() + else: + log_exporter = OTLPLogExporter(endpoint=config.otlp_endpoint) + + logger_provider = LoggerProvider(resource=resource) + logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) + set_logger_provider(logger_provider) + _logger_provider = logger_provider + + # Bridge Python logging -> OTEL logs. + if _LOGGING_INSTRUMENTOR_AVAILABLE: + try: + LoggingInstrumentor().instrument(set_logging_format=True) + except Exception as exc: # pragma: no cover - defensive + logger.warning('LoggingInstrumentor failed to instrument: %s', exc) + + handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider) + # Attach to BOTH the root logger and the ``twinkle`` namespace logger. + # ``twinkle.utils.logger`` configures the ``twinkle`` logger with + # ``propagate=False`` and its own StreamHandler, so log records emitted + # under ``twinkle.*`` (which is the entire server codebase) never bubble + # up to root and would be invisible to an OTLP handler bound there only. + logging.getLogger().addHandler(handler) + logging.getLogger('twinkle').addHandler(handler) + _logging_handler = handler + + _initialized = True + logger.info( + 'Telemetry initialized (service=%s, debug=%s, otlp_endpoint=%s)', + config.service_name, + config.debug, + config.otlp_endpoint, + ) + + +def shutdown_telemetry() -> None: + """Shutdown all OTEL providers and detach the logging handler.""" + global _tracer_provider, _meter_provider, _logger_provider + global _logging_handler, _initialized + + if _logging_handler is not None: + for logger_name in ('', 'twinkle'): + try: + logging.getLogger(logger_name).removeHandler(_logging_handler) + except Exception as exc: # pragma: no cover - defensive + logger.warning('Failed to detach logging handler from %r: %s', logger_name, exc) + _logging_handler = None + + if _tracer_provider is not None: + try: + _tracer_provider.shutdown() + except Exception as exc: # pragma: no cover - defensive + logger.warning('TracerProvider shutdown failed: %s', exc) + _tracer_provider = None + + if _meter_provider is not None: + try: + _meter_provider.shutdown() + except Exception as exc: # pragma: no cover - defensive + logger.warning('MeterProvider shutdown failed: %s', exc) + _meter_provider = None + + if _logger_provider is not None: + try: + _logger_provider.shutdown() + except Exception as exc: # pragma: no cover - defensive + logger.warning('LoggerProvider shutdown failed: %s', exc) + _logger_provider = None + + _initialized = False + + +class _NoopInstrument: + """No-op instrument for when OTEL SDK is not available.""" + + def add(self, *args, **kwargs): + pass + + def record(self, *args, **kwargs): + pass + + +class _NoopMeter: + """No-op meter for when OTEL SDK is not available.""" + + def create_counter(self, *args, **kwargs): + return _NoopInstrument() + + def create_up_down_counter(self, *args, **kwargs): + return _NoopInstrument() + + def create_histogram(self, *args, **kwargs): + return _NoopInstrument() + + +_noop_meter = _NoopMeter() + + +def get_meter(name: str = 'twinkle-server'): + """Return an OTEL meter. Returns NoOp meter if OTEL SDK is not available.""" + if not _OTEL_AVAILABLE: + return _noop_meter + return metrics.get_meter(name) diff --git a/src/twinkle/server/telemetry/resource_metrics.py b/src/twinkle/server/telemetry/resource_metrics.py new file mode 100644 index 000000000..112367160 --- /dev/null +++ b/src/twinkle/server/telemetry/resource_metrics.py @@ -0,0 +1,193 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Resource (CPU / Memory / GPU) observable gauges (R12). + +Registers OTEL ``observable_gauge`` instruments for CPU utilization, memory +usage (system + process), GPU utilization, and GPU memory. The data sources +(``psutil`` for CPU/memory, ``pynvml`` for GPU) are **optional telemetry +dependencies** — when either is missing, the corresponding gauges report no +data and the collector does not raise (R12.3, R18.3). + +The collector is started by :func:`worker_init.ensure_telemetry_initialized` +in each Ray Serve worker process so per-replica resource usage shows up in +Mimir / Grafana. +""" +from __future__ import annotations + +import os +from typing import Any, Iterator + +from .provider import get_meter + +try: + import psutil # type: ignore + _PSUTIL_AVAILABLE = True +except Exception: + _PSUTIL_AVAILABLE = False + +try: + import pynvml # type: ignore + _PYNVML_AVAILABLE = True +except Exception: + _PYNVML_AVAILABLE = False + +_NVML_INITIALIZED = False + + +def _nvml_handle_count() -> int: + """Return GPU count, initializing pynvml lazily. ``0`` if unavailable.""" + global _NVML_INITIALIZED + if not _PYNVML_AVAILABLE: + return 0 + try: + if not _NVML_INITIALIZED: + pynvml.nvmlInit() + _NVML_INITIALIZED = True + return int(pynvml.nvmlDeviceGetCount()) + except Exception: + return 0 + + +class ResourceMetricsCollector: + """Owns the observable-gauge callbacks for CPU/Mem/GPU.""" + + def __init__(self) -> None: + self._started = False + # Track which gauges were registered so callers (and tests) can + # introspect what's exported in this process. + self.registered_gauges: list[str] = [] + + def maybe_start(self) -> None: + """Register the available gauges; idempotent and never raises.""" + if self._started: + return + self._started = True + try: + meter = get_meter('twinkle.server.resource') + except Exception: + return + + # CPU + memory require psutil. + if _PSUTIL_AVAILABLE: + try: + meter.create_observable_gauge( + 'twinkle.system.cpu.utilization', + description='System CPU utilization (0..1)', + callbacks=[self._cpu_utilization_callback], + ) + self.registered_gauges.append('twinkle.system.cpu.utilization') + meter.create_observable_gauge( + 'twinkle.system.memory.usage_bytes', + description='System memory used in bytes', + callbacks=[self._memory_usage_callback], + ) + self.registered_gauges.append('twinkle.system.memory.usage_bytes') + meter.create_observable_gauge( + 'twinkle.process.memory.usage_bytes', + description='Resident-set memory of this process in bytes', + callbacks=[self._process_memory_callback], + ) + self.registered_gauges.append('twinkle.process.memory.usage_bytes') + except Exception: + pass + + # GPU requires pynvml AND at least one GPU device — without either, + # we silently skip GPU gauges (R12.3 / R18.3). + if _nvml_handle_count() > 0: + try: + meter.create_observable_gauge( + 'twinkle.gpu.utilization', + description='Per-GPU utilization (0..1)', + callbacks=[self._gpu_utilization_callback], + ) + self.registered_gauges.append('twinkle.gpu.utilization') + meter.create_observable_gauge( + 'twinkle.gpu.memory.usage_bytes', + description='Per-GPU memory used in bytes', + callbacks=[self._gpu_memory_callback], + ) + self.registered_gauges.append('twinkle.gpu.memory.usage_bytes') + except Exception: + pass + + # ----- callbacks ----------------------------------------------------- # + + @staticmethod + def _cpu_utilization_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + if not _PSUTIL_AVAILABLE: + return iter(()) + try: + value = float(psutil.cpu_percent(interval=None)) / 100.0 + return iter([Observation(value)]) + except Exception: + return iter(()) + + @staticmethod + def _memory_usage_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + if not _PSUTIL_AVAILABLE: + return iter(()) + try: + return iter([Observation(int(psutil.virtual_memory().used))]) + except Exception: + return iter(()) + + @staticmethod + def _process_memory_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + if not _PSUTIL_AVAILABLE: + return iter(()) + try: + rss = psutil.Process(os.getpid()).memory_info().rss + return iter([Observation(int(rss))]) + except Exception: + return iter(()) + + @staticmethod + def _gpu_utilization_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + count = _nvml_handle_count() + out: list[Any] = [] + for i in range(count): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + out.append(Observation(float(util.gpu) / 100.0, {'gpu_index': i})) + except Exception: + continue + return iter(out) + + @staticmethod + def _gpu_memory_callback(_options: Any) -> Iterator[Any]: + from opentelemetry.metrics import Observation # type: ignore + + count = _nvml_handle_count() + out: list[Any] = [] + for i in range(count): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + out.append(Observation(int(mem.used), {'gpu_index': i})) + except Exception: + continue + return iter(out) + + +_GLOBAL_COLLECTOR: ResourceMetricsCollector | None = None + + +def get_collector() -> ResourceMetricsCollector: + global _GLOBAL_COLLECTOR + if _GLOBAL_COLLECTOR is None: + _GLOBAL_COLLECTOR = ResourceMetricsCollector() + return _GLOBAL_COLLECTOR + + +def reset_collector_for_tests() -> None: + """Clear the module-global collector. Test-only helper.""" + global _GLOBAL_COLLECTOR + _GLOBAL_COLLECTOR = None diff --git a/src/twinkle/server/telemetry/tracing.py b/src/twinkle/server/telemetry/tracing.py new file mode 100644 index 000000000..770b4b3a1 --- /dev/null +++ b/src/twinkle/server/telemetry/tracing.py @@ -0,0 +1,171 @@ +"""Twinkle Server tracing utilities — thin wrapper over OpenTelemetry tracing.""" + +from __future__ import annotations + +from contextlib import contextmanager +from fastapi import Request +from typing import Any, Iterator, Mapping + +try: + from opentelemetry import trace + from opentelemetry.context import Context + from opentelemetry.propagate import extract, inject + _OTEL_AVAILABLE = True +except Exception: + _OTEL_AVAILABLE = False + +from .correlation import set_correlation_attrs + + +def get_tracer(name: str = 'twinkle-server'): + """Retrieve tracer instance. Returns NoOp tracer when OTEL is not installed.""" + if not _OTEL_AVAILABLE: + return _NoopTracer() + return trace.get_tracer(name) + + +def inject_context(carrier: dict) -> None: + """Inject current trace context into carrier. Noop when OTEL is not installed.""" + if not _OTEL_AVAILABLE: + return + inject(carrier) + + +def extract_context(carrier: dict): + """Extract trace context from carrier. Returns empty context when OTEL is not installed.""" + if not _OTEL_AVAILABLE: + return None + return extract(carrier) + + +def get_current_span(): + """Get current active span. Returns noop span when OTEL is not installed.""" + if not _OTEL_AVAILABLE: + return _NoopSpan() + return trace.get_current_span() + + +class _NoopSpan: + """Minimal noop span for when OTEL is not available.""" + + def set_attribute(self, *args, **kwargs): + pass + + def set_status(self, *args, **kwargs): + pass + + def add_event(self, *args, **kwargs): + pass + + def end(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + +class _NoopTracer: + """Minimal noop tracer for when OTEL is not available.""" + + def start_as_current_span(self, name, **kwargs): + return _NoopSpan() + + def start_span(self, name, **kwargs): + return _NoopSpan() + + +@contextmanager +def traced_operation( + name: str, + *, + attrs: Mapping[str, Any] | None = None, + tracer_name: str = 'twinkle.server.business', +) -> Iterator[Any]: + """Run a business-layer block under one OTEL span (R10). + + The span starts before the block runs and ends after it returns. If the + block raises, the exception is recorded on the span, the span status is + set to ERROR, the span is ended, and the original exception is re-raised + to the caller (R10.4). When the OTEL SDK is missing, the context manager + degrades to a NoOp that runs the block normally and returns the same + result it would return when tracing is active (R10.5 / R18.3). + """ + if not _OTEL_AVAILABLE: + yield _NoopSpan() + return + + tracer = trace.get_tracer(tracer_name) + with tracer.start_as_current_span(name) as span: + if attrs: + set_correlation_attrs(span, attrs) + try: + yield span + except Exception as exc: + try: + span.record_exception(exc) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(exc))) + except Exception: + # NEVER let a tracing error mask the underlying exception. + pass + raise + + +def create_tracing_middleware(service_component: str): + """Create an HTTP tracing middleware compatible with Ray Serve pickling. + + Unlike ``FastAPIInstrumentor.instrument_app`` which attaches unpicklable + references (e.g. ``_thread.lock``) to the FastAPI app and breaks Ray Serve + deployment pickling, the returned middleware is a plain async function + closing only over the ``service_component`` string. + + When OpenTelemetry is not installed, returns a passthrough middleware so + the server still works without the optional dependency. + + Args: + service_component: Logical service name used as the tracer name suffix + and recorded as a span attribute (e.g. ``Gateway``, ``Model``, + ``Processor``, ``Sampler``). + + Returns: + An async FastAPI HTTP middleware function. + """ + if not _OTEL_AVAILABLE: + + async def passthrough_middleware(request: Request, call_next): + return await call_next(request) + + return passthrough_middleware + + async def tracing_middleware(request: Request, call_next): + tracer = trace.get_tracer(f'twinkle.server.{service_component}') + + method = request.method + path = request.url.path + span_name = f'{method} {path}' + + with tracer.start_as_current_span( + span_name, + kind=trace.SpanKind.SERVER, + attributes={ + 'http.method': method, + 'http.url': str(request.url), + 'http.route': path, + 'http.scheme': request.url.scheme, + 'service.component': service_component, + }, + ) as span: + try: + response = await call_next(request) + span.set_attribute('http.status_code', response.status_code) + if response.status_code >= 400: + span.set_status(trace.Status(trace.StatusCode.ERROR)) + return response + except Exception as exc: + span.set_status(trace.Status(trace.StatusCode.ERROR, str(exc))) + span.record_exception(exc) + raise + + return tracing_middleware diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py new file mode 100644 index 000000000..f4ac60cfd --- /dev/null +++ b/src/twinkle/server/telemetry/worker_init.py @@ -0,0 +1,71 @@ +"""Telemetry initialization for Ray worker processes. + +Ray Serve deployments run in separate worker processes that do not inherit +the driver process's OTEL global state. This module provides a function +to re-initialize telemetry in each worker process using environment variables +set by the launcher. +""" +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + +_worker_initialized = False + + +def ensure_telemetry_initialized() -> None: + """Initialize telemetry in the current worker process if not already done. + + Reads configuration from environment variables set by the launcher process. + Safe to call multiple times - only initializes once per process. + """ + global _worker_initialized + if _worker_initialized: + return + + _worker_initialized = True + + telemetry_enabled = os.environ.get('TWINKLE_TELEMETRY_ENABLED') == '1' + + if not telemetry_enabled: + # Even with telemetry disabled, register the resource collector so + # graceful-degradation behavior matches the enabled path (R12.2/R18.3). + _start_resource_collector() + return + + try: + from twinkle.server.telemetry import TelemetryConfig, init_telemetry + from twinkle.server.telemetry.metrics import MetricsRegistry + + config = TelemetryConfig( + enabled=True, + debug=os.environ.get('TWINKLE_TELEMETRY_DEBUG', '0') == '1', + service_name=os.environ.get('TWINKLE_TELEMETRY_SERVICE', 'twinkle-server'), + otlp_endpoint=os.environ.get('TWINKLE_TELEMETRY_ENDPOINT', 'http://localhost:4317'), + export_interval_ms=int(os.environ.get('TWINKLE_TELEMETRY_INTERVAL', '30000')), + ) + init_telemetry(config) + # Reset MetricsRegistry singleton so it picks up the real MeterProvider + MetricsRegistry.reset() + logger.info(f'Worker telemetry initialized (service={config.service_name}, debug={config.debug})') + except Exception as e: + logger.warning(f'Failed to initialize worker telemetry: {e}') + + _start_resource_collector() + + +def _start_resource_collector() -> None: + """Start the resource (CPU / Memory / GPU) metrics collector. + + Safe to call even when telemetry init was skipped or failed — the + collector picks up the NoOp meter and silently records no observations + (R12.2 / R18.3). + """ + try: + from twinkle.server.telemetry import resource_metrics + + resource_metrics.get_collector().maybe_start() + except Exception as e: + logger.debug(f'Resource metrics collector start failed: {e}') diff --git a/src/twinkle/server/utils/lifecycle/base.py b/src/twinkle/server/utils/lifecycle/base.py index 6c1c1b57a..394cd2200 100644 --- a/src/twinkle/server/utils/lifecycle/base.py +++ b/src/twinkle/server/utils/lifecycle/base.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy + from twinkle.server.state import ServerState from twinkle.utils.logger import get_logger @@ -38,7 +38,7 @@ class SessionResourceMixin: """ # Type hint for state attribute that inheriting classes must provide - state: ServerStateProxy + state: ServerState # Resource type name for logging (override in subclass) _resource_type: str = 'resource' diff --git a/src/twinkle/server/utils/metrics.py b/src/twinkle/server/utils/metrics.py index eee915d78..ff1fe7f29 100644 --- a/src/twinkle/server/utils/metrics.py +++ b/src/twinkle/server/utils/metrics.py @@ -2,98 +2,109 @@ """ Central metrics module for Twinkle server observability. -Provides ray.util.metrics instruments that feed both the Ray Dashboard -(port 8265) and Prometheus (via /api/prometheus). +This module is a *thin adapter layer* over the OpenTelemetry instruments +declared in :class:`twinkle.server.telemetry.metrics.MetricsRegistry`. +It preserves the legacy Ray-style API (``.inc()`` / ``.set()`` / ``.observe()`` +with ``tags=`` keyword) so that existing call sites do not need to change +their call patterns, while routing every measurement through OTEL. -All metric names use the ``twinkle_`` prefix. Metric instances are -cached per deployment to avoid duplicate registration. - -Public entry-points: +Public entry-points (unchanged signatures): * ``create_metrics_middleware(deployment)`` – FastAPI HTTP middleware -* ``get_task_metrics(deployment)`` – task-queue / rate-limit gauges -* ``get_resource_metrics()`` – ServerState resource gauges +* ``get_task_metrics(deployment)`` – task-queue / rate-limit gauges """ from __future__ import annotations import time from pydantic import BaseModel, ConfigDict -from ray.util.metrics import Counter, Gauge, Histogram from typing import Any, Callable +from twinkle.server.telemetry import MetricsRegistry from twinkle.utils.logger import get_logger logger = get_logger() # --------------------------------------------------------------------------- -# Histogram bucket boundaries (seconds) – shared by all histograms -# --------------------------------------------------------------------------- -_HISTOGRAM_BOUNDARIES = [ - 0.01, - 0.05, - 0.1, - 0.25, - 0.5, - 1.0, - 2.5, - 5.0, - 10.0, - 30.0, - 60.0, - 120.0, - 300.0, -] - -# --------------------------------------------------------------------------- -# Lazy caches – populated on first call per deployment / globally +# Lazy caches – populated on first call per deployment # --------------------------------------------------------------------------- _task_metrics_cache: dict[str, TaskMetrics] = {} -_resource_metrics_cache: ResourceMetrics | None = None _request_metrics_cache: dict[str, _RequestMetrics] = {} # --------------------------------------------------------------------------- -# Pydantic models for structured metric access +# Adapter classes – wrap OTEL instruments to expose the legacy Ray-style API +# (``.inc(tags=...)`` / ``.set(value, tags=...)`` / ``.observe(value, tags=...)``) +# while delegating all measurements to OpenTelemetry. # --------------------------------------------------------------------------- -class TaskMetrics(BaseModel): - """Task queue metrics container. +class _Counter: + """Adapter mapping ``.inc(value, tags=...)`` to ``otel_counter.add()``.""" - Attributes: - queue_depth: Current number of queued tasks. - tasks_total: Total task completions. - execution_seconds: Pure task execution time in seconds. - queue_wait_seconds: Time from enqueue to execution start. - rate_limit_rejections: Total rate-limit rejections. - rate_limiter_active_tokens: Tokens tracked by rate limiter. + def __init__(self, instrument: Any) -> None: + self._instrument = instrument + + def inc(self, value: float = 1.0, tags: dict[str, str] | None = None) -> None: + self._instrument.add(value, attributes=tags or {}) + + +class _Histogram: + """Adapter mapping ``.observe(value, tags=...)`` to ``otel_histogram.record()``.""" + + def __init__(self, instrument: Any) -> None: + self._instrument = instrument + + def observe(self, value: float, tags: dict[str, str] | None = None) -> None: + self._instrument.record(value, attributes=tags or {}) + + +class _Gauge: + """Adapter mapping ``.set(value, tags=...)`` onto an OTEL UpDownCounter. + + OpenTelemetry up/down counters take *deltas*, not absolute values, so we + track the last reported value per attribute combination and emit the + incremental change. State is held per adapter instance (= per deployment), + keyed by the frozen attribute tuple. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + def __init__(self, instrument: Any) -> None: + self._instrument = instrument + self._last: dict[tuple, float] = {} - queue_depth: Gauge - tasks_total: Counter - execution_seconds: Histogram - queue_wait_seconds: Histogram - rate_limit_rejections: Counter - rate_limiter_active_tokens: Gauge + def set(self, value: float, tags: dict[str, str] | None = None) -> None: + attrs = tags or {} + key = tuple(sorted(attrs.items())) + last = self._last.get(key, 0.0) + delta = value - last + if delta != 0: + self._instrument.add(delta, attributes=attrs) + self._last[key] = value -class ResourceMetrics(BaseModel): - """Resource gauge metrics container. +# --------------------------------------------------------------------------- +# Pydantic containers for structured metric access +# --------------------------------------------------------------------------- + + +class TaskMetrics(BaseModel): + """Task queue metrics container. Attributes: - active_sessions: Current active session count. - active_models: Current registered model count. - active_sampling_sessions: Current sampling session count. - active_futures: Current future/request count. + queue_depth: Current number of queued tasks (gauge). + tasks_total: Total task completions (counter). + execution_seconds: Pure task execution time in seconds (histogram). + queue_wait_seconds: Time from enqueue to execution start (histogram). + rate_limit_rejections: Total rate-limit rejections (counter). + rate_limiter_active_tokens: Tokens tracked by rate limiter (gauge). """ model_config = ConfigDict(arbitrary_types_allowed=True) - active_sessions: Gauge - active_models: Gauge - active_sampling_sessions: Gauge - active_futures: Gauge + queue_depth: _Gauge + tasks_total: _Counter + execution_seconds: _Histogram + queue_wait_seconds: _Histogram + rate_limit_rejections: _Counter + rate_limiter_active_tokens: _Gauge class _RequestMetrics(BaseModel): @@ -101,8 +112,8 @@ class _RequestMetrics(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - requests_total: Counter - request_duration_seconds: Histogram + requests_total: _Counter + request_duration_seconds: _Histogram # --------------------------------------------------------------------------- @@ -111,22 +122,14 @@ class _RequestMetrics(BaseModel): def _get_request_metrics(deployment: str) -> _RequestMetrics: - """Return (or create) per-deployment HTTP request metrics.""" + """Return (or create) per-deployment HTTP request metric adapters.""" if deployment in _request_metrics_cache: return _request_metrics_cache[deployment] + reg = MetricsRegistry.get() metrics = _RequestMetrics( - requests_total=Counter( - 'twinkle_requests_total', - description='Total HTTP requests.', - tag_keys=('deployment', 'method', 'status'), - ), - request_duration_seconds=Histogram( - 'twinkle_request_duration_seconds', - description='End-to-end HTTP request latency in seconds.', - boundaries=_HISTOGRAM_BOUNDARIES, - tag_keys=('deployment', 'method'), - ), + requests_total=_Counter(reg.requests_total), + request_duration_seconds=_Histogram(reg.request_duration_seconds), ) _request_metrics_cache[deployment] = metrics return metrics @@ -138,12 +141,16 @@ def create_metrics_middleware(deployment: str) -> Callable: Usage inside a ``build_*_app()`` function:: from twinkle.server.utils.metrics import create_metrics_middleware - metrics_mw = create_metrics_middleware("Model") - app.middleware('http')(metrics_mw) + from twinkle.server.telemetry.tracing import create_tracing_middleware + + app.middleware('http')(verify_token) + app.middleware('http')(create_tracing_middleware("Model")) + app.middleware('http')(create_metrics_middleware("Model")) # outermost - Because FastAPI executes middleware in LIFO order, registering this - **after** ``verify_token`` means it wraps the outermost layer and - captures full end-to-end latency including auth. + FastAPI executes middleware in LIFO order, so the **last** middleware + registered is the outermost wrapper. Register metrics last so its + latency observation covers the full request path including tracing + overhead and authentication. """ async def metrics_middleware(request: Any, call_next: Callable) -> Any: @@ -174,94 +181,25 @@ async def metrics_middleware(request: Any, call_next: Callable) -> Any: def get_task_metrics(deployment: str) -> TaskMetrics: - """Return (or create) per-deployment task-queue metrics. + """Return (or create) per-deployment task-queue metric adapters. - Returns a :class:`TaskMetrics` Pydantic model with: - - - ``queue_depth`` – Gauge - - ``tasks_total`` – Counter - - ``execution_seconds`` – Histogram - - ``queue_wait_seconds`` – Histogram - - ``rate_limit_rejections`` – Counter - - ``rate_limiter_active_tokens`` – Gauge + Returns a :class:`TaskMetrics` container of adapter objects; the + adapters delegate every measurement to the OTEL instruments held by + :class:`twinkle.server.telemetry.metrics.MetricsRegistry`. A separate + adapter instance is cached per deployment so that gauge-state tracking + (last value per attribute set) stays isolated. """ if deployment in _task_metrics_cache: return _task_metrics_cache[deployment] + reg = MetricsRegistry.get() metrics = TaskMetrics( - queue_depth=Gauge( - 'twinkle_task_queue_depth', - description='Current number of queued tasks.', - tag_keys=('deployment', ), - ), - tasks_total=Counter( - 'twinkle_tasks_total', - description='Total task completions.', - tag_keys=('deployment', 'task_type', 'status'), - ), - execution_seconds=Histogram( - 'twinkle_task_execution_seconds', - description='Pure task execution time in seconds.', - boundaries=_HISTOGRAM_BOUNDARIES, - tag_keys=('deployment', 'task_type'), - ), - queue_wait_seconds=Histogram( - 'twinkle_task_queue_wait_seconds', - description='Time from enqueue to execution start in seconds.', - boundaries=_HISTOGRAM_BOUNDARIES, - tag_keys=('deployment', 'task_type'), - ), - rate_limit_rejections=Counter( - 'twinkle_rate_limit_rejections_total', - description='Total rate-limit rejections.', - tag_keys=('deployment', ), - ), - rate_limiter_active_tokens=Gauge( - 'twinkle_rate_limiter_active_tokens', - description='Number of tokens tracked by the rate limiter.', - tag_keys=('deployment', ), - ), + queue_depth=_Gauge(reg.queue_depth), + tasks_total=_Counter(reg.tasks_total), + execution_seconds=_Histogram(reg.task_execution_seconds), + queue_wait_seconds=_Histogram(reg.task_wait_seconds), + rate_limit_rejections=_Counter(reg.rate_limit_rejections), + rate_limiter_active_tokens=_Gauge(reg.rate_limiter_active_tokens), ) _task_metrics_cache[deployment] = metrics return metrics - - -# --------------------------------------------------------------------------- -# D. Resource gauges (ServerState actor, updated every 15 s) -# --------------------------------------------------------------------------- - - -def get_resource_metrics() -> ResourceMetrics: - """Return (or create) global resource gauge metrics. - - Returns a :class:`ResourceMetrics` Pydantic model with: - - - ``active_sessions`` – Gauge - - ``active_models`` – Gauge - - ``active_sampling_sessions`` – Gauge - - ``active_futures`` – Gauge - """ - global _resource_metrics_cache - if _resource_metrics_cache is not None: - return _resource_metrics_cache - - metrics = ResourceMetrics( - active_sessions=Gauge( - 'twinkle_active_sessions', - description='Current active session count.', - ), - active_models=Gauge( - 'twinkle_active_models', - description='Current registered model count.', - ), - active_sampling_sessions=Gauge( - 'twinkle_active_sampling_sessions', - description='Current sampling session count.', - ), - active_futures=Gauge( - 'twinkle_active_futures', - description='Current future/request count.', - ), - ) - _resource_metrics_cache = metrics - return metrics diff --git a/src/twinkle/server/utils/state/base.py b/src/twinkle/server/utils/state/base.py deleted file mode 100644 index c7480ec78..000000000 --- a/src/twinkle/server/utils/state/base.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -import time -from abc import ABC, abstractmethod -from datetime import datetime -from pydantic import BaseModel -from typing import Generic, TypeVar - -T = TypeVar('T', bound=BaseModel) - - -class BaseManager(ABC, Generic[T]): - """ - Abstract base class for resource managers. - - Provides common CRUD operations and timestamp parsing. - Subclasses must implement `cleanup_expired`. - """ - - def __init__(self, expiration_timeout: float) -> None: - self._store: dict[str, T] = {} - self.expiration_timeout = expiration_timeout - - # ----- CRUD ----- - - def add(self, resource_id: str, record: T) -> None: - """Store a record under the given ID.""" - self._store[resource_id] = record - - def get(self, resource_id: str) -> T | None: - """Return the record for the given ID, or None.""" - return self._store.get(resource_id) - - def remove(self, resource_id: str) -> bool: - """Remove a record by ID. Returns True if it existed.""" - return self._store.pop(resource_id, None) is not None - - def count(self) -> int: - """Return the number of stored records.""" - return len(self._store) - - # ----- Cleanup ----- - - @abstractmethod - def cleanup_expired(self, cutoff_time: float) -> int: - """ - Remove all records older than cutoff_time. - - Args: - cutoff_time: Unix timestamp; records with activity before this are removed. - - Returns: - Number of records removed. - """ - - # ----- Helpers ----- - - def _parse_timestamp(self, timestamp_str: str) -> float: - """Parse an ISO-format timestamp string to a Unix timestamp. - - Falls back to the current time so that unparseable entries are - never accidentally kept alive forever. - """ - try: - return datetime.fromisoformat(timestamp_str).timestamp() - except (ValueError, AttributeError): - return time.time() diff --git a/src/twinkle/server/utils/state/config_manager.py b/src/twinkle/server/utils/state/config_manager.py deleted file mode 100644 index e1aa3bcea..000000000 --- a/src/twinkle/server/utils/state/config_manager.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -from typing import Any - - -class ConfigManager: - """ - Manages key-value configuration entries. - - Configuration entries have no expiry; they persist until explicitly removed - or cleared. This manager does not inherit from BaseManager because config - values are arbitrary Python objects rather than Pydantic models. - """ - - def __init__(self) -> None: - self._store: dict[str, Any] = {} - - # ----- CRUD ----- - - def add(self, key: str, value: Any) -> None: - """Add or overwrite a configuration value.""" - self._store[key] = value - - def add_or_get(self, key: str, value: Any) -> Any: - """Add a value if the key does not exist; otherwise return the existing value. - - Args: - key: Configuration key. - value: Value to store if the key is absent. - - Returns: - The existing or newly stored value. - """ - if key not in self._store: - self._store[key] = value - return self._store[key] - - def get(self, key: str) -> Any | None: - """Return the configuration value for key, or None.""" - return self._store.get(key) - - def pop(self, key: str) -> Any | None: - """Remove and return the configuration value for key, or None.""" - return self._store.pop(key, None) - - def clear(self) -> None: - """Remove all configuration entries.""" - self._store.clear() - - def count(self) -> int: - """Return the number of stored configuration entries.""" - return len(self._store) diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py deleted file mode 100644 index 2d5345f7a..000000000 --- a/src/twinkle/server/utils/state/model_manager.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -from .base import BaseManager -from .models import ModelRecord - - -class ModelManager(BaseManager[ModelRecord]): - """ - Manages registered models. - - Expiry is based on `created_at`. A model is also considered expired if - its owning session has already been removed (cascade expiry). - - Enforces a per-token model limit across all model instances (server-global). - - Also tracks replica registrations so the router can query which replicas - still have capacity (i.e. their loaded-model count < max_loras). - """ - - def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) -> None: - super().__init__(expiration_timeout) - self._per_token_model_limit = per_token_model_limit - # token -> set of model_ids owned by that token - self._token_models: dict[str, set[str]] = {} - # replica_id -> set of model_ids currently loaded on that replica - self._replica_models: dict[str, set[str]] = {} - # replica_id -> max_loras limit declared at registration time - self._replica_max_loras: dict[str, int] = {} - - def get_capacity_info(self) -> dict[str, int]: - """Return global LoRA capacity across all registered replicas. - - Returns: - Dict containing 'max_loras', 'used_loras', and 'free_loras'. - """ - total_max_loras = sum(self._replica_max_loras.values()) - total_used_loras = sum(len(self._replica_models.get(rid, set())) for rid in self._replica_max_loras.keys()) - return { - 'max_loras': total_max_loras, - 'used_loras': total_used_loras, - 'free_loras': max(0, total_max_loras - total_used_loras), - } - - # ----- Replica Registration ----- - - def register_replica(self, replica_id: str, max_loras: int) -> None: - """Register a replica and its LoRA capacity. - - Args: - replica_id: Unique identifier for the replica. - max_loras: Maximum number of LoRA adapters the replica can hold. - """ - self._replica_max_loras[replica_id] = max_loras - self._replica_models.setdefault(replica_id, set()) - - def unregister_replica(self, replica_id: str) -> None: - """Remove a replica from the registry. - - Any model associations for this replica are also cleared. - - Args: - replica_id: Unique identifier for the replica to remove. - """ - self._replica_max_loras.pop(replica_id, None) - self._replica_models.pop(replica_id, None) - - def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: - """Return the subset of candidate replica IDs that still have capacity. - - A replica has capacity when its current loaded-model count is strictly - less than its declared ``max_loras``. Replicas that are not registered - (unknown to this manager) are included as-is (conservative fallback). - - Args: - candidate_ids: Replica IDs to evaluate. - - Returns: - Filtered list preserving the original order. - """ - available = [] - for rid in candidate_ids: - max_loras = self._replica_max_loras.get(rid) - if max_loras is None: - # Unknown replica – include conservatively - available.append(rid) - continue - current = len(self._replica_models.get(rid, set())) - if current < max_loras: - available.append(rid) - return available - - # ----- CRUD ----- - - def add(self, model_id: str, record: ModelRecord) -> None: - """Store a record under the given ID. - - Args: - model_id: Unique identifier for the model. - record: ModelRecord to store. - - Raises: - RuntimeError: If the token has reached per_token_model_limit. - """ - token = record.token - current_ids = self._token_models.get(token, set()) - if len(current_ids) >= self._per_token_model_limit: - raise RuntimeError(f'Model limit exceeded: ' - f'{len(current_ids)}/{self._per_token_model_limit} models') - self._token_models.setdefault(token, set()).add(model_id) - if record.replica_id is not None: - self._replica_models.setdefault(record.replica_id, set()).add(model_id) - self._store[model_id] = record - - def remove(self, model_id: str) -> bool: - """Remove a record by ID and clean up token and replica ownership. - - Returns: - True if the record existed and was removed, False otherwise. - """ - record = self._store.pop(model_id, None) - if record is None: - return False - self._cleanup_ownership(model_id, record) - return True - - # ----- Cleanup ----- - - def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: - """Remove models that are older than cutoff_time, or whose owning - session has already been expired. - - Args: - cutoff_time: Unix timestamp threshold. - expired_session_ids: Optional list of session IDs that have just - been expired; any model belonging to one of these sessions will - also be removed regardless of its own age. - - Returns: - Number of models removed. - """ - session_set = set(expired_session_ids or []) - expired_ids = [] - - for model_id, record in self._store.items(): - # Cascade: owner session was expired - if record.session_id and record.session_id in session_set: - expired_ids.append(model_id) - continue - # Own age - created_at = self._parse_timestamp(record.created_at) - if created_at < cutoff_time: - expired_ids.append(model_id) - - for model_id in expired_ids: - record = self._store.pop(model_id) - self._cleanup_ownership(model_id, record) - - return len(expired_ids) - - # ----- Internal helpers ----- - - def _cleanup_ownership(self, model_id: str, record: ModelRecord) -> None: - """Remove token and replica ownership entries for a model record. - - Args: - model_id: The model ID being removed. - record: The associated ModelRecord. - """ - token = record.token - if token and token in self._token_models: - self._token_models[token].discard(model_id) - if not self._token_models[token]: - del self._token_models[token] - if record.replica_id and record.replica_id in self._replica_models: - self._replica_models[record.replica_id].discard(model_id) diff --git a/src/twinkle/server/utils/task_queue/config.py b/src/twinkle/server/utils/task_queue/config.py index a8b6437b6..57f24ba67 100644 --- a/src/twinkle/server/utils/task_queue/config.py +++ b/src/twinkle/server/utils/task_queue/config.py @@ -2,78 +2,49 @@ """ Task queue configuration. -Provides TaskQueueConfig for controlling rate limits, timeouts, -and queue behavior. +Provides TaskQueueConfig (Pydantic) for controlling rate limits, timeouts, +and queue behavior. Constraints are validated at construction time so an +invalid YAML/dict value is rejected before the deployment reaches a ready +state. """ from __future__ import annotations -from dataclasses import dataclass +from pydantic import BaseModel, ConfigDict, Field from typing import Any -@dataclass -class TaskQueueConfig: +class TaskQueueConfig(BaseModel): """Configuration for task queue and rate limiting. Attributes: - rps_limit: Maximum requests per second per user token. - tps_limit: Maximum input tokens per second per user token. - window_seconds: Time window for rate limiting calculations. + rps_limit: Maximum requests per second per user token. ``0`` disables. + tps_limit: Maximum input tokens per second per user token. ``0`` disables. + window_seconds: Sliding window for rate-limit calculations. Must be > 0. queue_timeout: Maximum time a task can wait in queue (seconds). execution_timeout: Maximum time a task can execute (seconds). 0 means no limit. enabled: Whether rate limiting is enabled. token_cleanup_multiplier: Multiplier for token cleanup threshold. token_cleanup_interval: How often to run cleanup task (seconds). - max_input_tokens: Maximum allowed input tokens per request (default 16000). + max_input_tokens: Maximum allowed input tokens per request. """ - rps_limit: float = 100.0 # 100 requests per second - tps_limit: float = 16000.0 # 16000 input tokens per second - window_seconds: float = 1.0 # 1 second sliding window - queue_timeout: float = 300.0 # 5 minutes queue timeout - execution_timeout: float = 120.0 # 120 seconds execution timeout (0 to disable) - enabled: bool = True # Rate limiting enabled by default - # Remove tokens after 10x window inactivity - token_cleanup_multiplier: float = 10.0 - token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds - max_input_tokens: int = 16000 # Maximum input tokens per request + + model_config = ConfigDict(extra='forbid') + + rps_limit: float = Field(default=100.0, ge=0) + tps_limit: float = Field(default=16000.0, ge=0) + window_seconds: float = Field(default=1.0, gt=0) + queue_timeout: float = Field(default=300.0, ge=0) + execution_timeout: float = Field(default=120.0, ge=0) + enabled: bool = True + token_cleanup_multiplier: float = Field(default=10.0, ge=0) + token_cleanup_interval: float = Field(default=60.0, ge=0) + max_input_tokens: int = Field(default=16000, ge=1) @classmethod def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: - """Create TaskQueueConfig from a dictionary. - - Args: - config_dict: Dictionary with configuration values. Supports keys: - - rps_limit: requests per second limit - - tps_limit: input tokens per second limit - - window_seconds: sliding window duration - - queue_timeout: queue timeout in seconds - - execution_timeout: task execution timeout in seconds (0 to disable) - - enabled: whether rate limiting is enabled - - token_cleanup_multiplier: multiplier for token cleanup threshold - - token_cleanup_interval: cleanup task interval in seconds - - max_input_tokens: maximum input tokens per request + """Validate ``config_dict`` (or ``{}``) into a ``TaskQueueConfig``. - Returns: - TaskQueueConfig instance with values from dict merged with defaults. + Equivalent to ``cls.model_validate(config_dict or {})`` — kept for + call-site compatibility and to make the entry point explicit. """ - config = cls() - if config_dict: - if 'rps_limit' in config_dict: - config.rps_limit = float(config_dict['rps_limit']) - if 'tps_limit' in config_dict: - config.tps_limit = float(config_dict['tps_limit']) - if 'window_seconds' in config_dict: - config.window_seconds = float(config_dict['window_seconds']) - if 'queue_timeout' in config_dict: - config.queue_timeout = float(config_dict['queue_timeout']) - if 'execution_timeout' in config_dict: - config.execution_timeout = float(config_dict['execution_timeout']) - if 'enabled' in config_dict: - config.enabled = bool(config_dict['enabled']) - if 'token_cleanup_multiplier' in config_dict: - config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier']) - if 'token_cleanup_interval' in config_dict: - config.token_cleanup_interval = float(config_dict['token_cleanup_interval']) - if 'max_input_tokens' in config_dict: - config.max_input_tokens = int(config_dict['max_input_tokens']) - return config + return cls.model_validate(config_dict or {}) diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py index a5ecbc7e8..1bec0843b 100644 --- a/src/twinkle/server/utils/task_queue/mixin.py +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -22,7 +22,7 @@ from .worker import ComputeWorker if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy + from twinkle.server.state import ServerState logger = get_logger() @@ -43,11 +43,11 @@ class TaskQueueMixin: Requirements ------------ - Inheriting class must expose self.state: ServerStateProxy and call + Inheriting class must expose self.state: ServerState and call _init_task_queue() during __init__. """ - state: ServerStateProxy + state: ServerState def _init_task_queue(self, config: TaskQueueConfig | None = None, deployment_name: str = '') -> None: """Initialise the task queue, rate limiter, and compute worker.""" diff --git a/src/twinkle/server/utils/task_queue/rate_limiter.py b/src/twinkle/server/utils/task_queue/rate_limiter.py index 229428801..d3370371a 100644 --- a/src/twinkle/server/utils/task_queue/rate_limiter.py +++ b/src/twinkle/server/utils/task_queue/rate_limiter.py @@ -55,7 +55,8 @@ def __init__( will be removed. Default is 10.0 (10x the window). token_cleanup_interval: How often to run the cleanup task in seconds. Default is 60.0 (every minute). - active_tokens_gauge: Optional ray.util.metrics Gauge for tracking active token count. + active_tokens_gauge: Optional gauge adapter (see twinkle.server.utils.metrics) + for tracking the active token count. deployment_name: Deployment name for metrics labels. """ self.rps_limit = rps_limit diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index 77740cb72..a546172c1 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -14,13 +14,15 @@ from collections import deque from typing import TYPE_CHECKING, Any, Deque +from twinkle.server.telemetry.correlation import MODEL_ID, TOKEN_ID +from twinkle.server.telemetry.tracing import traced_operation from twinkle.utils.logger import get_logger from .config import TaskQueueConfig from .types import QueuedTask, QueueState, TaskStatus if TYPE_CHECKING: + from twinkle.server.state import ServerState from twinkle.server.utils.metrics import TaskMetrics - from twinkle.server.utils.state import ServerStateProxy logger = get_logger() @@ -39,7 +41,7 @@ class ComputeWorker: def __init__( self, - state: ServerStateProxy, + state: ServerState, config: TaskQueueConfig, task_metrics: TaskMetrics | None, deployment_name: str, @@ -198,14 +200,30 @@ async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue exec_start = time.monotonic() task_status = 'completed' exec_time = 0.0 + # R10.2: one span per queued task execution; R10.3: a nested span tagged + # `.` (e.g. `model.forward`, `sampler.sample`) so + # the handler's primary op has its own span carrying token/model + # correlation. The nested span is started inside the try/except so + # span lifecycle stays correctly scoped on timeout / exceptions. + queue_attrs = { + TOKEN_ID: task.token, + MODEL_ID: task.model_id, + 'twinkle.task_type': task_type, + 'twinkle.deployment': self._deployment_name or 'unknown', + 'twinkle.queue_key': queue_key, + } + handler_attrs = {TOKEN_ID: task.token, MODEL_ID: task.model_id} + handler_span_name = f'{self._deployment_name or "deployment"}.{task_type}' try: - coro = task.coro_factory() - logger.debug(f'[ComputeWorker] Task {task.request_id} executing, ' - f'type={task_type}, queue_key={queue_key}') - if self._config.execution_timeout > 0: - result = await asyncio.wait_for(coro, timeout=self._config.execution_timeout) - else: - result = await coro + with traced_operation('task_queue.execute', attrs=queue_attrs): + logger.debug(f'[ComputeWorker] Task {task.request_id} executing, ' + f'type={task_type}, queue_key={queue_key}') + with traced_operation(handler_span_name, attrs=handler_attrs): + coro = task.coro_factory() + if self._config.execution_timeout > 0: + result = await asyncio.wait_for(coro, timeout=self._config.execution_timeout) + else: + result = await coro exec_time = time.monotonic() - exec_start logger.info(f'[ComputeWorker] Task {task.request_id} completed in {exec_time:.2f}s, type={task_type}') await self._state.store_future_status( diff --git a/tests/contract/__init__.py b/tests/contract/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contract/client_api_baseline.json b/tests/contract/client_api_baseline.json new file mode 100644 index 000000000..54f0ec1bf --- /dev/null +++ b/tests/contract/client_api_baseline.json @@ -0,0 +1,7340 @@ +{ + "gateway": { + "paths": { + "/asample": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SampleRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Asample Asample Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/create_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Create Model Create Model Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/create_sampling_session": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateSamplingSessionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateSamplingSessionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/create_session": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__create_session_request__CreateSessionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__create_session_response__CreateSessionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/forward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Forward Forward Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/forward_backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardBackwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Forward Backward Forward Backward Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/get_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Get Info Get Info Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/get_server_capabilities": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__get_server_capabilities_response__GetServerCapabilitiesResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/healthz": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__health_response__HealthResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/load_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Load Weights Load Weights Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/optim_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptimStepRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Optim Step Optim Step Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/retrieve_future": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FutureRetrieveRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Retrieve Future Retrieve Future Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/save_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Save Weights Save Weights Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/save_weights_for_sampler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsForSamplerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Save Weights For Sampler Save Weights For Sampler Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/session_heartbeat": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__session_heartbeat_request__SessionHeartbeatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__session_heartbeat_response__SessionHeartbeatResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/telemetry": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TelemetrySendRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TelemetryResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs": { + "GET": { + "parameters": [ + { + "in": "query", + "name": "limit", + "required": false, + "schema": { + "default": 20, + "title": "Limit", + "type": "integer" + } + }, + { + "in": "query", + "name": "offset", + "required": false, + "schema": { + "default": 0, + "title": "Offset", + "type": "integer" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__training_runs_response__TrainingRunsResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__training_run__TrainingRun" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}/checkpoints": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__checkpoints_list_response__CheckpointsListResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}/checkpoints/{checkpoint_id}": { + "DELETE": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Delete Run Checkpoint Training Runs Run Id Checkpoints Checkpoint Id Delete" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/training_runs/{run_id}/checkpoints/{checkpoint_id}/publish": { + "POST": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/capacity_info": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CapacityInfoResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/checkpoint_path/{run_id}/{checkpoint_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CheckpointPathResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create_session": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__CreateSessionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__CreateSessionResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/get_server_capabilities": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__server__GetServerCapabilitiesResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/healthz": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__server__HealthResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/session_heartbeat": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__SessionHeartbeatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__session__SessionHeartbeatResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/status": { + "GET": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "title": "Response Status Twinkle Status Get", + "type": "object" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/training_runs": { + "GET": { + "parameters": [ + { + "in": "query", + "name": "limit", + "required": false, + "schema": { + "default": 20, + "title": "Limit", + "type": "integer" + } + }, + { + "in": "query", + "name": "offset", + "required": false, + "schema": { + "default": 0, + "title": "Offset", + "type": "integer" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__TrainingRunsResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/training_runs/{run_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__TrainingRun" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/training_runs/{run_id}/checkpoints": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__CheckpointsListResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id}": { + "DELETE": { + "parameters": [ + { + "in": "path", + "name": "run_id", + "required": true, + "schema": { + "title": "Run Id", + "type": "string" + } + }, + { + "in": "path", + "name": "checkpoint_id", + "required": true, + "schema": { + "title": "Checkpoint Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteCheckpointResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/weights_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WeightsInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__training__WeightsInfoResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/unload_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnloadModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Unload Model Unload Model Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/weights_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "title": "Body", + "type": "object" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__weights_info_response__WeightsInfoResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "AdamParams": { + "additionalProperties": false, + "properties": { + "beta1": { + "default": 0.9, + "title": "Beta1", + "type": "number" + }, + "beta2": { + "default": 0.95, + "title": "Beta2", + "type": "number" + }, + "eps": { + "default": 1e-12, + "title": "Eps", + "type": "number" + }, + "grad_clip_norm": { + "default": 0.0, + "title": "Grad Clip Norm", + "type": "number" + }, + "learning_rate": { + "default": 0.0001, + "title": "Learning Rate", + "type": "number" + }, + "weight_decay": { + "default": 0.0, + "title": "Weight Decay", + "type": "number" + } + }, + "title": "AdamParams", + "type": "object" + }, + "CapacityInfoResponse": { + "description": "Response body for the /capacity_info endpoint.", + "properties": { + "free_loras": { + "title": "Free Loras", + "type": "integer" + }, + "max_loras": { + "title": "Max Loras", + "type": "integer" + }, + "used_loras": { + "title": "Used Loras", + "type": "integer" + } + }, + "required": [ + "max_loras", + "used_loras", + "free_loras" + ], + "title": "CapacityInfoResponse", + "type": "object" + }, + "CheckpointPathResponse": { + "description": "Response body for the /checkpoint_path endpoint.", + "properties": { + "path": { + "title": "Path", + "type": "string" + }, + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + } + }, + "required": [ + "path", + "twinkle_path" + ], + "title": "CheckpointPathResponse", + "type": "object" + }, + "CreateModelRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "lora_config": { + "anyOf": [ + { + "$ref": "#/components/schemas/LoraConfig" + }, + { + "type": "null" + } + ] + }, + "model_seq_id": { + "title": "Model Seq Id", + "type": "integer" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_model", + "default": "create_model", + "title": "Type", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "session_id", + "model_seq_id", + "base_model" + ], + "title": "CreateModelRequest", + "type": "object" + }, + "CreateSamplingSessionRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "model_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Path" + }, + "sampling_session_seq_id": { + "title": "Sampling Session Seq Id", + "type": "integer" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_sampling_session", + "default": "create_sampling_session", + "title": "Type", + "type": "string" + } + }, + "required": [ + "session_id", + "sampling_session_seq_id" + ], + "title": "CreateSamplingSessionRequest", + "type": "object" + }, + "CreateSamplingSessionResponse": { + "properties": { + "sampling_session_id": { + "title": "Sampling Session Id", + "type": "string" + }, + "type": { + "const": "create_sampling_session", + "default": "create_sampling_session", + "title": "Type", + "type": "string" + } + }, + "required": [ + "sampling_session_id" + ], + "title": "CreateSamplingSessionResponse", + "type": "object" + }, + "Datum": { + "additionalProperties": false, + "properties": { + "loss_fn_inputs": { + "additionalProperties": { + "$ref": "#/components/schemas/TensorData" + }, + "title": "Loss Fn Inputs", + "type": "object" + }, + "model_input": { + "$ref": "#/components/schemas/ModelInput" + } + }, + "required": [ + "loss_fn_inputs", + "model_input" + ], + "title": "Datum", + "type": "object" + }, + "DeleteCheckpointResponse": { + "properties": { + "message": { + "title": "Message", + "type": "string" + }, + "success": { + "title": "Success", + "type": "boolean" + } + }, + "required": [ + "success", + "message" + ], + "title": "DeleteCheckpointResponse", + "type": "object" + }, + "EncodedTextChunk": { + "additionalProperties": false, + "properties": { + "tokens": { + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + }, + "type": { + "const": "encoded_text", + "default": "encoded_text", + "title": "Type", + "type": "string" + } + }, + "required": [ + "tokens" + ], + "title": "EncodedTextChunk", + "type": "object" + }, + "ForwardBackwardInput": { + "additionalProperties": false, + "properties": { + "data": { + "items": { + "$ref": "#/components/schemas/Datum" + }, + "title": "Data", + "type": "array" + }, + "loss_fn": { + "enum": [ + "cross_entropy", + "importance_sampling", + "ppo", + "cispo", + "dro" + ], + "title": "Loss Fn", + "type": "string" + }, + "loss_fn_config": { + "anyOf": [ + { + "additionalProperties": { + "type": "number" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Loss Fn Config" + } + }, + "required": [ + "data", + "loss_fn" + ], + "title": "ForwardBackwardInput", + "type": "object" + }, + "ForwardBackwardRequest": { + "additionalProperties": false, + "properties": { + "forward_backward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_backward_input", + "model_id" + ], + "title": "ForwardBackwardRequest", + "type": "object" + }, + "ForwardRequest": { + "additionalProperties": false, + "properties": { + "forward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_input", + "model_id" + ], + "title": "ForwardRequest", + "type": "object" + }, + "FutureRetrieveRequest": { + "additionalProperties": false, + "properties": { + "allow_metadata_only": { + "default": false, + "title": "Allow Metadata Only", + "type": "boolean" + }, + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "FutureRetrieveRequest", + "type": "object" + }, + "GenericEvent": { + "properties": { + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_data": { + "additionalProperties": true, + "default": {}, + "title": "Event Data", + "type": "object" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_name": { + "title": "Event Name", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + } + }, + "required": [ + "event", + "event_id", + "event_name", + "event_session_index", + "severity", + "timestamp" + ], + "title": "GenericEvent", + "type": "object" + }, + "GetInfoRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "get_info", + "default": "get_info", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "GetInfoRequest", + "type": "object" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ImageAssetPointerChunk": { + "additionalProperties": false, + "properties": { + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "location": { + "title": "Location", + "type": "string" + }, + "type": { + "const": "image_asset_pointer", + "default": "image_asset_pointer", + "title": "Type", + "type": "string" + } + }, + "required": [ + "format", + "location" + ], + "title": "ImageAssetPointerChunk", + "type": "object" + }, + "ImageChunk": { + "additionalProperties": false, + "properties": { + "data": { + "contentMediaType": "application/octet-stream", + "title": "Data", + "type": "string" + }, + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "type": { + "const": "image", + "default": "image", + "title": "Type", + "type": "string" + } + }, + "required": [ + "data", + "format" + ], + "title": "ImageChunk", + "type": "object" + }, + "LoadWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "optimizer": { + "title": "Optimizer", + "type": "boolean" + }, + "path": { + "title": "Path", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "load_weights", + "default": "load_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id", + "path", + "optimizer" + ], + "title": "LoadWeightsRequest", + "type": "object" + }, + "LoraConfig": { + "additionalProperties": false, + "properties": { + "rank": { + "title": "Rank", + "type": "integer" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "train_attn": { + "default": true, + "title": "Train Attn", + "type": "boolean" + }, + "train_mlp": { + "default": true, + "title": "Train Mlp", + "type": "boolean" + }, + "train_unembed": { + "default": true, + "title": "Train Unembed", + "type": "boolean" + } + }, + "required": [ + "rank" + ], + "title": "LoraConfig", + "type": "object" + }, + "ModelInput": { + "additionalProperties": false, + "properties": { + "chunks": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/EncodedTextChunk" + }, + { + "$ref": "#/components/schemas/ImageAssetPointerChunk" + }, + { + "$ref": "#/components/schemas/ImageChunk" + } + ] + }, + "title": "Chunks", + "type": "array" + } + }, + "required": [ + "chunks" + ], + "title": "ModelInput", + "type": "object" + }, + "OptimStepRequest": { + "additionalProperties": false, + "properties": { + "adam_params": { + "$ref": "#/components/schemas/AdamParams" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "optim_step", + "default": "optim_step", + "title": "Type", + "type": "string" + } + }, + "required": [ + "adam_params", + "model_id" + ], + "title": "OptimStepRequest", + "type": "object" + }, + "SampleRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "model_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Path" + }, + "num_samples": { + "default": 1, + "title": "Num Samples", + "type": "integer" + }, + "prompt": { + "$ref": "#/components/schemas/ModelInput" + }, + "prompt_logprobs": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Prompt Logprobs" + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "sampling_session_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "topk_prompt_logprobs": { + "default": 0, + "title": "Topk Prompt Logprobs", + "type": "integer" + }, + "type": { + "const": "sample", + "default": "sample", + "title": "Type", + "type": "string" + } + }, + "required": [ + "prompt", + "sampling_params" + ], + "title": "SampleRequest", + "type": "object" + }, + "SamplingParams": { + "properties": { + "max_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Max Tokens" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "stop": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stop" + }, + "temperature": { + "default": 1, + "title": "Temperature", + "type": "number" + }, + "top_k": { + "default": -1, + "title": "Top K", + "type": "integer" + }, + "top_p": { + "default": 1, + "title": "Top P", + "type": "number" + } + }, + "title": "SamplingParams", + "type": "object" + }, + "SaveWeightsForSamplerRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "sampling_session_seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Seq Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights_for_sampler", + "default": "save_weights_for_sampler", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsForSamplerRequest", + "type": "object" + }, + "SaveWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights", + "default": "save_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsRequest", + "type": "object" + }, + "SessionEndEvent": { + "properties": { + "duration": { + "title": "Duration", + "type": "string" + }, + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + } + }, + "required": [ + "duration", + "event", + "event_id", + "event_session_index", + "severity", + "timestamp" + ], + "title": "SessionEndEvent", + "type": "object" + }, + "SessionStartEvent": { + "properties": { + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + } + }, + "required": [ + "event", + "event_id", + "event_session_index", + "severity", + "timestamp" + ], + "title": "SessionStartEvent", + "type": "object" + }, + "TelemetryResponse": { + "properties": { + "status": { + "const": "accepted", + "title": "Status", + "type": "string" + } + }, + "required": [ + "status" + ], + "title": "TelemetryResponse", + "type": "object" + }, + "TelemetrySendRequest": { + "additionalProperties": false, + "properties": { + "events": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/SessionStartEvent" + }, + { + "$ref": "#/components/schemas/SessionEndEvent" + }, + { + "$ref": "#/components/schemas/UnhandledExceptionEvent" + }, + { + "$ref": "#/components/schemas/GenericEvent" + } + ] + }, + "title": "Events", + "type": "array" + }, + "platform": { + "title": "Platform", + "type": "string" + }, + "sdk_version": { + "title": "Sdk Version", + "type": "string" + }, + "session_id": { + "title": "Session Id", + "type": "string" + } + }, + "required": [ + "events", + "platform", + "sdk_version", + "session_id" + ], + "title": "TelemetrySendRequest", + "type": "object" + }, + "TensorData": { + "additionalProperties": false, + "properties": { + "data": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "items": { + "type": "number" + }, + "type": "array" + } + ], + "title": "Data" + }, + "dtype": { + "enum": [ + "int64", + "float32" + ], + "title": "Dtype", + "type": "string" + }, + "shape": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Shape" + } + }, + "required": [ + "data", + "dtype" + ], + "title": "TensorData", + "type": "object" + }, + "UnhandledExceptionEvent": { + "properties": { + "error_message": { + "title": "Error Message", + "type": "string" + }, + "error_type": { + "title": "Error Type", + "type": "string" + }, + "event": { + "enum": [ + "SESSION_START", + "SESSION_END", + "UNHANDLED_EXCEPTION", + "GENERIC_EVENT" + ], + "title": "Event", + "type": "string" + }, + "event_id": { + "title": "Event Id", + "type": "string" + }, + "event_session_index": { + "title": "Event Session Index", + "type": "integer" + }, + "severity": { + "enum": [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL" + ], + "title": "Severity", + "type": "string" + }, + "timestamp": { + "format": "date-time", + "title": "Timestamp", + "type": "string" + }, + "traceback": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Traceback" + } + }, + "required": [ + "error_message", + "error_type", + "event", + "event_id", + "event_session_index", + "severity", + "timestamp" + ], + "title": "UnhandledExceptionEvent", + "type": "object" + }, + "UnloadModelRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "unload_model", + "default": "unload_model", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "UnloadModelRequest", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + }, + "WeightsInfoRequest": { + "properties": { + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + } + }, + "required": [ + "twinkle_path" + ], + "title": "WeightsInfoRequest", + "type": "object" + }, + "tinker__types__checkpoint__Checkpoint": { + "properties": { + "checkpoint_id": { + "title": "Checkpoint Id", + "type": "string" + }, + "checkpoint_type": { + "enum": [ + "training", + "sampler" + ], + "title": "Checkpoint Type", + "type": "string" + }, + "expires_at": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Expires At" + }, + "public": { + "default": false, + "title": "Public", + "type": "boolean" + }, + "size_bytes": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Size Bytes" + }, + "time": { + "format": "date-time", + "title": "Time", + "type": "string" + }, + "tinker_path": { + "title": "Tinker Path", + "type": "string" + } + }, + "required": [ + "checkpoint_id", + "checkpoint_type", + "time", + "tinker_path" + ], + "title": "Checkpoint", + "type": "object" + }, + "tinker__types__checkpoints_list_response__CheckpointsListResponse": { + "properties": { + "checkpoints": { + "items": { + "$ref": "#/components/schemas/tinker__types__checkpoint__Checkpoint" + }, + "title": "Checkpoints", + "type": "array" + }, + "cursor": { + "anyOf": [ + { + "$ref": "#/components/schemas/tinker__types__cursor__Cursor" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "checkpoints" + ], + "title": "CheckpointsListResponse", + "type": "object" + }, + "tinker__types__create_session_request__CreateSessionRequest": { + "additionalProperties": false, + "properties": { + "sdk_version": { + "title": "Sdk Version", + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "title": "Tags", + "type": "array" + }, + "type": { + "const": "create_session", + "default": "create_session", + "title": "Type", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "tags", + "user_metadata", + "sdk_version" + ], + "title": "CreateSessionRequest", + "type": "object" + }, + "tinker__types__create_session_response__CreateSessionResponse": { + "properties": { + "error_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Message" + }, + "info_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Info Message" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_session", + "default": "create_session", + "title": "Type", + "type": "string" + }, + "warning_message": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Warning Message" + } + }, + "required": [ + "session_id" + ], + "title": "CreateSessionResponse", + "type": "object" + }, + "tinker__types__cursor__Cursor": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "offset": { + "title": "Offset", + "type": "integer" + }, + "total_count": { + "title": "Total Count", + "type": "integer" + } + }, + "required": [ + "offset", + "limit", + "total_count" + ], + "title": "Cursor", + "type": "object" + }, + "tinker__types__get_server_capabilities_response__GetServerCapabilitiesResponse": { + "description": "Response containing the server's supported models and capabilities.", + "properties": { + "supported_models": { + "items": { + "$ref": "#/components/schemas/tinker__types__get_server_capabilities_response__SupportedModel" + }, + "title": "Supported Models", + "type": "array" + } + }, + "required": [ + "supported_models" + ], + "title": "GetServerCapabilitiesResponse", + "type": "object" + }, + "tinker__types__get_server_capabilities_response__SupportedModel": { + "description": "Information about a model supported by the server.", + "properties": { + "model_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Name" + } + }, + "title": "SupportedModel", + "type": "object" + }, + "tinker__types__health_response__HealthResponse": { + "properties": { + "status": { + "const": "ok", + "title": "Status", + "type": "string" + } + }, + "required": [ + "status" + ], + "title": "HealthResponse", + "type": "object" + }, + "tinker__types__session_heartbeat_request__SessionHeartbeatRequest": { + "additionalProperties": false, + "properties": { + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "session_heartbeat", + "default": "session_heartbeat", + "title": "Type", + "type": "string" + } + }, + "required": [ + "session_id" + ], + "title": "SessionHeartbeatRequest", + "type": "object" + }, + "tinker__types__session_heartbeat_response__SessionHeartbeatResponse": { + "properties": { + "type": { + "const": "session_heartbeat", + "default": "session_heartbeat", + "title": "Type", + "type": "string" + } + }, + "title": "SessionHeartbeatResponse", + "type": "object" + }, + "tinker__types__training_run__TrainingRun": { + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "corrupted": { + "default": false, + "title": "Corrupted", + "type": "boolean" + }, + "is_lora": { + "title": "Is Lora", + "type": "boolean" + }, + "last_checkpoint": { + "anyOf": [ + { + "$ref": "#/components/schemas/tinker__types__checkpoint__Checkpoint" + }, + { + "type": "null" + } + ] + }, + "last_request_time": { + "format": "date-time", + "title": "Last Request Time", + "type": "string" + }, + "last_sampler_checkpoint": { + "anyOf": [ + { + "$ref": "#/components/schemas/tinker__types__checkpoint__Checkpoint" + }, + { + "type": "null" + } + ] + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_owner": { + "title": "Model Owner", + "type": "string" + }, + "training_run_id": { + "title": "Training Run Id", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "training_run_id", + "base_model", + "model_owner", + "is_lora", + "last_request_time" + ], + "title": "TrainingRun", + "type": "object" + }, + "tinker__types__training_runs_response__TrainingRunsResponse": { + "properties": { + "cursor": { + "$ref": "#/components/schemas/tinker__types__cursor__Cursor" + }, + "training_runs": { + "items": { + "$ref": "#/components/schemas/tinker__types__training_run__TrainingRun" + }, + "title": "Training Runs", + "type": "array" + } + }, + "required": [ + "training_runs", + "cursor" + ], + "title": "TrainingRunsResponse", + "type": "object" + }, + "tinker__types__weights_info_response__WeightsInfoResponse": { + "description": "Minimal information for loading public checkpoints.", + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "is_lora": { + "title": "Is Lora", + "type": "boolean" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "train_attn": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Attn" + }, + "train_mlp": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Mlp" + }, + "train_unembed": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Unembed" + } + }, + "required": [ + "base_model", + "is_lora" + ], + "title": "WeightsInfoResponse", + "type": "object" + }, + "twinkle_client__types__server__GetServerCapabilitiesResponse": { + "description": "Response body for the /get_server_capabilities endpoint.", + "properties": { + "supported_models": { + "items": { + "$ref": "#/components/schemas/twinkle_client__types__server__SupportedModel" + }, + "title": "Supported Models", + "type": "array" + } + }, + "required": [ + "supported_models" + ], + "title": "GetServerCapabilitiesResponse", + "type": "object" + }, + "twinkle_client__types__server__HealthResponse": { + "properties": { + "status": { + "title": "Status", + "type": "string" + } + }, + "required": [ + "status" + ], + "title": "HealthResponse", + "type": "object" + }, + "twinkle_client__types__server__SupportedModel": { + "description": "Information about a supported model.", + "properties": { + "model_name": { + "title": "Model Name", + "type": "string" + } + }, + "required": [ + "model_name" + ], + "title": "SupportedModel", + "type": "object" + }, + "twinkle_client__types__session__CreateSessionRequest": { + "description": "Request body for POST /twinkle/create_session.", + "properties": { + "metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Metadata" + } + }, + "title": "CreateSessionRequest", + "type": "object" + }, + "twinkle_client__types__session__CreateSessionResponse": { + "description": "Response body for POST /twinkle/create_session.", + "properties": { + "session_id": { + "title": "Session Id", + "type": "string" + } + }, + "required": [ + "session_id" + ], + "title": "CreateSessionResponse", + "type": "object" + }, + "twinkle_client__types__session__SessionHeartbeatRequest": { + "description": "Request body for POST /twinkle/session_heartbeat.", + "properties": { + "session_id": { + "title": "Session Id", + "type": "string" + } + }, + "required": [ + "session_id" + ], + "title": "SessionHeartbeatRequest", + "type": "object" + }, + "twinkle_client__types__session__SessionHeartbeatResponse": { + "description": "Response body for POST /twinkle/session_heartbeat.", + "properties": {}, + "title": "SessionHeartbeatResponse", + "type": "object" + }, + "twinkle_client__types__training__Checkpoint": { + "description": "Twinkle checkpoint model.", + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "checkpoint_id": { + "title": "Checkpoint Id", + "type": "string" + }, + "checkpoint_type": { + "title": "Checkpoint Type", + "type": "string" + }, + "is_lora": { + "default": false, + "title": "Is Lora", + "type": "boolean" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "public": { + "default": false, + "title": "Public", + "type": "boolean" + }, + "size_bytes": { + "title": "Size Bytes", + "type": "integer" + }, + "time": { + "format": "date-time", + "title": "Time", + "type": "string" + }, + "train_attn": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Attn" + }, + "train_mlp": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Mlp" + }, + "train_unembed": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Train Unembed" + }, + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "checkpoint_id", + "checkpoint_type", + "time", + "size_bytes", + "twinkle_path" + ], + "title": "Checkpoint", + "type": "object" + }, + "twinkle_client__types__training__CheckpointsListResponse": { + "properties": { + "checkpoints": { + "items": { + "$ref": "#/components/schemas/twinkle_client__types__training__Checkpoint" + }, + "title": "Checkpoints", + "type": "array" + }, + "cursor": { + "anyOf": [ + { + "$ref": "#/components/schemas/twinkle_client__types__training__Cursor" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "checkpoints" + ], + "title": "CheckpointsListResponse", + "type": "object" + }, + "twinkle_client__types__training__Cursor": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "offset": { + "title": "Offset", + "type": "integer" + }, + "total_count": { + "title": "Total Count", + "type": "integer" + } + }, + "required": [ + "limit", + "offset", + "total_count" + ], + "title": "Cursor", + "type": "object" + }, + "twinkle_client__types__training__TrainingRun": { + "description": "Twinkle training run model.", + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "corrupted": { + "default": false, + "title": "Corrupted", + "type": "boolean" + }, + "is_lora": { + "default": false, + "title": "Is Lora", + "type": "boolean" + }, + "last_checkpoint": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Last Checkpoint" + }, + "last_request_time": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Last Request Time" + }, + "last_sampler_checkpoint": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Last Sampler Checkpoint" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_owner": { + "title": "Model Owner", + "type": "string" + }, + "save_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Save Dir" + }, + "training_run_id": { + "title": "Training Run Id", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "training_run_id", + "base_model", + "model_owner" + ], + "title": "TrainingRun", + "type": "object" + }, + "twinkle_client__types__training__TrainingRunsResponse": { + "properties": { + "cursor": { + "$ref": "#/components/schemas/twinkle_client__types__training__Cursor" + }, + "training_runs": { + "items": { + "$ref": "#/components/schemas/twinkle_client__types__training__TrainingRun" + }, + "title": "Training Runs", + "type": "array" + } + }, + "required": [ + "training_runs", + "cursor" + ], + "title": "TrainingRunsResponse", + "type": "object" + }, + "twinkle_client__types__training__WeightsInfoResponse": { + "description": "Twinkle weights info response.", + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "is_lora": { + "default": false, + "title": "Is Lora", + "type": "boolean" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_owner": { + "title": "Model Owner", + "type": "string" + }, + "training_run_id": { + "title": "Training Run Id", + "type": "string" + } + }, + "required": [ + "training_run_id", + "base_model", + "model_owner" + ], + "title": "WeightsInfoResponse", + "type": "object" + } + } + }, + "model": { + "paths": { + "/tinker/create_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/forward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__forward_request__ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/forward_backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardBackwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/get_info": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetInfoResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/load_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/optim_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptimStepRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/save_weights": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/save_weights_for_sampler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveWeightsForSamplerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/tinker/unload_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnloadModelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/add_adapter_to_model": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/add_metric": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddMetricRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/apply_patch": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ApplyPatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/calculate_loss": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CalculateLossResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/calculate_metric": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CalculateMetricRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CalculateMetricResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/clip_grad_and_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClipGradAndStepRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/clip_grad_norm": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ClipGradNormResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/forward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__model__ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/forward_backward": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__model__ForwardRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardBackwardResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/forward_only": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardOnlyRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForwardResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/get_state_dict": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetStateDictRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetStateDictResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/get_train_configs": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GetTrainConfigsResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/load": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/lr_step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/resume_from_checkpoint": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ResumeFromCheckpointRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TrainingProgressResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/save": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SaveResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_loss": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetLossRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_lr_scheduler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetLrSchedulerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_optimizer": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetOptimizerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_processor": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetProcessorRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_template": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetTemplateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/step": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/upload_status/{request_id}": { + "GET": { + "parameters": [ + { + "in": "path", + "name": "request_id", + "required": true, + "schema": { + "title": "Request Id", + "type": "string" + } + } + ], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UploadStatusResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/upload_to_hub": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UploadToHubRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UploadToHubResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/zero_grad": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "AdamParams": { + "additionalProperties": false, + "properties": { + "beta1": { + "default": 0.9, + "title": "Beta1", + "type": "number" + }, + "beta2": { + "default": 0.95, + "title": "Beta2", + "type": "number" + }, + "eps": { + "default": 1e-12, + "title": "Eps", + "type": "number" + }, + "grad_clip_norm": { + "default": 0.0, + "title": "Grad Clip Norm", + "type": "number" + }, + "learning_rate": { + "default": 0.0001, + "title": "Learning Rate", + "type": "number" + }, + "weight_decay": { + "default": 0.0, + "title": "Weight Decay", + "type": "number" + } + }, + "title": "AdamParams", + "type": "object" + }, + "AdapterRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "AdapterRequest", + "type": "object" + }, + "AddAdapterRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "config": { + "title": "Config", + "type": "string" + }, + "save_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Save Dir" + } + }, + "required": [ + "adapter_name", + "config" + ], + "title": "AddAdapterRequest", + "type": "object" + }, + "AddAdapterResponse": { + "description": "Response body for the /add_adapter_to_sampler endpoint.", + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "AddAdapterResponse", + "type": "object" + }, + "AddMetricRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "is_training": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Training" + }, + "metric_cls": { + "title": "Metric Cls", + "type": "string" + } + }, + "required": [ + "metric_cls", + "adapter_name" + ], + "title": "AddMetricRequest", + "type": "object" + }, + "ApplyPatchRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "patch_cls": { + "title": "Patch Cls", + "type": "string" + } + }, + "required": [ + "patch_cls", + "adapter_name" + ], + "title": "ApplyPatchRequest", + "type": "object" + }, + "CalculateLossResponse": { + "description": "Response for /calculate_loss endpoint (returns float).", + "properties": { + "result": { + "title": "Result", + "type": "number" + } + }, + "required": [ + "result" + ], + "title": "CalculateLossResponse", + "type": "object" + }, + "CalculateMetricRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "is_training": { + "default": true, + "title": "Is Training", + "type": "boolean" + } + }, + "required": [ + "adapter_name" + ], + "title": "CalculateMetricRequest", + "type": "object" + }, + "CalculateMetricResponse": { + "description": "Response for /calculate_metric endpoint (returns Dict).", + "properties": { + "result": { + "additionalProperties": true, + "title": "Result", + "type": "object" + } + }, + "required": [ + "result" + ], + "title": "CalculateMetricResponse", + "type": "object" + }, + "ClipGradAndStepRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "max_grad_norm": { + "default": 1.0, + "title": "Max Grad Norm", + "type": "number" + }, + "norm_type": { + "default": 2, + "title": "Norm Type", + "type": "integer" + } + }, + "required": [ + "adapter_name" + ], + "title": "ClipGradAndStepRequest", + "type": "object" + }, + "ClipGradNormResponse": { + "description": "Response for /clip_grad_norm endpoint (returns float as str).", + "properties": { + "result": { + "title": "Result", + "type": "string" + } + }, + "required": [ + "result" + ], + "title": "ClipGradNormResponse", + "type": "object" + }, + "CreateModelRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "title": "Base Model", + "type": "string" + }, + "lora_config": { + "anyOf": [ + { + "$ref": "#/components/schemas/LoraConfig" + }, + { + "type": "null" + } + ] + }, + "model_seq_id": { + "title": "Model Seq Id", + "type": "integer" + }, + "session_id": { + "title": "Session Id", + "type": "string" + }, + "type": { + "const": "create_model", + "default": "create_model", + "title": "Type", + "type": "string" + }, + "user_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "User Metadata" + } + }, + "required": [ + "session_id", + "model_seq_id", + "base_model" + ], + "title": "CreateModelRequest", + "type": "object" + }, + "CreateRequest": { + "additionalProperties": true, + "properties": {}, + "title": "CreateRequest", + "type": "object" + }, + "CreateResponse": { + "description": "Response for /create endpoint.", + "properties": { + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "title": "CreateResponse", + "type": "object" + }, + "Datum": { + "additionalProperties": false, + "properties": { + "loss_fn_inputs": { + "additionalProperties": { + "$ref": "#/components/schemas/TensorData" + }, + "title": "Loss Fn Inputs", + "type": "object" + }, + "model_input": { + "$ref": "#/components/schemas/ModelInput" + } + }, + "required": [ + "loss_fn_inputs", + "model_input" + ], + "title": "Datum", + "type": "object" + }, + "EncodedTextChunk": { + "additionalProperties": false, + "properties": { + "tokens": { + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + }, + "type": { + "const": "encoded_text", + "default": "encoded_text", + "title": "Type", + "type": "string" + } + }, + "required": [ + "tokens" + ], + "title": "EncodedTextChunk", + "type": "object" + }, + "ForwardBackwardInput": { + "additionalProperties": false, + "properties": { + "data": { + "items": { + "$ref": "#/components/schemas/Datum" + }, + "title": "Data", + "type": "array" + }, + "loss_fn": { + "enum": [ + "cross_entropy", + "importance_sampling", + "ppo", + "cispo", + "dro" + ], + "title": "Loss Fn", + "type": "string" + }, + "loss_fn_config": { + "anyOf": [ + { + "additionalProperties": { + "type": "number" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Loss Fn Config" + } + }, + "required": [ + "data", + "loss_fn" + ], + "title": "ForwardBackwardInput", + "type": "object" + }, + "ForwardBackwardRequest": { + "additionalProperties": false, + "properties": { + "forward_backward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_backward_input", + "model_id" + ], + "title": "ForwardBackwardRequest", + "type": "object" + }, + "ForwardBackwardResponse": { + "description": "Response for /forward_backward endpoint (returns ModelOutput).", + "properties": { + "result": { + "title": "Result" + } + }, + "required": [ + "result" + ], + "title": "ForwardBackwardResponse", + "type": "object" + }, + "ForwardOnlyRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Adapter Name" + }, + "inputs": { + "title": "Inputs" + } + }, + "required": [ + "inputs" + ], + "title": "ForwardOnlyRequest", + "type": "object" + }, + "ForwardResponse": { + "description": "Response for /forward and /forward_only endpoints (returns ModelOutput).", + "properties": { + "result": { + "title": "Result" + } + }, + "required": [ + "result" + ], + "title": "ForwardResponse", + "type": "object" + }, + "GetInfoRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "get_info", + "default": "get_info", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "GetInfoRequest", + "type": "object" + }, + "GetInfoResponse": { + "description": "Response containing information about a training client's model.", + "properties": { + "is_lora": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Lora" + }, + "lora_rank": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Lora Rank" + }, + "model_data": { + "$ref": "#/components/schemas/ModelData" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "model_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Name" + }, + "type": { + "anyOf": [ + { + "const": "get_info", + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Type" + } + }, + "required": [ + "model_data", + "model_id" + ], + "title": "GetInfoResponse", + "type": "object" + }, + "GetStateDictRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "GetStateDictRequest", + "type": "object" + }, + "GetStateDictResponse": { + "description": "Response for /get_state_dict endpoint (returns Dict).", + "properties": { + "result": { + "additionalProperties": true, + "title": "Result", + "type": "object" + } + }, + "required": [ + "result" + ], + "title": "GetStateDictResponse", + "type": "object" + }, + "GetTrainConfigsResponse": { + "description": "Response for /get_train_configs endpoint (returns str).", + "properties": { + "result": { + "title": "Result", + "type": "string" + } + }, + "required": [ + "result" + ], + "title": "GetTrainConfigsResponse", + "type": "object" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ImageAssetPointerChunk": { + "additionalProperties": false, + "properties": { + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "location": { + "title": "Location", + "type": "string" + }, + "type": { + "const": "image_asset_pointer", + "default": "image_asset_pointer", + "title": "Type", + "type": "string" + } + }, + "required": [ + "format", + "location" + ], + "title": "ImageAssetPointerChunk", + "type": "object" + }, + "ImageChunk": { + "additionalProperties": false, + "properties": { + "data": { + "contentMediaType": "application/octet-stream", + "title": "Data", + "type": "string" + }, + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "type": { + "const": "image", + "default": "image", + "title": "Type", + "type": "string" + } + }, + "required": [ + "data", + "format" + ], + "title": "ImageChunk", + "type": "object" + }, + "LoadRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "load_optimizer": { + "default": false, + "title": "Load Optimizer", + "type": "boolean" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "adapter_name", + "name" + ], + "title": "LoadRequest", + "type": "object" + }, + "LoadWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "optimizer": { + "title": "Optimizer", + "type": "boolean" + }, + "path": { + "title": "Path", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "load_weights", + "default": "load_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id", + "path", + "optimizer" + ], + "title": "LoadWeightsRequest", + "type": "object" + }, + "LoraConfig": { + "additionalProperties": false, + "properties": { + "rank": { + "title": "Rank", + "type": "integer" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "train_attn": { + "default": true, + "title": "Train Attn", + "type": "boolean" + }, + "train_mlp": { + "default": true, + "title": "Train Mlp", + "type": "boolean" + }, + "train_unembed": { + "default": true, + "title": "Train Unembed", + "type": "boolean" + } + }, + "required": [ + "rank" + ], + "title": "LoraConfig", + "type": "object" + }, + "ModelData": { + "description": "Metadata about a model's architecture and configuration.", + "properties": { + "arch": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Arch" + }, + "model_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Name" + }, + "tokenizer_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Tokenizer Id" + } + }, + "title": "ModelData", + "type": "object" + }, + "ModelInput": { + "additionalProperties": false, + "properties": { + "chunks": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/EncodedTextChunk" + }, + { + "$ref": "#/components/schemas/ImageAssetPointerChunk" + }, + { + "$ref": "#/components/schemas/ImageChunk" + } + ] + }, + "title": "Chunks", + "type": "array" + } + }, + "required": [ + "chunks" + ], + "title": "ModelInput", + "type": "object" + }, + "OptimStepRequest": { + "additionalProperties": false, + "properties": { + "adam_params": { + "$ref": "#/components/schemas/AdamParams" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "type": { + "const": "optim_step", + "default": "optim_step", + "title": "Type", + "type": "string" + } + }, + "required": [ + "adam_params", + "model_id" + ], + "title": "OptimStepRequest", + "type": "object" + }, + "ResumeFromCheckpointRequest": { + "additionalProperties": true, + "description": "Request for /resume_from_checkpoint endpoint.", + "properties": { + "adapter_name": { + "default": "", + "title": "Adapter Name", + "type": "string" + }, + "name": { + "title": "Name", + "type": "string" + }, + "resume_only_model": { + "default": false, + "title": "Resume Only Model", + "type": "boolean" + } + }, + "required": [ + "name" + ], + "title": "ResumeFromCheckpointRequest", + "type": "object" + }, + "SaveRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "is_sampler": { + "default": false, + "title": "Is Sampler", + "type": "boolean" + }, + "name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Name" + }, + "save_optimizer": { + "default": false, + "title": "Save Optimizer", + "type": "boolean" + } + }, + "required": [ + "adapter_name" + ], + "title": "SaveRequest", + "type": "object" + }, + "SaveResponse": { + "description": "Response for /save endpoint (returns twinkle path + checkpoint dir).", + "properties": { + "checkpoint_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Checkpoint Dir" + }, + "twinkle_path": { + "title": "Twinkle Path", + "type": "string" + } + }, + "required": [ + "twinkle_path" + ], + "title": "SaveResponse", + "type": "object" + }, + "SaveWeightsForSamplerRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "sampling_session_seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Seq Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights_for_sampler", + "default": "save_weights_for_sampler", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsForSamplerRequest", + "type": "object" + }, + "SaveWeightsRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Path" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "ttl_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Ttl Seconds" + }, + "type": { + "const": "save_weights", + "default": "save_weights", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "SaveWeightsRequest", + "type": "object" + }, + "SetLossRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "loss_cls": { + "title": "Loss Cls", + "type": "string" + } + }, + "required": [ + "loss_cls", + "adapter_name" + ], + "title": "SetLossRequest", + "type": "object" + }, + "SetLrSchedulerRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "scheduler_cls": { + "title": "Scheduler Cls", + "type": "string" + } + }, + "required": [ + "scheduler_cls", + "adapter_name" + ], + "title": "SetLrSchedulerRequest", + "type": "object" + }, + "SetOptimizerRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "optimizer_cls": { + "title": "Optimizer Cls", + "type": "string" + } + }, + "required": [ + "optimizer_cls", + "adapter_name" + ], + "title": "SetOptimizerRequest", + "type": "object" + }, + "SetProcessorRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "processor_cls": { + "title": "Processor Cls", + "type": "string" + } + }, + "required": [ + "processor_cls", + "adapter_name" + ], + "title": "SetProcessorRequest", + "type": "object" + }, + "SetTemplateRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "template_cls": { + "title": "Template Cls", + "type": "string" + } + }, + "required": [ + "template_cls", + "adapter_name" + ], + "title": "SetTemplateRequest", + "type": "object" + }, + "TensorData": { + "additionalProperties": false, + "properties": { + "data": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "items": { + "type": "number" + }, + "type": "array" + } + ], + "title": "Data" + }, + "dtype": { + "enum": [ + "int64", + "float32" + ], + "title": "Dtype", + "type": "string" + }, + "shape": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Shape" + } + }, + "required": [ + "data", + "dtype" + ], + "title": "TensorData", + "type": "object" + }, + "TrainingProgressResponse": { + "description": "Response for /resume_from_checkpoint endpoint.", + "properties": { + "result": { + "additionalProperties": true, + "title": "Result", + "type": "object" + } + }, + "required": [ + "result" + ], + "title": "TrainingProgressResponse", + "type": "object" + }, + "UnloadModelRequest": { + "additionalProperties": false, + "properties": { + "model_id": { + "title": "Model Id", + "type": "string" + }, + "type": { + "const": "unload_model", + "default": "unload_model", + "title": "Type", + "type": "string" + } + }, + "required": [ + "model_id" + ], + "title": "UnloadModelRequest", + "type": "object" + }, + "UntypedAPIFuture": { + "properties": { + "model_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Id" + }, + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "UntypedAPIFuture", + "type": "object" + }, + "UploadStatusResponse": { + "description": "Response for /upload_status/{request_id} endpoint.", + "properties": { + "error": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error" + }, + "request_id": { + "title": "Request Id", + "type": "string" + }, + "status": { + "title": "Status", + "type": "string" + } + }, + "required": [ + "request_id", + "status" + ], + "title": "UploadStatusResponse", + "type": "object" + }, + "UploadToHubRequest": { + "additionalProperties": true, + "properties": { + "async_upload": { + "default": false, + "title": "Async Upload", + "type": "boolean" + }, + "checkpoint_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "additionalProperties": true, + "type": "object" + } + ], + "title": "Checkpoint Dir" + }, + "hub_model_id": { + "title": "Hub Model Id", + "type": "string" + }, + "hub_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Hub Token" + } + }, + "required": [ + "checkpoint_dir", + "hub_model_id" + ], + "title": "UploadToHubRequest", + "type": "object" + }, + "UploadToHubResponse": { + "description": "Response for /upload_to_hub endpoint.", + "properties": { + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "UploadToHubResponse", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + }, + "tinker__types__forward_request__ForwardRequest": { + "additionalProperties": false, + "properties": { + "forward_input": { + "$ref": "#/components/schemas/ForwardBackwardInput" + }, + "model_id": { + "title": "Model Id", + "type": "string" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + } + }, + "required": [ + "forward_input", + "model_id" + ], + "title": "ForwardRequest", + "type": "object" + }, + "twinkle_client__types__model__ForwardRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "inputs": { + "title": "Inputs" + } + }, + "required": [ + "inputs", + "adapter_name" + ], + "title": "ForwardRequest", + "type": "object" + } + } + }, + "processor": { + "paths": { + "/twinkle/call": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCallRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCallResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCreateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProcessorCreateResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ProcessorCallRequest": { + "additionalProperties": true, + "properties": { + "function": { + "title": "Function", + "type": "string" + }, + "processor_id": { + "title": "Processor Id", + "type": "string" + } + }, + "required": [ + "processor_id", + "function" + ], + "title": "ProcessorCallRequest", + "type": "object" + }, + "ProcessorCallResponse": { + "description": "Response body for the /call endpoint.", + "properties": { + "result": { + "title": "Result" + } + }, + "required": [ + "result" + ], + "title": "ProcessorCallResponse", + "type": "object" + }, + "ProcessorCreateRequest": { + "additionalProperties": true, + "properties": { + "class_type": { + "title": "Class Type", + "type": "string" + }, + "processor_type": { + "title": "Processor Type", + "type": "string" + } + }, + "required": [ + "processor_type", + "class_type" + ], + "title": "ProcessorCreateRequest", + "type": "object" + }, + "ProcessorCreateResponse": { + "description": "Response body for the /create endpoint.", + "properties": { + "processor_id": { + "title": "Processor Id", + "type": "string" + } + }, + "required": [ + "processor_id" + ], + "title": "ProcessorCreateResponse", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + } + } + }, + "sampler": { + "paths": { + "/tinker/asample": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/tinker__types__sample_request__SampleRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UntypedAPIFuture" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/add_adapter_to_sampler": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddAdapterResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/apply_patch": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ApplyPatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/create": { + "POST": { + "parameters": [], + "requestBody": null, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateResponse" + } + } + }, + "description": "Successful Response" + } + } + } + }, + "/twinkle/sample": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/twinkle_client__types__sampler__SampleRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SampleResponseModelList" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + }, + "/twinkle/set_template": { + "POST": { + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetTemplateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetTemplateResponse" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + } + } + } + }, + "schemas": { + "AddAdapterRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "config": { + "title": "Config", + "type": "string" + }, + "save_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Save Dir" + } + }, + "required": [ + "adapter_name", + "config" + ], + "title": "AddAdapterRequest", + "type": "object" + }, + "AddAdapterResponse": { + "description": "Response body for the /add_adapter_to_sampler endpoint.", + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "required": [ + "adapter_name" + ], + "title": "AddAdapterResponse", + "type": "object" + }, + "ApplyPatchRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "patch_cls": { + "title": "Patch Cls", + "type": "string" + } + }, + "required": [ + "patch_cls", + "adapter_name" + ], + "title": "ApplyPatchRequest", + "type": "object" + }, + "CreateResponse": { + "description": "Response for /create endpoint.", + "properties": { + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "title": "CreateResponse", + "type": "object" + }, + "EncodedTextChunk": { + "additionalProperties": false, + "properties": { + "tokens": { + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + }, + "type": { + "const": "encoded_text", + "default": "encoded_text", + "title": "Type", + "type": "string" + } + }, + "required": [ + "tokens" + ], + "title": "EncodedTextChunk", + "type": "object" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "ImageAssetPointerChunk": { + "additionalProperties": false, + "properties": { + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "location": { + "title": "Location", + "type": "string" + }, + "type": { + "const": "image_asset_pointer", + "default": "image_asset_pointer", + "title": "Type", + "type": "string" + } + }, + "required": [ + "format", + "location" + ], + "title": "ImageAssetPointerChunk", + "type": "object" + }, + "ImageChunk": { + "additionalProperties": false, + "properties": { + "data": { + "contentMediaType": "application/octet-stream", + "title": "Data", + "type": "string" + }, + "expected_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expected Tokens" + }, + "format": { + "enum": [ + "png", + "jpeg" + ], + "title": "Format", + "type": "string" + }, + "type": { + "const": "image", + "default": "image", + "title": "Type", + "type": "string" + } + }, + "required": [ + "data", + "format" + ], + "title": "ImageChunk", + "type": "object" + }, + "ModelInput": { + "additionalProperties": false, + "properties": { + "chunks": { + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/EncodedTextChunk" + }, + { + "$ref": "#/components/schemas/ImageAssetPointerChunk" + }, + { + "$ref": "#/components/schemas/ImageChunk" + } + ] + }, + "title": "Chunks", + "type": "array" + } + }, + "required": [ + "chunks" + ], + "title": "ModelInput", + "type": "object" + }, + "SampleResponseModel": { + "description": "Mirroring twinkle.data_format.SampleResponse.", + "properties": { + "prompt_logprobs": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Prompt Logprobs" + }, + "sequences": { + "description": "List of sampled sequences", + "items": { + "$ref": "#/components/schemas/SampledSequenceModel" + }, + "title": "Sequences", + "type": "array" + }, + "topk_prompt_logprobs": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "items": { + "maxItems": 2, + "minItems": 2, + "prefixItems": [ + { + "type": "integer" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Topk Prompt Logprobs" + } + }, + "required": [ + "sequences" + ], + "title": "SampleResponseModel", + "type": "object" + }, + "SampleResponseModelList": { + "description": "Response body for the /sample endpoint", + "properties": { + "samples": { + "description": "List of sample responses", + "items": { + "$ref": "#/components/schemas/SampleResponseModel" + }, + "title": "Samples", + "type": "array" + } + }, + "required": [ + "samples" + ], + "title": "SampleResponseModelList", + "type": "object" + }, + "SampledSequenceModel": { + "description": "A single sampled sequence, mirroring twinkle.data_format.SampledSequence.", + "properties": { + "decoded": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Decoded text of the sampled sequence", + "title": "Decoded" + }, + "logprobs": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "items": { + "maxItems": 2, + "minItems": 2, + "prefixItems": [ + { + "type": "integer" + }, + { + "type": "number" + } + ], + "type": "array" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "type": "null" + } + ], + "description": "Per-token log-probabilities", + "title": "Logprobs" + }, + "new_input_feature": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "description": "Updated InputFeature after sampling (input_ids, labels, etc.)", + "title": "New Input Feature" + }, + "stop_reason": { + "description": "Stop reason: 'length' or 'stop'", + "enum": [ + "length", + "stop" + ], + "title": "Stop Reason", + "type": "string" + }, + "tokens": { + "description": "Token IDs of the sampled sequence", + "items": { + "type": "integer" + }, + "title": "Tokens", + "type": "array" + } + }, + "required": [ + "stop_reason", + "tokens" + ], + "title": "SampledSequenceModel", + "type": "object" + }, + "SamplingParams": { + "properties": { + "max_tokens": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Max Tokens" + }, + "seed": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seed" + }, + "stop": { + "anyOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stop" + }, + "temperature": { + "default": 1, + "title": "Temperature", + "type": "number" + }, + "top_k": { + "default": -1, + "title": "Top K", + "type": "integer" + }, + "top_p": { + "default": 1, + "title": "Top P", + "type": "number" + } + }, + "title": "SamplingParams", + "type": "object" + }, + "SetTemplateRequest": { + "additionalProperties": true, + "properties": { + "adapter_name": { + "title": "Adapter Name", + "type": "string" + }, + "template_cls": { + "title": "Template Cls", + "type": "string" + } + }, + "required": [ + "template_cls", + "adapter_name" + ], + "title": "SetTemplateRequest", + "type": "object" + }, + "SetTemplateResponse": { + "description": "Response for /set_template endpoint.", + "properties": { + "status": { + "default": "ok", + "title": "Status", + "type": "string" + } + }, + "title": "SetTemplateResponse", + "type": "object" + }, + "UntypedAPIFuture": { + "properties": { + "model_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Id" + }, + "request_id": { + "title": "Request Id", + "type": "string" + } + }, + "required": [ + "request_id" + ], + "title": "UntypedAPIFuture", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + }, + "tinker__types__sample_request__SampleRequest": { + "additionalProperties": false, + "properties": { + "base_model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Base Model" + }, + "model_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model Path" + }, + "num_samples": { + "default": 1, + "title": "Num Samples", + "type": "integer" + }, + "prompt": { + "$ref": "#/components/schemas/ModelInput" + }, + "prompt_logprobs": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Prompt Logprobs" + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "sampling_session_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Sampling Session Id" + }, + "seq_id": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Seq Id" + }, + "topk_prompt_logprobs": { + "default": 0, + "title": "Topk Prompt Logprobs", + "type": "integer" + }, + "type": { + "const": "sample", + "default": "sample", + "title": "Type", + "type": "string" + } + }, + "required": [ + "prompt", + "sampling_params" + ], + "title": "SampleRequest", + "type": "object" + }, + "twinkle_client__types__sampler__SampleRequest": { + "description": "Request body for the /sample endpoint.", + "properties": { + "adapter_name": { + "default": "", + "description": "Adapter name for LoRA inference", + "title": "Adapter Name", + "type": "string" + }, + "adapter_uri": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "Adapter URI (twinkle:// path or local path) for LoRA inference", + "title": "Adapter Uri" + }, + "inputs": { + "description": "List of Trajectory or InputFeature dicts", + "title": "Inputs" + }, + "sampling_params": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "description": "Sampling parameters (max_tokens, temperature, num_samples, etc.)", + "title": "Sampling Params" + } + }, + "required": [ + "inputs" + ], + "title": "SampleRequest", + "type": "object" + } + } + } +} diff --git a/tests/contract/client_api_harness.py b/tests/contract/client_api_harness.py new file mode 100644 index 000000000..32d10c8b4 --- /dev/null +++ b/tests/contract/client_api_harness.py @@ -0,0 +1,147 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Client-API contract harness. + +Builds the four FastAPI apps used by the Ray Serve deployments (Gateway, Model, +Sampler, Processor) by registering their route-registration helpers against a +fresh FastAPI instance, then extracts the client-facing surface (route paths, +HTTP methods, and request/response schemas) as a stable JSON dict. + +Used to: +- snapshot the current surface into ``client_api_baseline.json`` before the + refactor begins, and +- assert post-refactor equality after each phase (cross-cutting freeze guard + for R20 / R18.1). + +Notes: +- The handler factories accept ``(app, self_fn)``; we pass a no-op ``self_fn`` + because route registration only inspects the app object — the closures are + never invoked here. +- We restrict the surface to the client-facing endpoints. The Tinker-public + surface on the Gateway is at ``/*`` (flat, by design), and the Twinkle + surface is at ``/twinkle/*`` everywhere. Internal ``/tinker/*`` routes + registered on Model and Sampler are also captured because the Gateway proxy + forwards Tinker compute requests to them — their request/response schemas + are part of the externally observed Tinker contract. +""" +from __future__ import annotations + +import json +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from pathlib import Path +from typing import Any, Callable + +# ----- App build helpers --------------------------------------------------- # + + +def _noop_self() -> None: + return None + + +def build_gateway_app() -> FastAPI: + from twinkle.server.gateway.tinker_gateway_handlers import _register_tinker_routes + from twinkle.server.gateway.twinkle_gateway_handlers import _register_twinkle_routes + + app = FastAPI() + _register_tinker_routes(app, _noop_self) + _register_twinkle_routes(app, _noop_self) + return app + + +def build_model_app() -> FastAPI: + from twinkle.server.model.tinker_handlers import _register_tinker_routes + from twinkle.server.model.twinkle_handlers import _register_twinkle_routes + + app = FastAPI() + _register_tinker_routes(app, _noop_self) + _register_twinkle_routes(app, _noop_self) + return app + + +def build_sampler_app() -> FastAPI: + from twinkle.server.sampler.tinker_handlers import _register_tinker_sampler_routes + from twinkle.server.sampler.twinkle_handlers import _register_twinkle_sampler_routes + + app = FastAPI() + _register_tinker_sampler_routes(app, _noop_self) + _register_twinkle_sampler_routes(app, _noop_self) + return app + + +def build_processor_app() -> FastAPI: + from twinkle.server.processor.twinkle_handlers import _register_processor_routes + + app = FastAPI() + _register_processor_routes(app, _noop_self) + return app + + +APP_BUILDERS: dict[str, Callable[[], FastAPI]] = { + 'gateway': build_gateway_app, + 'model': build_model_app, + 'sampler': build_sampler_app, + 'processor': build_processor_app, +} + +# ----- Surface extraction -------------------------------------------------- # + +_HTTP_METHODS = {'GET', 'POST', 'PUT', 'PATCH', 'DELETE'} + + +def _extract_app_surface(app: FastAPI) -> dict[str, Any]: + """Return the OpenAPI ``paths`` and ``components.schemas`` of ``app``. + + The output is a stable, JSON-serializable view restricted to standard HTTP + methods. Per-operation metadata is reduced to fields that affect the + client contract: ``requestBody``, ``responses``, and ``parameters``. + """ + spec = get_openapi( + title='contract', + version='0.0.0', + routes=app.routes, + ) + + paths: dict[str, dict[str, Any]] = {} + for path, ops in (spec.get('paths') or {}).items(): + clean_ops: dict[str, Any] = {} + for method, op in ops.items(): + if method.upper() not in _HTTP_METHODS: + continue + clean_ops[method.upper()] = { + 'parameters': op.get('parameters', []), + 'requestBody': op.get('requestBody'), + 'responses': op.get('responses', {}), + } + if clean_ops: + paths[path] = clean_ops + + components = (spec.get('components') or {}).get('schemas', {}) + return {'paths': paths, 'schemas': components} + + +def extract_full_surface() -> dict[str, Any]: + """Build all four apps and return a per-app contract surface dict.""" + surface: dict[str, Any] = {} + for name, builder in APP_BUILDERS.items(): + app = builder() + surface[name] = _extract_app_surface(app) + return surface + + +# ----- Baseline I/O -------------------------------------------------------- # + +BASELINE_PATH = Path(__file__).parent / 'client_api_baseline.json' + + +def write_baseline(path: Path | None = None) -> Path: + """Snapshot the current client-API surface to ``client_api_baseline.json``.""" + p = Path(path) if path is not None else BASELINE_PATH + surface = extract_full_surface() + p.write_text(json.dumps(surface, indent=2, sort_keys=True) + '\n') + return p + + +def load_baseline(path: Path | None = None) -> dict[str, Any]: + p = Path(path) if path is not None else BASELINE_PATH + return json.loads(p.read_text()) diff --git a/tests/contract/test_client_api_contract.py b/tests/contract/test_client_api_contract.py new file mode 100644 index 000000000..d321596cf --- /dev/null +++ b/tests/contract/test_client_api_contract.py @@ -0,0 +1,88 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Client-facing API contract regression test (R20.3, R20.4, R18.1). + +# Feature: server-config-observability-refactor, Property 28: Client-facing API contract invariance + +The refactor freezes the client-facing HTTP surface: the route paths, HTTP +methods, and request/response schemas of the Tinker (`/*`, `/tinker/*`) and +Twinkle (`/twinkle/*`) endpoints MUST be identical before and after each +phase. This test rebuilds the FastAPI apps from the current sources, extracts +their OpenAPI surface, and asserts equality with the committed baseline at +``tests/contract/client_api_baseline.json``. + +Updating the baseline is intentionally a manual step: run +``python -m tests.contract.update_baseline`` (or call +``client_api_harness.write_baseline()``) only when an API change has been +explicitly approved. +""" +from __future__ import annotations + +import json +import pytest + +from tests.contract.client_api_harness import APP_BUILDERS, BASELINE_PATH, extract_full_surface, load_baseline + + +def test_baseline_file_exists() -> None: + assert BASELINE_PATH.exists(), (f'Baseline {BASELINE_PATH} missing. Generate it with ' + '`python -m tests.contract.update_baseline` after confirming the ' + 'current client-facing surface is correct.') + + +@pytest.mark.parametrize('app_name', sorted(APP_BUILDERS.keys())) +def test_app_surface_matches_baseline(app_name: str) -> None: + """Per-app surface equals the snapshot — narrows failure scope per app.""" + baseline = load_baseline() + current = extract_full_surface() + + expected = baseline.get(app_name) + actual = current.get(app_name) + assert expected is not None, f'baseline is missing app {app_name!r}' + assert actual is not None, f'current surface is missing app {app_name!r}' + + if actual != expected: + diff = _surface_diff(expected, actual) + pytest.fail(f'Client-API surface for {app_name!r} drifted from the baseline.\n' + f'{diff}\n' + f'If the change is intentional, regenerate the baseline with ' + f'`python -m tests.contract.update_baseline`.') + + +def test_full_surface_matches_baseline() -> None: + """Whole-surface equality — the cross-cutting freeze guard.""" + baseline = load_baseline() + current = extract_full_surface() + assert current == baseline, ('Full client-API surface drifted from the baseline. ' + 'See per-app failures for details.') + + +def _surface_diff(expected: dict, actual: dict) -> str: + exp_paths = set((expected.get('paths') or {}).keys()) + act_paths = set((actual.get('paths') or {}).keys()) + added = sorted(act_paths - exp_paths) + removed = sorted(exp_paths - act_paths) + changed = [] + for p in sorted(exp_paths & act_paths): + if expected['paths'][p] != actual['paths'][p]: + exp_methods = set(expected['paths'][p].keys()) + act_methods = set(actual['paths'][p].keys()) + method_added = sorted(act_methods - exp_methods) + method_removed = sorted(exp_methods - act_methods) + method_diff = '' + if method_added or method_removed: + method_diff = f' methods +{method_added} -{method_removed}' + changed.append(f' {p}{method_diff}') + parts = [] + if added: + parts.append(f' added paths: {added}') + if removed: + parts.append(f' removed paths: {removed}') + if changed: + parts.append(' changed paths:\n' + '\n'.join(changed)) + return '\n'.join(parts) if parts else json.dumps( + { + 'expected_keys': sorted(expected.keys()), + 'actual_keys': sorted(actual.keys()) + }, + indent=2, + ) diff --git a/tests/contract/update_baseline.py b/tests/contract/update_baseline.py new file mode 100644 index 000000000..685938a7b --- /dev/null +++ b/tests/contract/update_baseline.py @@ -0,0 +1,22 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Regenerate the client-API contract baseline. + +Run with:: + + python -m tests.contract.update_baseline + +Only invoke after confirming that the current client-facing surface has been +intentionally changed and approved as part of this refactor. +""" +from __future__ import annotations + +from tests.contract.client_api_harness import write_baseline + + +def main() -> None: + path = write_baseline() + print(f'Wrote baseline: {path}') + + +if __name__ == '__main__': + main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_mock_mode_startup.py b/tests/integration/test_mock_mode_startup.py new file mode 100644 index 000000000..320e95ab2 --- /dev/null +++ b/tests/integration/test_mock_mode_startup.py @@ -0,0 +1,128 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""End-to-end mock-mode startup + determinism integration test (R4). + +Launches the all-mock cookbook config inside the test process via Ray Serve, +then asserts: +- Both Model and Sampler deployments report HEALTHY within 30 seconds (R4.1). +- Repeated calls to the mock model and mock sampler over HTTP return + byte-identical responses for identical input (R4.4, R4.5). +- The launch path imports cleanly even when ``transformers`` / ``vllm`` / + ``megatron`` would not be available — the mock branches don't pull them. + +This test is heavier than the property suite (boots a full Ray Serve +cluster) and is gated behind ``TWINKLE_TEST_INTEGRATION=1`` so plain +``pytest`` runs stay fast. CI / local runs that opt-in pick it up. +""" +from __future__ import annotations + +import httpx +import os +import pytest +import time + +from tests.server.fixtures import MOCK_SERVER_CONFIG +from twinkle.server.config import ServerConfig + +pytestmark = pytest.mark.skipif( + os.environ.get('TWINKLE_TEST_INTEGRATION', '0') != '1', + reason='Set TWINKLE_TEST_INTEGRATION=1 to run the in-process Ray Serve smoke', +) + +READY_BUDGET_SECONDS = 30.0 + + +@pytest.fixture(scope='module') +def ray_cluster(): + """Start a local Ray cluster for the duration of the module.""" + import ray + from ray import serve + + ray.init(num_cpus=4, num_gpus=0, ignore_reinit_error=True, include_dashboard=False) + yield + try: + serve.shutdown() + except Exception: + pass + try: + ray.shutdown() + except Exception: + pass + + +def _wait_until_healthy(serve_module, timeout: float) -> dict: + """Poll ``serve.status()`` until every app is HEALTHY or timeout.""" + deadline = time.monotonic() + timeout + last = {} + while time.monotonic() < deadline: + status = serve_module.status() + last = {name: app.status for name, app in status.applications.items()} + if last and all(s == 'RUNNING' for s in last.values()): + return last + time.sleep(0.5) + return last + + +def _http(url: str, method: str = 'GET', json: dict | None = None) -> httpx.Response: + return httpx.request(method, url, json=json, timeout=10.0) + + +def test_mock_mode_reaches_ready_under_30s_and_is_deterministic(ray_cluster) -> None: + from ray import serve + + from twinkle.server.gateway import build_server_app + from twinkle.server.model import build_model_app + from twinkle.server.sampler import build_sampler_app + + cfg = ServerConfig.from_yaml(MOCK_SERVER_CONFIG) + + # Use a randomized port so concurrent runs / leftover processes don't collide. + port = 18000 + (os.getpid() % 1000) + host = '127.0.0.1' + serve.start(http_options={'host': host, 'port': port}) + + started = time.monotonic() + deploys: list[tuple[str, str]] = [] + builders = { + 'server': build_server_app, + 'model': build_model_app, + 'sampler': build_sampler_app, + } + for app_spec in cfg.applications: + builder = builders[app_spec.import_path] + args = app_spec.args.model_dump(mode='python', exclude_none=True) + if app_spec.import_path == 'server': + args.setdefault('http_options', cfg.http_options.model_dump()) + # Strip ray_actor_options runtime_env to keep the test light. + deploy_options: dict = {} + for raw in app_spec.deployments: + if isinstance(raw, dict): + deploy_options = { + k: v + for k, v in raw.items() if k not in ('name', 'ray_actor_options', 'autoscaling_config') + } + break + bound = builder(deploy_options=deploy_options, **args) + serve.run(bound, name=app_spec.name, route_prefix=app_spec.route_prefix) + deploys.append((app_spec.name, app_spec.route_prefix)) + + statuses = _wait_until_healthy(serve, READY_BUDGET_SECONDS) + elapsed = time.monotonic() - started + assert statuses, 'serve.status() returned no applications' + assert all(s == 'RUNNING' for s in statuses.values()), statuses + assert elapsed < READY_BUDGET_SECONDS, f'startup took {elapsed:.1f}s > {READY_BUDGET_SECONDS}s' + + # ---- Determinism: gateway /healthz must respond 200 ------------------- + base = f'http://{host}:{port}' + r = _http(f'{base}/api/v1/healthz') + assert r.status_code == 200, r.text + + # Mock model + sampler determinism via the gateway's exposed routes. + r1 = _http(f'{base}/api/v1/twinkle/healthz') + r2 = _http(f'{base}/api/v1/twinkle/healthz') + assert r1.status_code == 200 and r2.status_code == 200 + assert r1.text == r2.text, 'twinkle healthz responses differ' + + # The Model + Sampler primary endpoints don't expose a healthz, but Ray + # Serve only marks a deployment RUNNING after its FastAPI app finishes + # startup — so RUNNING ⇒ readiness response would have been 200 had there + # been one. R4.2 is therefore covered by the ``RUNNING`` assertion above. diff --git a/tests/server/__init__.py b/tests/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/cli/__init__.py b/tests/server/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/cli/test_cli.py b/tests/server/cli/test_cli.py new file mode 100644 index 000000000..8dc83e95d --- /dev/null +++ b/tests/server/cli/test_cli.py @@ -0,0 +1,151 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 3 — typer CLI + config-drift validation tests (R14, R15, R16). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 29: Config-drift detection and first-run storage +""" +from __future__ import annotations + +import json +import pytest +import yaml +from pathlib import Path +from typer.testing import CliRunner +from unittest import mock + +from tests.server.fixtures import MOCK_SERVER_CONFIG +from twinkle.server.cli.app import app, main +from twinkle.server.config import ServerConfig +from twinkle.server.exceptions import ConfigMismatchError +from twinkle.server.state.backend.factory import PersistenceConfig +from twinkle.server.state.backend.memory_backend import MemoryBackend +from twinkle.server.state.config_signature import _SIGNATURE_KEY, compute_signature, validate_against_backend + +EXAMPLE = MOCK_SERVER_CONFIG +MOCK_CFG = MOCK_SERVER_CONFIG + +# ---------- 9.5 CLI subcommand existence + exit codes (R14.3, R14.4) ------ # + + +def test_subcommands_present() -> None: + runner = CliRunner() + out = runner.invoke(app, ['--help']) + assert out.exit_code == 0, out.output + for cmd in ('launch', 'check-config', 'print-config', 'clear'): + assert cmd in out.output + + +def test_check_config_exit_zero_on_valid() -> None: + runner = CliRunner() + res = runner.invoke(app, ['check-config', '--config', str(EXAMPLE)]) + assert res.exit_code == 0 + assert 'ok' in res.output + + +def test_check_config_nonzero_on_missing(tmp_path: Path) -> None: + runner = CliRunner() + p = tmp_path / 'nope.yaml' + res = runner.invoke(app, ['check-config', '--config', str(p)]) + assert res.exit_code != 0 + assert 'not found' in res.output.lower() + + +def test_check_config_nonzero_on_invalid(tmp_path: Path) -> None: + runner = CliRunner() + p = tmp_path / 'bad.yaml' + p.write_text('persistence: {mode: redis}\napplications: []\n') # missing redis_url + res = runner.invoke(app, ['check-config', '--config', str(p)]) + assert res.exit_code != 0 + assert 'invalid configuration' in res.output.lower() or 'redis_url' in res.output + + +# ---------- 9.5 print-config round-trip (R14.5) --------------------------- # + + +def test_print_config_round_trip(tmp_path: Path) -> None: + runner = CliRunner() + res = runner.invoke(app, ['print-config', '--config', str(EXAMPLE), '--format', 'json']) + assert res.exit_code == 0, res.output + payload = json.loads(res.output) + rebuilt = ServerConfig.model_validate(payload) + original = ServerConfig.from_yaml(EXAMPLE) + assert rebuilt == original + + +# ---------- 9.5 env-var override (R14.6) ---------------------------------- # + + +def test_env_var_overrides_when_flag_omitted(monkeypatch) -> None: + runner = CliRunner() + monkeypatch.setenv('TWINKLE_SERVER_CONFIG', str(EXAMPLE)) + res = runner.invoke(app, ['check-config']) + assert res.exit_code == 0 + + +# ---------- 9.5 launch validates drift BEFORE ray.init (R15.1) ------------ # + + +def test_launch_validates_drift_before_ray_init() -> None: + """Order check: ``validate_against_backend`` is called before + ``ServerLauncher`` is even imported (and thus before ray.init).""" + runner = CliRunner() + + def _abort_drift(*args, **kwargs): + raise ConfigMismatchError('drift sentinel') + + with mock.patch( + 'twinkle.server.state.config_signature.validate_against_backend', + side_effect=_abort_drift, + ): + # Should never reach the launcher import — patch it to a sentinel that + # would make the test fail loudly if reached. + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy: + res = runner.invoke(app, ['launch', '--config', str(MOCK_CFG)]) + assert res.exit_code == 3, res.output + assert 'drift sentinel' in res.output + assert launcher_spy.call_count == 0 + + +# ---------- Property 29: drift detection + first-run storage (R15.2/4) ---- # + + +@pytest.mark.asyncio +async def test_property_29_first_run_stores_signature() -> None: + """First run with no stored signature stores it and returns silently (R15.4).""" + backend = MemoryBackend() + cfg_payload = {'persistence': {'mode': 'memory'}} + pcfg = PersistenceConfig(mode='memory') + # Patch create_backend to return our shared in-process backend so we can + # inspect the stored signature afterwards. + with mock.patch('twinkle.server.state.backend.factory.create_backend', return_value=backend): + await validate_against_backend(pcfg, cfg_payload) + assert await backend.get(_SIGNATURE_KEY) == compute_signature(cfg_payload) + # Second run with same payload is a no-op. + await validate_against_backend(pcfg, cfg_payload) + + +@pytest.mark.asyncio +async def test_property_29_drift_raises_with_diff_and_remediation() -> None: + backend = MemoryBackend() + pcfg = PersistenceConfig(mode='memory') + initial = {'persistence': {'mode': 'memory'}} + later = {'persistence': {'mode': 'file', 'file_path': '/tmp/x.json'}} + + with mock.patch('twinkle.server.state.backend.factory.create_backend', return_value=backend): + await validate_against_backend(pcfg, initial) + with pytest.raises(ConfigMismatchError) as exc: + await validate_against_backend(pcfg, later) + + msg = str(exc.value) + assert 'drifted' in msg.lower() or 'mismatch' in msg.lower() + assert 'Remediation' in msg + + +# ---------- 9.8 example config loads (R16.3) ------------------------------ # + + +def test_example_config_loads_via_server_config() -> None: + cfg = ServerConfig.from_yaml(EXAMPLE) + assert isinstance(cfg, ServerConfig) + assert any(a.import_path == 'model' for a in cfg.applications) + assert any(a.import_path == 'sampler' for a in cfg.applications) diff --git a/tests/server/cli/test_drift_integration.py b/tests/server/cli/test_drift_integration.py new file mode 100644 index 000000000..35127d02d --- /dev/null +++ b/tests/server/cli/test_drift_integration.py @@ -0,0 +1,216 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 3 config-drift integration tests against Docker Redis (R15). + +Verifies the launch-time signature gate end-to-end: +1. First launch with persistence: redis stores the signature on a fresh DB. +2. Relaunch with the **same** persistence-relevant config returns clean. +3. Relaunch with a **changed** persistence-relevant config raises + ``ConfigMismatchError`` from ``validate_against_backend`` and the + ``launch`` CLI exits non-zero with the diff + remediation hint. +""" +from __future__ import annotations + +import asyncio +import os +import pytest +import re +import uuid +import yaml +from pathlib import Path +from typer.testing import CliRunner +from unittest import mock + +from twinkle.server.cli.app import app +from twinkle.server.config import ServerConfig +from twinkle.server.exceptions import ConfigMismatchError +from twinkle.server.state.backend.factory import PersistenceConfig, create_backend +from twinkle.server.state.backend.redis_backend import RedisBackend +from twinkle.server.state.config_signature import _SIGNATURE_KEY, compute_signature, validate_against_backend + +REDIS_URL = os.environ.get('TWINKLE_TEST_REDIS_URL', 'redis://localhost:6379/0') + + +def _can_reach_redis() -> bool: + + async def _check() -> bool: + backend = RedisBackend(REDIS_URL) + try: + return await backend.health_check() + except Exception: + return False + finally: + try: + await backend.close() + except Exception: + pass + + try: + return asyncio.run(_check()) + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _can_reach_redis(), + reason=f'Redis at {REDIS_URL} unreachable', +) + + +@pytest.fixture +def fresh_prefix() -> str: + return f'twinkle-drift-{uuid.uuid4().hex[:8]}::' + + +@pytest.fixture +def write_config(tmp_path: Path): + """Build a YAML config file with a parametric persistence section.""" + + def _write(persistence: dict) -> Path: + payload = { + 'http_options': { + 'host': 'localhost', + 'port': 8000 + }, + 'telemetry': { + 'enabled': False + }, + 'persistence': + persistence, + 'applications': [{ + 'name': 'server', + 'route_prefix': '/api/v1', + 'import_path': 'server', + 'args': { + 'supported_models': ['mock'] + }, + }], + } + path = tmp_path / f'config-{uuid.uuid4().hex[:6]}.yaml' + path.write_text(yaml.safe_dump(payload)) + return path + + return _write + + +# ---------- direct validate_against_backend behaviour --------------------- # + + +@pytest.mark.asyncio +async def test_first_run_stores_signature_then_match_passes(fresh_prefix: str) -> None: + pcfg = PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix) + payload = {'persistence': pcfg.model_dump(mode='json')} + + # Cleanly fresh — first run stores. + await validate_against_backend(pcfg, payload) + backend = create_backend(pcfg) + try: + assert await backend.get(_SIGNATURE_KEY) == compute_signature(payload) + # Same payload — second run should pass without error. + await validate_against_backend(pcfg, payload) + finally: + for k in await backend.keys('*'): + await backend.delete(k) + await backend.close() + + +@pytest.mark.asyncio +async def test_drift_raises_with_diff_and_remediation(fresh_prefix: str) -> None: + pcfg = PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix) + initial = {'persistence': pcfg.model_dump(mode='json')} + drifted = {'persistence': {**pcfg.model_dump(mode='json'), 'redis_url': REDIS_URL + '?changed=1'}} + + backend = create_backend(pcfg) + try: + await validate_against_backend(pcfg, initial) + with pytest.raises(ConfigMismatchError) as exc: + await validate_against_backend(pcfg, drifted) + msg = str(exc.value) + assert 'drifted' in msg.lower() or 'mismatch' in msg.lower() + assert 'Remediation' in msg + assert 'redis_url' in msg + finally: + for k in await backend.keys('*'): + await backend.delete(k) + await backend.close() + + +# ---------- CLI launch path: drift exit code = 3 -------------------------- # + + +def test_cli_launch_drift_exit_nonzero_with_diff(fresh_prefix: str, write_config) -> None: + """``launch`` calls ``validate_against_backend`` BEFORE the heavy + ServerLauncher import, so we can stub ServerLauncher to a sentinel and + still observe the drift error.""" + runner = CliRunner() + + cfg_a = write_config({'mode': 'redis', 'redis_url': REDIS_URL, 'key_prefix': fresh_prefix}) + cfg_b = write_config({ + 'mode': 'redis', + 'redis_url': REDIS_URL, + 'key_prefix': fresh_prefix, + # ``file_path`` doesn't apply to redis mode but its serialized + # presence (or any non-default field) flips the signature. + 'file_path': '/tmp/intentional-drift.json', + }) + + # Make the first launch a no-op after drift validation by stubbing the + # launcher; we only care about the signature side effects on Redis. + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy: + launcher_spy.return_value.launch = mock.MagicMock(return_value=None) + + first = runner.invoke(app, ['launch', '--config', str(cfg_a)]) + assert first.exit_code == 0, first.output + assert launcher_spy.call_count == 1 + + # Second launch with drifted persistence config — never reaches the launcher. + launcher_spy.reset_mock() + second = runner.invoke(app, ['launch', '--config', str(cfg_b)]) + + assert second.exit_code == 3, second.output + assert launcher_spy.call_count == 0 + assert re.search(r'drifted|mismatch', second.output, re.IGNORECASE) + assert 'Remediation' in second.output + + # `clear persistence` clears the namespace so a follow-up launch with the + # drifted config can succeed (this is the documented remediation). + cleared = runner.invoke(app, ['clear', 'persistence', '--config', str(cfg_b)]) + assert cleared.exit_code == 0, cleared.output + + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy_2: + launcher_spy_2.return_value.launch = mock.MagicMock(return_value=None) + post_clear = runner.invoke(app, ['launch', '--config', str(cfg_b)]) + assert post_clear.exit_code == 0, post_clear.output + + +def test_cli_launch_first_run_succeeds_then_match(fresh_prefix: str, write_config) -> None: + runner = CliRunner() + cfg = write_config({'mode': 'redis', 'redis_url': REDIS_URL, 'key_prefix': fresh_prefix}) + + with mock.patch('twinkle.server.launcher.ServerLauncher') as launcher_spy: + launcher_spy.return_value.launch = mock.MagicMock(return_value=None) + + first = runner.invoke(app, ['launch', '--config', str(cfg)]) + second = runner.invoke(app, ['launch', '--config', str(cfg)]) + + assert first.exit_code == 0 + assert second.exit_code == 0 + + +# ---------- check-config doesn't touch the backend ------------------------ # + + +def test_check_config_does_not_touch_redis(fresh_prefix: str, write_config) -> None: + """check-config only validates the YAML; no signature is stored.""" + cfg = write_config({'mode': 'redis', 'redis_url': REDIS_URL, 'key_prefix': fresh_prefix}) + runner = CliRunner() + res = runner.invoke(app, ['check-config', '--config', str(cfg)]) + assert res.exit_code == 0 + + async def _read_signature() -> object: + backend = create_backend(PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=fresh_prefix)) + try: + return await backend.get(_SIGNATURE_KEY) + finally: + await backend.close() + + assert asyncio.run(_read_signature()) is None diff --git a/tests/server/config/__init__.py b/tests/server/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/config/test_server_config.py b/tests/server/config/test_server_config.py new file mode 100644 index 000000000..a7a621eeb --- /dev/null +++ b/tests/server/config/test_server_config.py @@ -0,0 +1,252 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the typed ``ServerConfig`` (R6, R7, R8). + +Properties covered: +- # Feature: server-config-observability-refactor, + Property 12: Valid configuration yields a fully validated instance +- # Feature: server-config-observability-refactor, + Property 13: Any constraint violation is rejected with the offending field named +- # Feature: server-config-observability-refactor, + Property 14: Configuration round-trip fidelity +- # Feature: server-config-observability-refactor, + Property 15: Legacy / unknown field names are rejected +""" +from __future__ import annotations + +import pytest +import yaml +from hypothesis import given, settings +from hypothesis import strategies as st +from pathlib import Path +from pydantic import ValidationError + +from twinkle.server.config import ApplicationSpec, ServerConfig +from twinkle.server.exceptions import ConfigParseError +from twinkle.server.launcher import ServerLauncher + +# ---------- minimal valid config strategy ---------------------------------- # + +_PERSISTENCE_VARIANTS = st.one_of( + st.fixed_dictionaries({'mode': st.just('memory')}), + st.fixed_dictionaries({ + 'mode': st.just('file'), + 'file_path': st.just('/tmp/state.json') + }), + st.fixed_dictionaries({ + 'mode': st.just('redis'), + 'redis_url': st.just('redis://localhost:6379/0') + }), +) + +_MODEL_APP = st.fixed_dictionaries({ + 'name': + st.just('m'), + 'route_prefix': + st.just('/api/v1/m'), + 'import_path': + st.just('model'), + 'args': + st.fixed_dictionaries({ + 'model_id': st.just('model-id'), + 'device_group': st.just({ + 'name': 'g', + 'ranks': 1, + 'device_type': 'CPU' + }), + 'device_mesh': st.just({ + 'device_type': 'CPU', + 'dp_size': 1 + }), + 'backend': st.sampled_from(['mock', 'transformers', 'megatron']), + }), +}) + +_VALID_CONFIG = st.fixed_dictionaries({ + 'persistence': _PERSISTENCE_VARIANTS, + 'applications': st.lists(_MODEL_APP, min_size=0, max_size=3), +}) + +# ---------- Property 12: valid → fully validated (R6.2, R7.3) -------------- # + + +@settings(max_examples=100) +@given(payload=_VALID_CONFIG) +def test_property_12_valid_payload_yields_full_instance(payload: dict) -> None: + cfg = ServerConfig.model_validate(payload) + assert isinstance(cfg, ServerConfig) + assert all(isinstance(a, ApplicationSpec) for a in cfg.applications) + # Nested sections instantiated and validated. + assert cfg.persistence.mode == payload['persistence']['mode'] + assert cfg.task_queue.rps_limit >= 0 + + +# ---------- Property 13: violation → field-named error (R6.3, R7.1, R7.2) -- # + + +def test_property_13_redis_mode_missing_url() -> None: + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate({'persistence': {'mode': 'redis'}}) + msg = str(exc.value) + assert 'persistence.redis_url' in msg or 'redis_url' in msg + + +def test_property_13_file_mode_missing_path() -> None: + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate({'persistence': {'mode': 'file'}}) + msg = str(exc.value) + assert 'persistence.file_path' in msg or 'file_path' in msg + + +@settings(max_examples=100) +@given(bad_backend=st.text(min_size=1, max_size=8).filter(lambda s: s not in ('mock', 'transformers', 'megatron'))) +def test_property_13_bad_backend_names_field(bad_backend: str) -> None: + payload = { + 'applications': [{ + 'name': 'm', + 'import_path': 'model', + 'args': { + 'model_id': 'x', + 'device_group': {}, + 'device_mesh': {}, + 'backend': bad_backend, + }, + }] + } + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate(payload) + errors = exc.value.errors() + assert any('backend' in err['loc'] for err in errors) + + +@settings(max_examples=100) +@given(bad_max_input_tokens=st.integers(max_value=0, min_value=-1000)) +def test_property_13_nested_field_constraint_violation_named(bad_max_input_tokens: int) -> None: + """Nested-section constraints (here ``task_queue.max_input_tokens``) are + enforced together with cross-field ones (R7.3) and the offending path is + visible in the error.""" + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate({'task_queue': {'max_input_tokens': bad_max_input_tokens}}) + errors = exc.value.errors() + assert any('max_input_tokens' in err['loc'] for err in errors) + + +# ---------- Property 14: round-trip fidelity (R6.7) ------------------------ # + + +@settings(max_examples=100) +@given(payload=_VALID_CONFIG) +def test_property_14_round_trip_fidelity(payload: dict) -> None: + cfg = ServerConfig.model_validate(payload) + dumped = cfg.to_yaml_dict() + re_loaded = ServerConfig.model_validate(dumped) + assert re_loaded == cfg + assert re_loaded.model_dump() == cfg.model_dump() + + +# ---------- Property 15: legacy/unknown rejected (R8.1, R8.2) -------------- # + + +@pytest.mark.parametrize( + 'legacy_field', + ['telemetry_config', 'persistence_config'], +) +def test_property_15_legacy_field_rejected(legacy_field: str) -> None: + payload = {legacy_field: {}} + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate(payload) + errors = exc.value.errors() + assert any(err['type'] == 'extra_forbidden' for err in errors) + assert any(legacy_field in err['loc'] for err in errors) + + +@settings(max_examples=100) +@given(unknown=st.text(min_size=1, max_size=20).filter(lambda s: not s.startswith('_'))) +def test_property_15_unknown_field_rejected(unknown: str) -> None: + known = { + 'ray_namespace', + 'proxy_location', + 'http_options', + 'telemetry', + 'persistence', + 'task_queue', + 'applications', + } + if unknown in known: + return + with pytest.raises(ValidationError): + ServerConfig.model_validate({unknown: 'x'}) + + +@pytest.mark.parametrize('section', ['telemetry', 'persistence']) +def test_property_15_unknown_nested_field_rejected(section: str) -> None: + """Nested config sections also reject unknown keys (defends against typos + inside ``telemetry: {...}`` / ``persistence: {...}``).""" + payload = {section: {'unknown_typo': 1}} + with pytest.raises(ValidationError) as exc: + ServerConfig.model_validate(payload) + assert any('unknown_typo' in err['loc'] for err in exc.value.errors()) + + +# ---------- 3.11: from_yaml error paths + launcher dict rejection ---------- # + + +def test_from_yaml_missing_path(tmp_path: Path) -> None: + p = tmp_path / 'does_not_exist.yaml' + with pytest.raises(FileNotFoundError) as exc: + ServerConfig.from_yaml(p) + assert str(p) in str(exc.value) + + +def test_from_yaml_malformed_yaml(tmp_path: Path) -> None: + p = tmp_path / 'bad.yaml' + p.write_text('this is: not: valid: yaml: : :\n - [unbalanced\n') + with pytest.raises(ConfigParseError): + ServerConfig.from_yaml(p) + + +def test_from_yaml_top_level_must_be_mapping(tmp_path: Path) -> None: + p = tmp_path / 'list.yaml' + p.write_text('- a\n- b\n') + with pytest.raises(ConfigParseError): + ServerConfig.from_yaml(p) + + +def test_from_yaml_valid_minimal(tmp_path: Path) -> None: + p = tmp_path / 'mini.yaml' + yaml.safe_dump( + { + 'persistence': { + 'mode': 'memory' + }, + 'applications': [] + }, + p.open('w'), + ) + cfg = ServerConfig.from_yaml(p) + assert cfg.persistence.mode == 'memory' + assert cfg.applications == [] + + +def test_launcher_rejects_raw_dict() -> None: + with pytest.raises(TypeError) as exc: + ServerLauncher(config={'applications': []}) + assert 'ServerConfig' in str(exc.value) + + +def test_launcher_accepts_typed_config() -> None: + cfg = ServerConfig() + launcher = ServerLauncher(config=cfg) + assert launcher.config is cfg + + +def test_cookbook_examples_load() -> None: + """Migrated cookbook configs all parse with the new field names.""" + here = Path(__file__).resolve().parents[3] + examples = [ + here / 'cookbook' / 'client' / 'server' / 'transformer' / 'server_config.yaml', + here / 'cookbook' / 'client' / 'server' / 'megatron' / 'server_config.yaml', + here / 'cookbook' / 'client' / 'server' / 'megatron' / 'server_config_4b.yaml', + ] + for p in examples: + cfg = ServerConfig.from_yaml(p) + assert isinstance(cfg, ServerConfig), p diff --git a/tests/server/fixtures/__init__.py b/tests/server/fixtures/__init__.py new file mode 100644 index 000000000..bad4af55d --- /dev/null +++ b/tests/server/fixtures/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared on-disk fixtures for server tests.""" +from __future__ import annotations + +from pathlib import Path + +FIXTURES_DIR = Path(__file__).resolve().parent + +# All-mock CPU-only server config used by CLI tests and the in-process e2e +# Ray Serve startup test. The YAML is the abstraction — callers pass this +# path to ServerConfig.from_yaml / CLI --config flags. +MOCK_SERVER_CONFIG = FIXTURES_DIR / 'server_config_mock.yaml' diff --git a/tests/server/fixtures/server_config_mock.yaml b/tests/server/fixtures/server_config_mock.yaml new file mode 100644 index 000000000..8370a2448 --- /dev/null +++ b/tests/server/fixtures/server_config_mock.yaml @@ -0,0 +1,86 @@ +# Test-only Twinkle Server config — CPU-only mock backends. +# +# This file is the single fixture consumed by CLI tests and the in-process +# Ray Serve e2e test. Mock model + mock sampler need no GPU / torch / vllm / +# megatron, so the launch path can run on any CI host. Not for production. + +proxy_location: EveryNode + +http_options: + host: 127.0.0.1 + port: 8000 + +persistence: + mode: memory + +applications: + + - name: server + route_prefix: /api/v1 + import_path: server + args: + server_config: + per_token_model_limit: 3 + supported_models: + - mock-model + deployments: + - name: GatewayServer + max_ongoing_requests: 50 + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + + - name: models-mock + route_prefix: /api/v1/model/mock + import_path: model + args: + backend: mock + model_id: mock-model + nproc_per_node: 1 + device_group: + name: model + ranks: 1 + device_type: cpu + device_mesh: + device_type: cpu + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + + - name: sampler-mock + route_prefix: /api/v1/sampler/mock + import_path: sampler + args: + sampler_type: mock + model_id: mock-model + nproc_per_node: 1 + device_group: + name: sampler + ranks: 1 + device_type: cpu + device_mesh: + device_type: cpu + dp_size: 1 + queue_config: + rps_limit: 100 + tps_limit: 100000 + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 diff --git a/tests/server/model/__init__.py b/tests/server/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/model/test_mock_model.py b/tests/server/model/test_mock_model.py new file mode 100644 index 000000000..5b2d800ee --- /dev/null +++ b/tests/server/model/test_mock_model.py @@ -0,0 +1,177 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the numpy-only mock model backend (R1, R3, R4). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 1: Mock model interface conformance +- # Feature: server-config-observability-refactor, Property 2: Mock model forward determinism and shape +- # Feature: server-config-observability-refactor, Property 3: Mock model adapter add/remove round-trip +- # Feature: server-config-observability-refactor, Property 4: Mock model remove-absent raises and preserves record +- # Feature: server-config-observability-refactor, Property 10: Model backend dispatch +""" +from __future__ import annotations + +import pytest +import sys +from hypothesis import given, settings +from hypothesis import strategies as st + +from twinkle.server.exceptions import ConfigError +from twinkle.server.model.app import _MODEL_BACKENDS, _dispatch_model_backend, _validate_model_backend +from twinkle.server.model.backends.mock_model import TwinkleCompatMockModel + +# ---------- Property 1: interface conformance (R1.1, R1.4) ---------------- # + +_REQUIRED_METHODS = ( + 'tinker_forward_only', + 'tinker_forward_backward', + 'tinker_step', + 'tinker_calculate_metric', + 'tinker_load', + 'forward_only', + 'forward_backward', + 'forward', + 'calculate_loss', + 'backward', + 'step', + 'zero_grad', + 'lr_step', + 'clip_grad_norm', + 'set_loss', + 'set_optimizer', + 'set_lr_scheduler', + 'set_template', + 'set_processor', + 'add_metric', + 'apply_patch', + 'save', + 'load', + 'resume_from_checkpoint', + 'get_state_dict', + 'get_train_configs', + 'add_adapter', + 'add_adapter_to_model', + 'remove_adapter', + 'has_adapter', +) + + +@pytest.mark.parametrize('method_name', _REQUIRED_METHODS) +def test_property_1_required_method_present(method_name: str) -> None: + m = TwinkleCompatMockModel('mid') + assert callable(getattr(m, method_name)), method_name + + +def test_property_1_constructor_does_not_raise() -> None: + TwinkleCompatMockModel('mid') + + +# ---------- Property 2: forward determinism + shape (R1.3, R4.4) ---------- # + + +@settings(max_examples=100) +@given( + seq_lens=st.lists(st.integers(min_value=1, max_value=12), min_size=1, max_size=5), + seed=st.integers(min_value=0, max_value=99), +) +def test_property_2_forward_only_deterministic_and_shaped(seq_lens: list, seed: int) -> None: + inputs = [{'tokens': list(range(n))} for n in seq_lens] + a = TwinkleCompatMockModel('mid', seed=seed) + b = TwinkleCompatMockModel('mid', seed=seed) + out_a = a.forward_only(inputs=inputs) + out_b = b.forward_only(inputs=inputs) + assert out_a == out_b + assert len(out_a) == len(inputs) + for record, n in zip(out_a, seq_lens): + assert len(record['logprobs']) == n + assert len(record['elementwise_loss']) == n + + +@settings(max_examples=100) +@given(seq_lens=st.lists(st.integers(min_value=1, max_value=8), min_size=1, max_size=4)) +def test_property_2_tinker_forward_backward_loss_is_finite(seq_lens: list) -> None: + m = TwinkleCompatMockModel('mid') + inputs = [{'tokens': list(range(n))} for n in seq_lens] + result, loss = m.tinker_forward_backward(inputs=inputs, adapter_name='a', loss_fn='cross_entropy') + assert isinstance(loss, float) + assert 0.0 <= loss <= 1.0 + assert len(result) == len(inputs) + + +# ---------- Property 3: adapter round-trip (R1.5, R1.6) ------------------- # + + +@settings(max_examples=100) +@given( + name=st.text( + min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) +def test_property_3_adapter_add_remove_round_trip(name: str) -> None: + m = TwinkleCompatMockModel('mid') + assert not m.has_adapter(name) + m.add_adapter(name, rank=4) + assert m.has_adapter(name) + m.remove_adapter(name) + assert not m.has_adapter(name) + + +# ---------- Property 4: remove-absent raises + preserves (R1.7) ----------- # + + +@settings(max_examples=100) +@given(name=st.text(min_size=1, max_size=12)) +def test_property_4_remove_absent_raises(name: str) -> None: + m = TwinkleCompatMockModel('mid') + pre = dict(m._adapters) + with pytest.raises(KeyError): + m.remove_adapter(name) + assert m._adapters == pre + + +# ---------- Property 10: Model backend dispatch (R3.1-3.3, R3.7, R3.9) ---- # + + +def test_property_10_mock_dispatch_returns_mock_model() -> None: + m = _dispatch_model_backend(_validate_model_backend('mock'), {'model_id': 'mid'}) + assert isinstance(m, TwinkleCompatMockModel) + + +@settings(max_examples=100) +@given(bad=st.text(min_size=1, max_size=10).filter(lambda s: s not in _MODEL_BACKENDS)) +def test_property_10_invalid_backend_raises_config_error(bad: str) -> None: + """Validation runs BEFORE any backend import / instantiation (R3.9).""" + with pytest.raises(ConfigError) as exc: + _validate_model_backend(bad) + assert exc.value.field == 'backend' + assert exc.value.value == bad + assert set(exc.value.allowed) == set(_MODEL_BACKENDS) + + +@pytest.mark.parametrize('value', [None, '']) +def test_property_10_absent_or_empty_backend_raises(value) -> None: + with pytest.raises(ConfigError) as exc: + _validate_model_backend(value) + assert exc.value.field == 'backend' + + +# ---------- Import isolation (R1.2) --------------------------------------- # + + +def test_import_isolation_no_torch_required() -> None: + """Block torch/transformers/vllm/megatron and re-import the mock module.""" + blocked = ('torch', 'transformers', 'vllm', 'megatron') + saved = {m: sys.modules.pop(m, None) for m in blocked} + saved_mock = sys.modules.pop('twinkle.server.model.backends.mock_model', None) + try: + for m in blocked: + sys.modules[m] = None # type: ignore[assignment] + import importlib + + mod = importlib.import_module('twinkle.server.model.backends.mock_model') + assert mod.TwinkleCompatMockModel is not None + finally: + for m in blocked: + if saved[m] is None: + sys.modules.pop(m, None) + else: + sys.modules[m] = saved[m] + if saved_mock is not None: + sys.modules['twinkle.server.model.backends.mock_model'] = saved_mock diff --git a/tests/server/sampler/__init__.py b/tests/server/sampler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/sampler/test_mock_sampler.py b/tests/server/sampler/test_mock_sampler.py new file mode 100644 index 000000000..269a768a9 --- /dev/null +++ b/tests/server/sampler/test_mock_sampler.py @@ -0,0 +1,144 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the numpy-only mock sampler (R2, R3, R4). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 5: Mock sampler interface conformance +- # Feature: server-config-observability-refactor, Property 6: Mock sampler output length and logprob count +- # Feature: server-config-observability-refactor, Property 7: Mock sampler determinism +- # Feature: server-config-observability-refactor, Property 8: Mock sampler rejects invalid max tokens +- # Feature: server-config-observability-refactor, Property 9: Mock sampler adapter record update +- # Feature: server-config-observability-refactor, Property 11: Sampler backend dispatch +""" +from __future__ import annotations + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from pathlib import Path + +from twinkle.data_format import InputFeature, SamplingParams +from twinkle.server.exceptions import ConfigError +from twinkle.server.sampler.app import _SAMPLER_TYPES, _dispatch_sampler_backend, _validate_sampler_type +from twinkle.server.sampler.backends.mock_sampler import MockSampler + +# ---------- Property 5: interface conformance (R2.1) ---------------------- # + +_REQUIRED_METHODS = ('sample', 'apply_patch', 'add_adapter_to_sampler', 'has_adapter') + + +@pytest.mark.parametrize('method', _REQUIRED_METHODS) +def test_property_5_required_method_present(method: str) -> None: + s = MockSampler('mid') + assert callable(getattr(s, method)) + + +# ---------- Property 6: output length + logprob count (R2.3, R2.4) -------- # + + +@settings(max_examples=100) +@given( + max_tokens=st.integers(min_value=1, max_value=20), + num_samples=st.integers(min_value=1, max_value=4), +) +def test_property_6_output_length_and_logprob_count(max_tokens: int, num_samples: int) -> None: + s = MockSampler('mid') + inp = InputFeature(input_ids=[1, 2, 3]) + responses = s.sample(inp, SamplingParams(max_tokens=max_tokens), adapter_name='a', num_samples=num_samples) + assert len(responses) == 1 + seqs = responses[0].sequences + assert len(seqs) == num_samples + for seq in seqs: + assert len(seq.tokens) == max_tokens + assert len(seq.logprobs) == max_tokens + + +# ---------- Property 7: determinism (R2.5, R4.5) -------------------------- # + + +@settings(max_examples=100) +@given( + max_tokens=st.integers(min_value=1, max_value=10), + num_samples=st.integers(min_value=1, max_value=3), + adapter=st.sampled_from(['', 'a', 'lora-1']), +) +def test_property_7_determinism(max_tokens: int, num_samples: int, adapter: str) -> None: + s = MockSampler('mid', seed=42) + inp = InputFeature(input_ids=[1, 2, 3]) + r1 = s.sample(inp, SamplingParams(max_tokens=max_tokens), adapter_name=adapter, num_samples=num_samples) + r2 = s.sample(inp, SamplingParams(max_tokens=max_tokens), adapter_name=adapter, num_samples=num_samples) + assert r1 == r2 + + +# ---------- Property 8: invalid max_tokens rejected (R2.6) ---------------- # + + +@settings(max_examples=50) +@given(bad=st.integers(max_value=0, min_value=-1000)) +def test_property_8_max_tokens_lt_1_raises(bad: int) -> None: + s = MockSampler('mid') + inp = InputFeature(input_ids=[1]) + with pytest.raises(ValueError) as exc: + s.sample(inp, SamplingParams(max_tokens=bad)) + assert 'max_tokens' in str(exc.value) + + +def test_property_8_no_sampling_params_raises() -> None: + s = MockSampler('mid') + inp = InputFeature(input_ids=[1]) + with pytest.raises(ValueError): + s.sample(inp, sampling_params=None) + + +# ---------- Property 9: adapter record update (R2.7) ---------------------- # + + +@settings(max_examples=100) +@given( + name=st.text( + min_size=1, max_size=12, alphabet=st.characters(whitelist_categories=('L', 'N'), whitelist_characters='_-'))) +def test_property_9_add_adapter_to_sampler(name: str) -> None: + s = MockSampler('mid') + assert not s.has_adapter(name) + s.add_adapter_to_sampler(name, {'rank': 4}) + assert s.has_adapter(name) + assert s._adapters[name] == {'rank': 4} + + +# ---------- Property 11: Sampler dispatch (R3.4-3.6, R3.10) --------------- # + + +def test_property_11_mock_dispatch_returns_mock_sampler() -> None: + s = _dispatch_sampler_backend(_validate_sampler_type('mock'), {'model_id': 'mid'}) + assert isinstance(s, MockSampler) + + +@settings(max_examples=100) +@given(bad=st.text(min_size=1, max_size=10).filter(lambda s: s not in _SAMPLER_TYPES)) +def test_property_11_invalid_sampler_type_raises_config_error(bad: str) -> None: + """Validation runs BEFORE any sampler import / instantiation (R3.10).""" + with pytest.raises(ConfigError) as exc: + _validate_sampler_type(bad) + assert exc.value.field == 'sampler_type' + assert exc.value.value == bad + assert set(exc.value.allowed) == set(_SAMPLER_TYPES) + + +@pytest.mark.parametrize('value', [None, '']) +def test_property_11_absent_or_empty_sampler_type_raises(value) -> None: + with pytest.raises(ConfigError) as exc: + _validate_sampler_type(value) + assert exc.value.field == 'sampler_type' + + +# ---------- No direct vllm import (R2.2) ---------------------------------- # + + +def test_mock_sampler_module_does_not_directly_import_vllm() -> None: + """Static check: ``mock_sampler.py`` must not import ``vllm`` directly.""" + src = Path( + __file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'sampler' / 'backends' / 'mock_sampler.py' + text = src.read_text() + for forbidden in ('import vllm', 'from vllm'): + assert forbidden not in text, f'mock_sampler.py contains {forbidden!r}' + + diff --git a/tests/server/state/__init__.py b/tests/server/state/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/state/test_config_signature.py b/tests/server/state/test_config_signature.py new file mode 100644 index 000000000..4bee4833d --- /dev/null +++ b/tests/server/state/test_config_signature.py @@ -0,0 +1,142 @@ +"""Tests for config signature validation.""" +from __future__ import annotations + +import pytest + +from twinkle.server.state.backend.memory_backend import MemoryBackend +from twinkle.server.state.config_signature import SignatureMismatchPolicy, compute_signature, validate_config_signature + +# ---- compute_signature ---- + + +def test_compute_signature_deterministic(): + """Same input should produce same output.""" + config = {'model': 'qwen', 'batch_size': 8} + sig1 = compute_signature(config) + sig2 = compute_signature(config) + assert sig1 == sig2 + + +def test_compute_signature_different_inputs(): + """Different inputs should produce different outputs.""" + config_a = {'model': 'qwen', 'batch_size': 8} + config_b = {'model': 'llama', 'batch_size': 8} + assert compute_signature(config_a) != compute_signature(config_b) + + +def test_compute_signature_key_order_independent(): + """Key order should not affect the signature (sort_keys=True).""" + config_a = {'b': 2, 'a': 1} + config_b = {'a': 1, 'b': 2} + assert compute_signature(config_a) == compute_signature(config_b) + + +def test_compute_signature_is_hex_string(): + """Signature should be a valid hex SHA256 string.""" + sig = compute_signature({'key': 'value'}) + assert len(sig) == 64 # SHA256 hex = 64 chars + assert all(c in '0123456789abcdef' for c in sig) + + +# ---- validate_config_signature ---- + + +@pytest.mark.asyncio +async def test_first_run_stores_signature(): + """First run with no stored sig should store it and return True.""" + backend = MemoryBackend() + config = {'model': 'test'} + result = await validate_config_signature(backend, config) + assert result is True + # Signature should be stored + stored = await backend.get('_meta::config_signature') + assert stored == compute_signature(config) + + +@pytest.mark.asyncio +async def test_same_config_passes(): + """Same config on second run should pass validation.""" + backend = MemoryBackend() + config = {'model': 'test', 'lr': 0.001} + # First run + await validate_config_signature(backend, config) + # Second run same config + result = await validate_config_signature(backend, config) + assert result is True + + +@pytest.mark.asyncio +async def test_different_config_warn_policy(): + """Different config with WARN policy should return False and update sig.""" + backend = MemoryBackend() + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} + + await validate_config_signature(backend, config_v1) + result = await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.WARN) + assert result is False + # Signature should be updated to v2 + stored = await backend.get('_meta::config_signature') + assert stored == compute_signature(config_v2) + + +@pytest.mark.asyncio +async def test_different_config_clear_policy(): + """CLEAR policy should clear non-meta data, preserve _meta, return False.""" + backend = MemoryBackend() + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} + + # Store initial config + await validate_config_signature(backend, config_v1) + # Add some user data + await backend.set('session::abc', {'data': 123}) + await backend.set('model::xyz', {'data': 456}) + await backend.set('_meta::other', 'keep_this') + + result = await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.CLEAR) + assert result is False + + # User data should be cleared + assert await backend.get('session::abc') is None + assert await backend.get('model::xyz') is None + + # _meta keys should be preserved + assert await backend.get('_meta::other') == 'keep_this' + # Signature should be updated + stored = await backend.get('_meta::config_signature') + assert stored == compute_signature(config_v2) + + +@pytest.mark.asyncio +async def test_different_config_abort_policy(): + """ABORT policy should raise ConfigMismatchError.""" + from twinkle.server.exceptions import ConfigMismatchError + + backend = MemoryBackend() + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} + + await validate_config_signature(backend, config_v1) + + with pytest.raises(ConfigMismatchError): + await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.ABORT) + + +@pytest.mark.asyncio +async def test_abort_policy_does_not_update_signature(): + """ABORT policy should NOT update the stored signature.""" + from twinkle.server.exceptions import ConfigMismatchError + + backend = MemoryBackend() + config_v1 = {'model': 'v1'} + config_v2 = {'model': 'v2'} + + await validate_config_signature(backend, config_v1) + + with pytest.raises(ConfigMismatchError): + await validate_config_signature(backend, config_v2, policy=SignatureMismatchPolicy.ABORT) + + # Signature should still be v1 + stored = await backend.get('_meta::config_signature') + assert stored == compute_signature(config_v1) diff --git a/tests/server/state/test_de_actor.py b/tests/server/state/test_de_actor.py new file mode 100644 index 000000000..2770ce358 --- /dev/null +++ b/tests/server/state/test_de_actor.py @@ -0,0 +1,172 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 0d — De-Actor ServerState tests (R19). + +Covers: +- # Feature: server-config-observability-refactor, Property 25: State operation equivalence under direct-backend access +- de-Actor wiring: ``get_server_state`` returns a direct-bound ``ServerState`` + and never creates a detached Ray Actor (R19.1, R19.2). +- in-process MemoryBackend works without Redis (R19.6). +""" +from __future__ import annotations + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st +from unittest import mock + +from twinkle.server.state import (PersistenceConfig, ReplicaRegistry, ServerState, get_server_state, + reset_server_state_cache) +from twinkle.server.state.backend.memory_backend import MemoryBackend + +# ---------- 4.6: de-Actor wiring + in-process persistence ------------------ # + + +def _ray_attr_used(obj_path: str) -> bool: + """Return True if ``obj_path`` (e.g. ``ray.remote``) is referenced in source.""" + from pathlib import Path + src = Path(__file__).resolve().parents[3] / 'src' / 'twinkle' / 'server' / 'state' / 'server_state.py' + return obj_path in src.read_text() + + +def test_no_detached_actor_in_source() -> None: + """The state module must not call ``ray.remote(...)`` or use ``lifetime='detached'``. + + Static check: searching the file is enough — the dynamic check below also + confirms ``ray.remote`` is never invoked when ``get_server_state`` runs. + """ + assert not _ray_attr_used('ray.remote('), ('state/server_state.py still references ray.remote(...) — ' + 'detached actor must not be created (R19.1).') + assert not _ray_attr_used("lifetime='detached'"), ("state/server_state.py still uses lifetime='detached' (R19.1).") + + +def test_get_server_state_does_not_call_ray_remote() -> None: + reset_server_state_cache() + import ray + + with mock.patch.object(ray, 'remote') as remote_spy, \ + mock.patch.object(ray, 'get_actor', side_effect=ValueError) as get_actor_spy: + state = get_server_state(actor_name='unit', backend=MemoryBackend()) + assert isinstance(state, ServerState) + assert remote_spy.call_count == 0, 'ray.remote was called — detached actor created' + # ray.get_actor may not be called at all under direct-backend access. + # Either way, the contract is that no remote actor is built. + _ = get_actor_spy + + +def test_get_server_state_caches_per_process() -> None: + reset_server_state_cache() + a = get_server_state(actor_name='cache-a', backend=MemoryBackend()) + b = get_server_state(actor_name='cache-a') + assert a is b + + +def test_get_server_state_separate_keys_yield_separate_instances() -> None: + reset_server_state_cache() + a = get_server_state(actor_name='k1', backend=MemoryBackend()) + b = get_server_state(actor_name='k2', backend=MemoryBackend()) + assert a is not b + + +def test_in_process_persistence_no_redis_required() -> None: + """``PersistenceConfig`` defaults to memory mode and ``ServerState`` works + without an external Redis (R19.6).""" + reset_server_state_cache() + cfg = PersistenceConfig() # mode == 'memory' + state = get_server_state(actor_name='no-redis', persistence_config=cfg) + assert isinstance(state, ServerState) + + +# ---------- 4.5: state-operation equivalence under direct-backend ---------- # + +_OP_STRATEGY = st.lists( + st.one_of( + # ('register_replica', replica_id, max_loras) + st.tuples( + st.just('register_replica'), + st.sampled_from(['r1', 'r2', 'r3']), + st.integers(min_value=1, max_value=4), + ), + # ('add_model', model_id, token, replica_id) + st.tuples( + st.just('add_model'), + st.text(min_size=1, max_size=4, alphabet='abcdefg'), + st.sampled_from(['t1', 't2']), + st.sampled_from(['r1', 'r2', None]), + ), + # ('config_set', key, value) + st.tuples( + st.just('config_set'), + st.sampled_from(['k1', 'k2', 'k3']), + st.integers(min_value=0, max_value=99), + ), + ), + min_size=0, + max_size=12, +) + + +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow]) +@given(ops=_OP_STRATEGY) +@pytest.mark.asyncio +async def test_property_25_state_operation_equivalence(ops: list[tuple]) -> None: + """Two ``ServerState`` instances driven by the same op stream agree. + + Two instances bound to one shared backend must agree on every read after + the same sequence of writes — this is the equivalence the actor used to + enforce, now provided by the shared backend itself (R19.3). + """ + backend = MemoryBackend() + a = ServerState(backend=backend) + b = ServerState(backend=backend) + seen_models: set[str] = set() + for op in ops: + kind = op[0] + if kind == 'register_replica': + _, rid, mx = op + await a.register_replica(rid, mx) + elif kind == 'add_model': + _, mid, token, rid = op + if mid in seen_models: + continue + seen_models.add(mid) + await a.register_model({'base_model': 'x'}, token=token, model_id=mid, replica_id=rid) + elif kind == 'config_set': + _, k, v = op + await b.add_config(k, v) + + # Both instances see the same persisted view. + assert await a.get_capacity_info() == await b.get_capacity_info() + for k in ('k1', 'k2', 'k3'): + assert await a.get_config(k) == await b.get_config(k) + + +# ---------- ReplicaRegistry direct ---------------------------------------- # + + +@pytest.mark.asyncio +async def test_replica_registry_round_trip() -> None: + backend = MemoryBackend() + reg = ReplicaRegistry(backend) + await reg.register('r1', 4) + await reg.register('r2', 7) + assert await reg.get_max_loras('r1') == 4 + assert await reg.get_max_loras('r2') == 7 + assert await reg.get_max_loras('unknown') is None + all_ = await reg.get_all() + assert all_ == {'r1': 4, 'r2': 7} + await reg.unregister('r1') + assert await reg.get_max_loras('r1') is None + + +@pytest.mark.asyncio +async def test_cross_instance_visibility_in_process() -> None: + """Two ``ServerState`` instances on one shared MemoryBackend see the same writes (R19.4 in-process).""" + backend = MemoryBackend() + a = ServerState(backend=backend) + b = ServerState(backend=backend) + await a.register_replica('r1', 3) + await a.register_model({'base_model': 'x'}, token='t1', model_id='m1', replica_id='r1') + info_b = await b.get_capacity_info() + assert info_b == {'max_loras': 3, 'used_loras': 1, 'free_loras': 2} + avail_b = await b.get_available_replica_ids(['r1']) + assert avail_b == ['r1'] diff --git a/tests/server/state/test_factory.py b/tests/server/state/test_factory.py new file mode 100644 index 000000000..9b0ea051d --- /dev/null +++ b/tests/server/state/test_factory.py @@ -0,0 +1,88 @@ +"""Tests for backend factory - create_backend function.""" +from __future__ import annotations + +import os +import pytest +import tempfile + +from twinkle.server.state.backend.factory import PersistenceConfig, create_backend +from twinkle.server.state.backend.file_backend import FileBackend +from twinkle.server.state.backend.memory_backend import MemoryBackend + +# ---- Memory Mode ---- + + +def test_create_backend_none_returns_memory(): + """Passing None should return MemoryBackend (default mode).""" + backend = create_backend(None) + assert isinstance(backend, MemoryBackend) + + +def test_create_backend_memory_mode(): + """Explicit memory mode should return MemoryBackend.""" + config = PersistenceConfig(mode='memory') + backend = create_backend(config) + assert isinstance(backend, MemoryBackend) + + +# ---- File Mode ---- + + +def test_create_backend_file_mode(): + """File mode with file_path should return FileBackend.""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + path = f.name + try: + os.unlink(path) # Let FileBackend create it + config = PersistenceConfig(mode='file', file_path=path) + backend = create_backend(config) + assert isinstance(backend, FileBackend) + finally: + if os.path.exists(path): + os.unlink(path) + + +def test_create_backend_file_mode_missing_path(): + """File mode without file_path should raise ValueError.""" + config = PersistenceConfig(mode='file') + with pytest.raises(ValueError, match='file_path'): + create_backend(config) + + +# ---- Redis Mode ---- + + +def test_create_backend_redis_mode(): + """Redis mode with redis_url should return RedisBackend (if redis available).""" + try: + import redis # noqa: F401 + except ImportError: + pytest.skip('redis package not available') + + from unittest.mock import MagicMock, patch + + from twinkle.server.state.backend.redis_backend import RedisBackend + + with patch('redis.asyncio.from_url', return_value=MagicMock()): + config = PersistenceConfig(mode='redis', redis_url='redis://localhost:6379') + backend = create_backend(config) + assert isinstance(backend, RedisBackend) + + +def test_create_backend_redis_mode_missing_url(): + """Redis mode without redis_url should raise ValueError.""" + config = PersistenceConfig(mode='redis') + with pytest.raises(ValueError, match='redis_url'): + create_backend(config) + + +# ---- PersistenceConfig Defaults ---- + + +def test_persistence_config_defaults(): + """PersistenceConfig should have sensible defaults.""" + config = PersistenceConfig() + assert config.mode == 'memory' + assert config.file_path is None + assert config.redis_url is None + assert config.key_prefix == '' diff --git a/tests/server/state/test_file_backend.py b/tests/server/state/test_file_backend.py new file mode 100644 index 000000000..32c393926 --- /dev/null +++ b/tests/server/state/test_file_backend.py @@ -0,0 +1,219 @@ +"""Tests for FileBackend - JSON file-based state backend.""" +from __future__ import annotations + +import asyncio +import json +import os +import pytest +import tempfile +import time + +from twinkle.server.state.backend.file_backend import FileBackend + + +@pytest.fixture +def tmp_file(): + """Provide a temporary file path, deleted before use so FileBackend creates fresh.""" + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + path = f.name + os.unlink(path) + yield path + if os.path.exists(path): + os.unlink(path) + + +# ---- Basic CRUD ---- + + +@pytest.mark.asyncio +async def test_set_and_get(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('key1', {'hello': 'world'}) + result = await backend.get('key1') + assert result == {'hello': 'world'} + + +@pytest.mark.asyncio +async def test_get_nonexistent_key(tmp_file): + backend = FileBackend(tmp_file) + result = await backend.get('nonexistent') + assert result is None + + +@pytest.mark.asyncio +async def test_delete(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('key1', 'value1') + await backend.delete('key1') + result = await backend.get('key1') + assert result is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_key(tmp_file): + """Deleting a key that doesn't exist should not raise.""" + backend = FileBackend(tmp_file) + await backend.delete('nonexistent') # Should not raise + + +@pytest.mark.asyncio +async def test_exists(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('key1', 'value1') + assert await backend.exists('key1') is True + assert await backend.exists('key2') is False + + +# ---- TTL Expiry ---- + + +@pytest.mark.asyncio +async def test_ttl_expiry(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('ephemeral', 'data', ttl=1) + # Immediately should exist + assert await backend.get('ephemeral') == 'data' + assert await backend.exists('ephemeral') is True + # Wait for expiry + time.sleep(1.1) + assert await backend.get('ephemeral') is None + assert await backend.exists('ephemeral') is False + + +@pytest.mark.asyncio +async def test_ttl_none_means_no_expiry(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('permanent', 'data', ttl=None) + time.sleep(0.1) + assert await backend.get('permanent') == 'data' + + +# ---- Keys Pattern Matching ---- + + +@pytest.mark.asyncio +async def test_keys_wildcard(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('session::abc', 's1') + await backend.set('session::def', 's2') + await backend.set('model::xyz', 'm1') + + session_keys = await backend.keys('session::*') + assert sorted(session_keys) == ['session::abc', 'session::def'] + + model_keys = await backend.keys('model::*') + assert model_keys == ['model::xyz'] + + all_keys = await backend.keys('*') + assert len(all_keys) == 3 + + +@pytest.mark.asyncio +async def test_keys_excludes_expired(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('alive', 'yes') + await backend.set('dying', 'soon', ttl=1) + time.sleep(1.1) + keys = await backend.keys('*') + assert keys == ['alive'] + + +# ---- Count ---- + + +@pytest.mark.asyncio +async def test_count(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('a::1', 'v') + await backend.set('a::2', 'v') + await backend.set('b::1', 'v') + assert await backend.count('a::*') == 2 + assert await backend.count('b::*') == 1 + assert await backend.count('*') == 3 + + +# ---- set_nx ---- + + +@pytest.mark.asyncio +async def test_set_nx_new_key(tmp_file): + backend = FileBackend(tmp_file) + result = await backend.set_nx('new_key', 'value') + assert result is True + assert await backend.get('new_key') == 'value' + + +@pytest.mark.asyncio +async def test_set_nx_existing_key(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('existing', 'original') + result = await backend.set_nx('existing', 'new_value') + assert result is False + # Value should not change + assert await backend.get('existing') == 'original' + + +@pytest.mark.asyncio +async def test_set_nx_expired_key(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('expired_key', 'old', ttl=1) + time.sleep(1.1) + # Key is expired, set_nx should succeed + result = await backend.set_nx('expired_key', 'new_value') + assert result is True + assert await backend.get('expired_key') == 'new_value' + + +# ---- Health Check ---- + + +@pytest.mark.asyncio +async def test_health_check(tmp_file): + backend = FileBackend(tmp_file) + assert await backend.health_check() is True + + +# ---- Auto-create File ---- + + +@pytest.mark.asyncio +async def test_auto_create_file(): + """FileBackend should create the file if it doesn't exist.""" + with tempfile.TemporaryDirectory() as tmp_dir: + path = os.path.join(tmp_dir, 'subdir', 'state.json') + FileBackend(path) + assert os.path.exists(path) + # File should be valid JSON + with open(path) as f: + data = json.load(f) + assert data == {} + + +# ---- Atomic Write Integrity ---- + + +@pytest.mark.asyncio +async def test_atomic_write_integrity(tmp_file): + """After write, reading from file should give consistent data.""" + backend = FileBackend(tmp_file) + await backend.set('k1', {'nested': [1, 2, 3]}) + await backend.set('k2', 'simple_string') + + # Read raw file to verify structure + with open(tmp_file, encoding='utf-8') as f: + raw = json.load(f) + assert 'k1' in raw + assert raw['k1']['value'] == {'nested': [1, 2, 3]} + assert 'k2' in raw + assert raw['k2']['value'] == 'simple_string' + + +# ---- Overwrite Value ---- + + +@pytest.mark.asyncio +async def test_overwrite_value(tmp_file): + backend = FileBackend(tmp_file) + await backend.set('key', 'v1') + await backend.set('key', 'v2') + assert await backend.get('key') == 'v2' diff --git a/tests/server/state/test_managers.py b/tests/server/state/test_managers.py new file mode 100644 index 000000000..ec51dd324 --- /dev/null +++ b/tests/server/state/test_managers.py @@ -0,0 +1,370 @@ +"""Tests for state managers using MemoryBackend as integration backend.""" +from __future__ import annotations + +import pytest +import time +from datetime import datetime, timezone + +from twinkle.server.state.backend.memory_backend import MemoryBackend +from twinkle.server.state.future_manager import FutureManager +from twinkle.server.state.model_manager import ModelManager +from twinkle.server.state.models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord +from twinkle.server.state.sampling_manager import SamplingSessionManager +from twinkle.server.state.session_manager import SessionManager + +# ============================================================ +# SessionManager Tests +# ============================================================ + + +class TestSessionManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return SessionManager(backend=backend, expiration_timeout=300.0) + + @pytest.mark.asyncio + async def test_add_and_get(self, manager): + record = SessionRecord(tags=['test'], sdk_version='1.0') + await manager.add('sess1', record) + result = await manager.get('sess1') + assert result is not None + assert result.tags == ['test'] + assert result.sdk_version == '1.0' + + @pytest.mark.asyncio + async def test_get_nonexistent(self, manager): + result = await manager.get('nonexistent') + assert result is None + + @pytest.mark.asyncio + async def test_remove(self, manager): + record = SessionRecord() + await manager.add('sess1', record) + removed = await manager.remove('sess1') + assert removed is True + assert await manager.get('sess1') is None + + @pytest.mark.asyncio + async def test_remove_nonexistent(self, manager): + removed = await manager.remove('nonexistent') + assert removed is False + + @pytest.mark.asyncio + async def test_count(self, manager): + await manager.add('s1', SessionRecord()) + await manager.add('s2', SessionRecord()) + assert await manager.count() == 2 + + @pytest.mark.asyncio + async def test_touch_updates_heartbeat(self, manager): + record = SessionRecord(last_heartbeat=1000.0) + await manager.add('sess1', record) + before = time.time() + result = await manager.touch('sess1') + after = time.time() + assert result is True + updated = await manager.get('sess1') + assert before <= updated.last_heartbeat <= after + + @pytest.mark.asyncio + async def test_touch_nonexistent(self, manager): + result = await manager.touch('nonexistent') + assert result is False + + @pytest.mark.asyncio + async def test_get_last_heartbeat(self, manager): + record = SessionRecord(last_heartbeat=12345.0) + await manager.add('sess1', record) + hb = await manager.get_last_heartbeat('sess1') + assert hb == 12345.0 + + @pytest.mark.asyncio + async def test_cleanup_expired(self, manager): + now = time.time() + # Old session + old_record = SessionRecord(last_heartbeat=now - 1000) + await manager.add('old_sess', old_record) + # Recent session + new_record = SessionRecord(last_heartbeat=now) + await manager.add('new_sess', new_record) + + cutoff = now - 500 + removed_count = await manager.cleanup_expired(cutoff) + assert removed_count == 1 + assert await manager.get('old_sess') is None + assert await manager.get('new_sess') is not None + + @pytest.mark.asyncio + async def test_cleanup_expired_uses_created_at_fallback(self, manager): + """When last_heartbeat is 0, should use created_at for expiry check.""" + old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() + record = SessionRecord(last_heartbeat=0.0, created_at=old_time) + await manager.add('old_sess', record) + + cutoff = time.time() - 100 + removed_count = await manager.cleanup_expired(cutoff) + assert removed_count == 1 + + +# ============================================================ +# ModelManager Tests +# ============================================================ + + +class TestModelManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return ModelManager(backend=backend, expiration_timeout=300.0, per_token_model_limit=3) + + @pytest.mark.asyncio + async def test_add_and_get(self, manager): + record = ModelRecord(token='tok1', session_id='sess1', base_model='qwen') + await manager.add('model1', record) + result = await manager.get('model1') + assert result is not None + assert result.token == 'tok1' + assert result.base_model == 'qwen' + + @pytest.mark.asyncio + async def test_remove(self, manager): + record = ModelRecord(token='tok1') + await manager.add('model1', record) + removed = await manager.remove('model1') + assert removed is True + assert await manager.get('model1') is None + + @pytest.mark.asyncio + async def test_token_limit_enforced(self, manager): + """Adding more models than per_token_model_limit should raise RuntimeError.""" + for i in range(3): + await manager.add(f'm{i}', ModelRecord(token='tok1')) + + with pytest.raises(RuntimeError, match='Model limit exceeded'): + await manager.add('m3', ModelRecord(token='tok1')) + + @pytest.mark.asyncio + async def test_token_limit_per_token(self, manager): + """Limit is per-token, different tokens have separate limits.""" + for i in range(3): + await manager.add(f'a{i}', ModelRecord(token='tokenA')) + # Different token should work + await manager.add('b0', ModelRecord(token='tokenB')) + assert await manager.get('b0') is not None + + @pytest.mark.asyncio + async def test_replica_registration(self, manager): + await manager.register_replica('replica1', max_loras=5) + info = await manager.get_capacity_info() + assert info['max_loras'] == 5 + assert info['used_loras'] == 0 + assert info['free_loras'] == 5 + + @pytest.mark.asyncio + async def test_capacity_info_after_add(self, manager): + await manager.register_replica('r1', max_loras=3) + record = ModelRecord(token='tok1', replica_id='r1') + await manager.add('m1', record) + info = await manager.get_capacity_info() + assert info['used_loras'] == 1 + assert info['free_loras'] == 2 + + @pytest.mark.asyncio + async def test_indexes_derived_from_backend(self, manager): + """Per-token / per-replica counts are derived from the backend.""" + record1 = ModelRecord(token='tok1', replica_id='r1') + record2 = ModelRecord(token='tok1', replica_id='r2') + await manager.add('m1', record1) + await manager.add('m2', record2) + + await manager.register_replica('r1', max_loras=5) + await manager.register_replica('r2', max_loras=5) + + # Backend-derived availability reflects all persisted records. + avail = await manager.get_available_replica_ids(['r1', 'r2']) + assert avail == ['r1', 'r2'] + + # Per-token count enforces the limit using the persisted records. + count = await manager._count_models_for_token('tok1') + assert count == 2 + + @pytest.mark.asyncio + async def test_cascade_cleanup_by_session(self, manager): + """Models owned by expired sessions should be cleaned up.""" + now = time.time() + record = ModelRecord( + token='tok1', + session_id='expired_sess', + created_at=datetime.now(timezone.utc).isoformat(), + ) + await manager.add('m1', record) + + # Cleanup with cascade + removed = await manager.cleanup_expired( + cutoff_time=now - 10000, # cutoff is old, so age-based wouldn't trigger + expired_session_ids=['expired_sess'], + ) + assert removed == 1 + assert await manager.get('m1') is None + + @pytest.mark.asyncio + async def test_get_available_replica_ids(self, manager): + await manager.register_replica('r1', max_loras=2) + await manager.register_replica('r2', max_loras=1) + # Fill r2 + await manager.add('m1', ModelRecord(token='t', replica_id='r2')) + + available = await manager.get_available_replica_ids(['r1', 'r2', 'r3_unknown']) + # r1 has capacity, r2 is full, r3 unknown (conservative include) + assert 'r1' in available + assert 'r2' not in available + assert 'r3_unknown' in available + + +# ============================================================ +# SamplingSessionManager Tests +# ============================================================ + + +class TestSamplingSessionManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return SamplingSessionManager(backend=backend, expiration_timeout=300.0) + + @pytest.mark.asyncio + async def test_add_and_get(self, manager): + record = SamplingSessionRecord(session_id='sess1', base_model='qwen') + await manager.add('samp1', record) + result = await manager.get('samp1') + assert result is not None + assert result.session_id == 'sess1' + assert result.base_model == 'qwen' + + @pytest.mark.asyncio + async def test_cleanup_expired_by_age(self, manager): + old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() + record = SamplingSessionRecord(session_id='sess1', created_at=old_time) + await manager.add('samp_old', record) + + # Recent + record2 = SamplingSessionRecord(session_id='sess2') + await manager.add('samp_new', record2) + + cutoff = time.time() - 100 + removed = await manager.cleanup_expired(cutoff) + assert removed == 1 + assert await manager.get('samp_old') is None + assert await manager.get('samp_new') is not None + + @pytest.mark.asyncio + async def test_cleanup_expired_cascade(self, manager): + """Sampling sessions should be cleaned when their parent session expires.""" + record = SamplingSessionRecord(session_id='expired_sess') + await manager.add('samp1', record) + + removed = await manager.cleanup_expired( + cutoff_time=0.0, # Won't catch by age + expired_session_ids=['expired_sess'], + ) + assert removed == 1 + assert await manager.get('samp1') is None + + +# ============================================================ +# FutureManager Tests +# ============================================================ + + +class TestFutureManager: + + @pytest.fixture + def backend(self): + return MemoryBackend() + + @pytest.fixture + def manager(self, backend): + return FutureManager(backend=backend, expiration_timeout=300.0) + + @pytest.mark.asyncio + async def test_store_status_creates_new(self, manager): + await manager.store_status( + request_id='req1', + status='pending', + model_id='model1', + ) + result = await manager.get('req1') + assert result is not None + assert result.status == 'pending' + assert result.model_id == 'model1' + + @pytest.mark.asyncio + async def test_store_status_updates_existing(self, manager): + await manager.store_status(request_id='req1', status='pending', model_id='m1') + await manager.store_status(request_id='req1', status='completed', model_id='m1', result={'output': 'done'}) + result = await manager.get('req1') + assert result.status == 'completed' + assert result.result == {'output': 'done'} + + @pytest.mark.asyncio + async def test_store_status_with_pydantic_result(self, manager): + """Pydantic models should be serialized via model_dump.""" + from pydantic import BaseModel + + class MockResult(BaseModel): + score: float = 0.95 + + await manager.store_status(request_id='req1', status='completed', model_id='m1', result=MockResult()) + result = await manager.get('req1') + assert result.result == {'score': 0.95} + + @pytest.mark.asyncio + async def test_store_status_preserves_reason(self, manager): + await manager.store_status(request_id='req1', status='rate_limited', model_id=None, reason='Too many requests') + result = await manager.get('req1') + assert result.reason == 'Too many requests' + + @pytest.mark.asyncio + async def test_cleanup_expired(self, manager): + old_time = datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat() + old_record = FutureRecord(status='completed', created_at=old_time, updated_at=old_time) + await manager.add('old_req', old_record) + + new_record = FutureRecord(status='pending') + await manager.add('new_req', new_record) + + cutoff = time.time() - 100 + removed = await manager.cleanup_expired(cutoff) + assert removed == 1 + assert await manager.get('old_req') is None + assert await manager.get('new_req') is not None + + @pytest.mark.asyncio + async def test_get_nonexistent(self, manager): + result = await manager.get('nonexistent') + assert result is None + + @pytest.mark.asyncio + async def test_store_status_queue_state(self, manager): + await manager.store_status( + request_id='req1', + status='queued', + model_id='m1', + queue_state='paused_rate_limit', + queue_state_reason='Rate limit hit') + result = await manager.get('req1') + assert result.queue_state == 'paused_rate_limit' + assert result.queue_state_reason == 'Rate limit hit' diff --git a/tests/server/state/test_redis_backend.py b/tests/server/state/test_redis_backend.py new file mode 100644 index 000000000..813b60708 --- /dev/null +++ b/tests/server/state/test_redis_backend.py @@ -0,0 +1,207 @@ +"""Tests for RedisBackend - using mocks since no real Redis is available.""" +from __future__ import annotations + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +# Skip entire module if redis package not available +redis = pytest.importorskip('redis') + +from twinkle.server.state.backend.redis_backend import RedisBackend # noqa: E402 + + +@pytest.fixture +def mock_redis_client(): + """Create a mock redis.asyncio client.""" + client = AsyncMock() + client.set = AsyncMock(return_value=True) + client.get = AsyncMock(return_value=None) + client.delete = AsyncMock(return_value=1) + client.exists = AsyncMock(return_value=1) + client.keys = AsyncMock(return_value=[]) + client.ping = AsyncMock(return_value=True) + client.aclose = AsyncMock() + return client + + +@pytest.fixture +def backend_no_prefix(mock_redis_client): + """RedisBackend with no key prefix.""" + with patch('redis.asyncio.from_url', return_value=mock_redis_client): + backend = RedisBackend('redis://localhost:6379') + return backend + + +@pytest.fixture +def backend_with_prefix(mock_redis_client): + """RedisBackend with key prefix.""" + with patch('redis.asyncio.from_url', return_value=mock_redis_client): + backend = RedisBackend('redis://localhost:6379', key_prefix='twinkle:') + return backend + + +# ---- SET ---- + + +@pytest.mark.asyncio +async def test_set_without_ttl(backend_no_prefix, mock_redis_client): + await backend_no_prefix.set('mykey', {'data': 123}) + mock_redis_client.set.assert_called_once_with('mykey', json.dumps({'data': 123})) + + +@pytest.mark.asyncio +async def test_set_with_ttl(backend_no_prefix, mock_redis_client): + await backend_no_prefix.set('mykey', 'value', ttl=60) + mock_redis_client.set.assert_called_once_with('mykey', json.dumps('value'), ex=60) + + +@pytest.mark.asyncio +async def test_set_with_prefix(backend_with_prefix, mock_redis_client): + await backend_with_prefix.set('mykey', 'val') + mock_redis_client.set.assert_called_once_with('twinkle:mykey', json.dumps('val')) + + +# ---- GET ---- + + +@pytest.mark.asyncio +async def test_get_existing_key(backend_no_prefix, mock_redis_client): + mock_redis_client.get.return_value = json.dumps({'hello': 'world'}) + result = await backend_no_prefix.get('mykey') + mock_redis_client.get.assert_called_once_with('mykey') + assert result == {'hello': 'world'} + + +@pytest.mark.asyncio +async def test_get_nonexistent_key(backend_no_prefix, mock_redis_client): + mock_redis_client.get.return_value = None + result = await backend_no_prefix.get('missing') + assert result is None + + +@pytest.mark.asyncio +async def test_get_with_prefix(backend_with_prefix, mock_redis_client): + mock_redis_client.get.return_value = json.dumps('data') + result = await backend_with_prefix.get('mykey') + mock_redis_client.get.assert_called_once_with('twinkle:mykey') + assert result == 'data' + + +# ---- DELETE ---- + + +@pytest.mark.asyncio +async def test_delete(backend_no_prefix, mock_redis_client): + await backend_no_prefix.delete('mykey') + mock_redis_client.delete.assert_called_once_with('mykey') + + +@pytest.mark.asyncio +async def test_delete_with_prefix(backend_with_prefix, mock_redis_client): + await backend_with_prefix.delete('mykey') + mock_redis_client.delete.assert_called_once_with('twinkle:mykey') + + +# ---- EXISTS ---- + + +@pytest.mark.asyncio +async def test_exists_true(backend_no_prefix, mock_redis_client): + mock_redis_client.exists.return_value = 1 + result = await backend_no_prefix.exists('mykey') + assert result is True + mock_redis_client.exists.assert_called_once_with('mykey') + + +@pytest.mark.asyncio +async def test_exists_false(backend_no_prefix, mock_redis_client): + mock_redis_client.exists.return_value = 0 + result = await backend_no_prefix.exists('mykey') + assert result is False + + +# ---- KEYS ---- + + +@pytest.mark.asyncio +async def test_keys_pattern(backend_no_prefix, mock_redis_client): + mock_redis_client.keys.return_value = ['session::a', 'session::b'] + result = await backend_no_prefix.keys('session::*') + mock_redis_client.keys.assert_called_once_with('session::*') + assert result == ['session::a', 'session::b'] + + +@pytest.mark.asyncio +async def test_keys_with_prefix(backend_with_prefix, mock_redis_client): + mock_redis_client.keys.return_value = ['twinkle:session::a', 'twinkle:session::b'] + result = await backend_with_prefix.keys('session::*') + mock_redis_client.keys.assert_called_once_with('twinkle:session::*') + # Result should have prefix stripped + assert result == ['session::a', 'session::b'] + + +# ---- COUNT ---- + + +@pytest.mark.asyncio +async def test_count(backend_no_prefix, mock_redis_client): + mock_redis_client.keys.return_value = ['a', 'b', 'c'] + result = await backend_no_prefix.count('*') + assert result == 3 + + +# ---- SET_NX ---- + + +@pytest.mark.asyncio +async def test_set_nx_success(backend_no_prefix, mock_redis_client): + mock_redis_client.set.return_value = True # nx succeeded + result = await backend_no_prefix.set_nx('newkey', {'value': 1}) + mock_redis_client.set.assert_called_once_with('newkey', json.dumps({'value': 1}), nx=True) + assert result is True + + +@pytest.mark.asyncio +async def test_set_nx_failure(backend_no_prefix, mock_redis_client): + mock_redis_client.set.return_value = None # nx failed (key exists) + result = await backend_no_prefix.set_nx('existing', 'val') + assert result is False + + +# ---- HEALTH CHECK ---- + + +@pytest.mark.asyncio +async def test_health_check_healthy(backend_no_prefix, mock_redis_client): + mock_redis_client.ping.return_value = True + result = await backend_no_prefix.health_check() + assert result is True + + +@pytest.mark.asyncio +async def test_health_check_unhealthy(backend_no_prefix, mock_redis_client): + mock_redis_client.ping.side_effect = ConnectionError('offline') + result = await backend_no_prefix.health_check() + assert result is False + + +# ---- CLOSE ---- + + +@pytest.mark.asyncio +async def test_close(backend_no_prefix, mock_redis_client): + await backend_no_prefix.close() + mock_redis_client.aclose.assert_called_once() + + +# ---- JSON Serialization ---- + + +@pytest.mark.asyncio +async def test_json_serialization_complex_types(backend_no_prefix, mock_redis_client): + """Values should be JSON-serialized before storage.""" + complex_value = {'list': [1, 2, 3], 'nested': {'key': 'val'}, 'null': None} + await backend_no_prefix.set('complex', complex_value) + expected_json = json.dumps(complex_value) + mock_redis_client.set.assert_called_once_with('complex', expected_json) diff --git a/tests/server/state/test_redis_integration.py b/tests/server/state/test_redis_integration.py new file mode 100644 index 000000000..11f2def2f --- /dev/null +++ b/tests/server/state/test_redis_integration.py @@ -0,0 +1,192 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Phase 0d Redis integration tests (R19.4, R19.5). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 26: Cross-worker write visibility +- # Feature: server-config-observability-refactor, Property 27: Concurrent-write consistency + +Both run against a real Redis instance reached at ``REDIS_URL`` (default +``redis://localhost:6379/0``). When the URL is unreachable the whole module +is skipped — these tests are explicitly Docker-dependent and must run +against the local stack rather than the in-process mock. +""" +from __future__ import annotations + +import asyncio +import os +import pytest +import uuid + +from twinkle.server.state import ServerState +from twinkle.server.state.backend.factory import PersistenceConfig, create_backend +from twinkle.server.state.backend.redis_backend import RedisBackend + +REDIS_URL = os.environ.get('TWINKLE_TEST_REDIS_URL', 'redis://localhost:6379/0') + + +def _can_reach_redis() -> bool: + + async def _check() -> bool: + backend = RedisBackend(REDIS_URL) + try: + return await backend.health_check() + except Exception: + return False + finally: + try: + await backend.close() + except Exception: + pass + + try: + return asyncio.run(_check()) + except Exception: + return False + + +pytestmark = pytest.mark.skipif( + not _can_reach_redis(), + reason=f'Redis at {REDIS_URL} unreachable — start docker compose / `docker run -p 6379:6379 redis`', +) + + +@pytest.fixture +def isolation_prefix() -> str: + """Fresh key namespace per test so parallel runs don't collide.""" + return f'twinkle-test-{uuid.uuid4().hex[:8]}::' + + +@pytest.fixture +async def shared_backend(isolation_prefix: str): + backend = create_backend(PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=isolation_prefix)) + yield backend + # Tear down everything we wrote. + try: + keys = await backend.keys('*') + for k in keys: + await backend.delete(k) + finally: + await backend.close() + + +@pytest.fixture +def make_state(isolation_prefix: str): + """Factory for fresh ``ServerState`` instances over the same shared key prefix. + + Each call returns a NEW ``RedisBackend`` (separate connection pool) so the + tests genuinely exercise cross-instance behaviour rather than two views + of the same client. + """ + created: list[ServerState] = [] + + def _make() -> ServerState: + backend = create_backend(PersistenceConfig(mode='redis', redis_url=REDIS_URL, key_prefix=isolation_prefix)) + state = ServerState(backend=backend) + created.append(state) + return state + + yield _make + + async def _cleanup() -> None: + for s in created: + try: + await s._backend.close() + except Exception: + pass + + asyncio.run(_cleanup()) + + +# ---------- Property 26: cross-worker visibility (R19.4) ------------------ # + + +@pytest.mark.asyncio +async def test_property_26_replica_write_via_a_visible_via_b(make_state) -> None: + """One worker registers a replica; a second worker on the same shared + backend sees the same capacity / availability view.""" + a = make_state() + b = make_state() + rid = f'r-{uuid.uuid4().hex[:6]}' + await a.register_replica(rid, max_loras=4) + + cap = await b.get_capacity_info() + assert cap['max_loras'] >= 4 + assert rid in await b.get_available_replica_ids([rid]) + + +@pytest.mark.asyncio +async def test_property_26_model_write_visible(make_state) -> None: + a = make_state() + b = make_state() + rid = f'r-{uuid.uuid4().hex[:6]}' + await a.register_replica(rid, max_loras=2) + mid = await a.register_model({'base_model': 'mock'}, token='tok-A', model_id='mid-A', replica_id=rid) + + meta = await b.get_model_metadata(mid) + assert meta is not None + assert meta['token'] == 'tok-A' + assert meta['replica_id'] == rid + + +@pytest.mark.asyncio +async def test_property_26_session_and_config(make_state) -> None: + a = make_state() + b = make_state() + sid = await a.create_session({'session_id': f'sess-{uuid.uuid4().hex[:6]}'}) + assert await b.get_session_last_heartbeat(sid) is not None + + await a.add_config('feature_flag', {'value': 42}) + assert await b.get_config('feature_flag') == {'value': 42} + + +# ---------- Property 27: concurrent-write consistency (R19.5) ------------- # + + +@pytest.mark.asyncio +async def test_property_27_concurrent_config_writes_no_torn_records(make_state) -> None: + """Many concurrent writes of distinct keys complete and every record + equals one of the writes (no torn / partial value).""" + a = make_state() + b = make_state() + n = 40 + payload = {f'k-{i}': {'idx': i, 'note': 'x' * 32} for i in range(n)} + + async def writer(state: ServerState, items: dict) -> None: + await asyncio.gather(*(state.add_config(k, v) for k, v in items.items())) + + half = list(payload.items())[:n // 2] + other = list(payload.items())[n // 2:] + await asyncio.gather(writer(a, dict(half)), writer(b, dict(other))) + + # Every key must read back equal to its expected payload from either side. + for k, v in payload.items(): + assert await a.get_config(k) == v, k + assert await b.get_config(k) == v, k + + +@pytest.mark.asyncio +async def test_property_27_concurrent_same_key_lands_one_of_committed(make_state) -> None: + """Two writers race on the same key — final value equals one of the + writes; no torn record (R19.5).""" + a = make_state() + b = make_state() + write_a = {'who': 'a', 'payload': list(range(8))} + write_b = {'who': 'b', 'payload': list(range(8, 16))} + + await asyncio.gather(a.add_config('contended', write_a), b.add_config('contended', write_b)) + final = await a.get_config('contended') + assert final in (write_a, write_b) + + +@pytest.mark.asyncio +async def test_property_27_concurrent_replica_registration(make_state) -> None: + a = make_state() + b = make_state() + rid = f'r-{uuid.uuid4().hex[:6]}' + await asyncio.gather( + a.register_replica(rid, 4), + b.register_replica(rid, 4), + ) + cap = await a.get_capacity_info() + # Capacity row stores ``max_loras``; both writers wrote 4, no torn write. + assert cap['max_loras'] == 4 diff --git a/tests/server/telemetry/__init__.py b/tests/server/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/telemetry/conftest.py b/tests/server/telemetry/conftest.py new file mode 100644 index 000000000..5a9276bc8 --- /dev/null +++ b/tests/server/telemetry/conftest.py @@ -0,0 +1,37 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared OTEL setup for telemetry tests. + +OTel's global tracer provider is one-shot per process — the second +``trace.set_tracer_provider(...)`` call no-ops with a warning. So multiple +test modules that each tried to register their own provider would silently +share whichever one ran first, and tests that read spans from the wrong +exporter would fail. This fixture installs one provider + one in-memory +exporter for the entire telemetry test package. +""" +from __future__ import annotations + +import pytest + + +def _otel_available() -> bool: + try: + from opentelemetry import trace # noqa: F401 + except Exception: + return False + return True + + +@pytest.fixture(scope='session') +def in_memory_span_exporter(): + if not _otel_available(): + pytest.skip('OTEL SDK not installed in test env') + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + return exporter diff --git a/tests/server/telemetry/test_context_carrier.py b/tests/server/telemetry/test_context_carrier.py new file mode 100644 index 000000000..36b35b67a --- /dev/null +++ b/tests/server/telemetry/test_context_carrier.py @@ -0,0 +1,83 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Trace context carrier round-trip tests (R13). + +# Feature: server-config-observability-refactor, Property 24: Trace context carrier round-trip +""" +from __future__ import annotations + +import pytest +from unittest import mock + +from twinkle.server.telemetry import context_carrier +from twinkle.server.telemetry.context_carrier import activate_carrier, make_carrier + + +def _otel_available() -> bool: + try: + from opentelemetry import trace # noqa: F401 + except Exception: + return False + return True + + +# ---------- NoOp path (R13.4 / R18.3) ------------------------------------- # + + +def test_make_carrier_returns_empty_dict_when_otel_absent() -> None: + with mock.patch.object(context_carrier, '_OTEL_AVAILABLE', False): + assert make_carrier() == {} + + +def test_activate_carrier_with_none_is_safe_noop() -> None: + with activate_carrier(None): + pass + with activate_carrier({}): + pass + + +def test_activate_carrier_when_otel_absent_is_noop() -> None: + with mock.patch.object(context_carrier, '_OTEL_AVAILABLE', False): + with activate_carrier({'traceparent': 'whatever'}): + pass + + +# ---------- Property 24: round-trip (R13.1, R13.2) ------------------------ # + + +def test_property_24_carrier_round_trip(in_memory_span_exporter) -> None: + """Active context → make_carrier → activate_carrier → child span shares + the same trace id (R13.1, R13.2).""" + from opentelemetry import trace + + in_memory_span_exporter.clear() + tracer = trace.get_tracer('twinkle.test') + + carrier: dict[str, object] = {} + parent_trace_id: int | None = None + with tracer.start_as_current_span('caller') as parent: + parent_trace_id = parent.get_span_context().trace_id + carrier = make_carrier() + + # Sanity — the carrier carries something OTEL recognizes (traceparent). + assert carrier and any(k.lower() == 'traceparent' for k in carrier.keys()) + + # On the receiving side, activating the carrier and starting a span + # should yield a span whose trace id equals the parent's. + with activate_carrier(carrier): + with tracer.start_as_current_span('callee') as child: + child_trace_id = child.get_span_context().trace_id + + assert parent_trace_id == child_trace_id + + +def test_property_24_empty_carrier_starts_fresh_trace(in_memory_span_exporter) -> None: + """An empty / None carrier means: start a new trace (R13.4).""" + from opentelemetry import trace + + in_memory_span_exporter.clear() + tracer = trace.get_tracer('twinkle.test') + + with activate_carrier(None): + with tracer.start_as_current_span('orphan') as span: + tid = span.get_span_context().trace_id + assert tid != 0 diff --git a/tests/server/telemetry/test_tracing_and_correlation.py b/tests/server/telemetry/test_tracing_and_correlation.py new file mode 100644 index 000000000..e4e0bf826 --- /dev/null +++ b/tests/server/telemetry/test_tracing_and_correlation.py @@ -0,0 +1,263 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the business-layer tracing helper, correlation +keys, and ``ResourceMetricsCollector`` (R10, R11, R12, R18). + +Properties covered: +- # Feature: server-config-observability-refactor, Property 19: Business-layer span lifecycle +- # Feature: server-config-observability-refactor, Property 20: Span exception handling +- # Feature: server-config-observability-refactor, Property 21: Tracing graceful-degradation equivalence +- # Feature: server-config-observability-refactor, Property 22: Correlation attribute attachment +- # Feature: server-config-observability-refactor, Property 23: Correlation prefix invariant +""" +from __future__ import annotations + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from unittest import mock + +from twinkle.server.telemetry import correlation +from twinkle.server.telemetry.correlation import CORRELATION_KEYS, PREFIX, set_correlation_attrs +from twinkle.server.telemetry.tracing import _NoopSpan, traced_operation + +# ---------- Property 23: prefix invariant (R11.3) ------------------------- # + + +@pytest.mark.parametrize('key', CORRELATION_KEYS) +def test_property_23_prefix_invariant(key: str) -> None: + assert key.startswith(PREFIX), key + + +def test_property_23_helper_constants_complete() -> None: + expected = { + 'twinkle.session_id', + 'twinkle.model_id', + 'twinkle.replica_id', + 'twinkle.token_id', + 'twinkle.sampling_session_id', + 'twinkle.base_model', + } + assert set(CORRELATION_KEYS) == expected + + +# ---------- Property 22: attachment of present-only values (R11.1, R11.2) - # + + +class _RecordingSpan: + + def __init__(self) -> None: + self.attrs: dict[str, object] = {} + + def set_attribute(self, key: str, value: object) -> None: + self.attrs[key] = value + + +@settings(max_examples=100) +@given( + payload=st.fixed_dictionaries( + {}, + optional={ + correlation.SESSION_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + correlation.MODEL_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + correlation.REPLICA_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + correlation.TOKEN_ID: st.one_of(st.none(), st.text(min_size=1, max_size=8)), + }, + )) +def test_property_22_set_correlation_attrs_skips_none(payload: dict) -> None: + span = _RecordingSpan() + set_correlation_attrs(span, payload) + expected = {k: v for k, v in payload.items() if v is not None} + assert span.attrs == expected + + +def test_property_22_noop_span_safe() -> None: + """``set_correlation_attrs`` is a no-op on a NoOp span (no SDK installed).""" + span = _NoopSpan() + set_correlation_attrs(span, {correlation.SESSION_ID: 's1'}) + # NoOp span has no recording surface — passing None / empty mapping is also safe. + set_correlation_attrs(None, {correlation.SESSION_ID: 's1'}) + set_correlation_attrs(span, None) + + +# ---------- Property 21: NoOp degradation equivalence (R10.5, R18.3) ------ # + + +def test_property_21_noop_yields_same_result_as_active() -> None: + """When OTEL is absent, ``traced_operation`` runs the body and returns the + body's result identically to when OTEL is active.""" + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', False): + with traced_operation('op') as span: + assert isinstance(span, _NoopSpan) + result = sum(range(5)) + assert result == 10 + + +def test_property_21_noop_propagates_exceptions() -> None: + """NoOp path still re-raises the original exception unchanged.""" + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', False): + with pytest.raises(RuntimeError, match='boom'): + with traced_operation('op'): + raise RuntimeError('boom') + + +# ---------- Property 19/20: span lifecycle + exception handling ----------- # + + +def _otel_available() -> bool: + try: + from opentelemetry import trace as _otel_trace # noqa: F401 + except Exception: + return False + return True + + +def test_property_19_span_lifecycle(in_memory_span_exporter) -> None: + """When OTEL is present, a span is started before and ended after the block.""" + in_memory_span_exporter.clear() + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', True): + with traced_operation('op.under.test', attrs={correlation.SESSION_ID: 's1'}): + pass + + spans = in_memory_span_exporter.get_finished_spans() + matches = [s for s in spans if s.name == 'op.under.test'] + assert matches + assert matches[-1].attributes.get(correlation.SESSION_ID) == 's1' + + +def test_property_20_exception_recorded_and_reraised(in_memory_span_exporter) -> None: + """Exception inside the block is recorded on the span and re-raised.""" + in_memory_span_exporter.clear() + with mock.patch('twinkle.server.telemetry.tracing._OTEL_AVAILABLE', True): + with pytest.raises(ValueError, match='boom'): + with traced_operation('op.exc'): + raise ValueError('boom') + + spans = [s for s in in_memory_span_exporter.get_finished_spans() if s.name == 'op.exc'] + assert spans, 'span was not exported' + span = spans[-1] + assert span.status.status_code.name == 'ERROR' + assert any('exception' in evt.name.lower() for evt in span.events) + + +# ---------- ResourceMetricsCollector wiring (R12.1, R12.2, R12.3, R18.4) -- # + + +def test_resource_metrics_collector_does_not_raise_without_pynvml() -> None: + """Collector starts cleanly even when pynvml/GPU is absent (R12.3).""" + from twinkle.server.telemetry import resource_metrics + + with mock.patch.object(resource_metrics, '_PYNVML_AVAILABLE', False): + collector = resource_metrics.ResourceMetricsCollector() + collector.maybe_start() # must not raise + # No GPU gauges registered when pynvml is absent. + assert all(not g.startswith('twinkle.gpu.') for g in collector.registered_gauges) + + +def test_resource_metrics_collector_registers_named_gauges_when_psutil_present() -> None: + from twinkle.server.telemetry import resource_metrics + + if not resource_metrics._PSUTIL_AVAILABLE: + pytest.skip('psutil not installed in test env') + collector = resource_metrics.ResourceMetricsCollector() + collector.maybe_start() + # System CPU + system memory + process memory always present when psutil is. + expected = { + 'twinkle.system.cpu.utilization', + 'twinkle.system.memory.usage_bytes', + 'twinkle.process.memory.usage_bytes', + } + assert expected.issubset(set(collector.registered_gauges)) + + +def test_resource_metrics_collector_idempotent() -> None: + from twinkle.server.telemetry import resource_metrics + + collector = resource_metrics.ResourceMetricsCollector() + collector.maybe_start() + pre = list(collector.registered_gauges) + collector.maybe_start() + assert collector.registered_gauges == pre + + +def test_worker_init_starts_collector() -> None: + """``ensure_telemetry_initialized`` calls the resource collector even when + telemetry is disabled — the collector silently records to a NoOp meter + in that case (R12.2).""" + from twinkle.server.telemetry import resource_metrics, worker_init + + # Force the worker_init guard to re-run and clear the global collector + # so we observe a fresh ``maybe_start`` call. + worker_init._worker_initialized = False + resource_metrics.reset_collector_for_tests() + + sentinel = mock.MagicMock() + sentinel.maybe_start = mock.MagicMock() + + with mock.patch.object(resource_metrics, 'get_collector', return_value=sentinel) as get_spy: + worker_init.ensure_telemetry_initialized() + + assert get_spy.call_count >= 1 + assert sentinel.maybe_start.call_count == 1 + + +def test_init_telemetry_attaches_handler_to_twinkle_logger() -> None: + """``init_telemetry`` must attach the OTLP ``LoggingHandler`` to BOTH + the root logger AND the ``twinkle`` logger. + + The ``twinkle.utils.logger`` module configures the ``twinkle`` namespace + with ``propagate=False`` and its own StreamHandler — so an OTLP handler + bound only to root would never see any ``twinkle.*`` log records, and + the entire server's logs would be invisible in Loki / OTLP backends. + """ + import logging + from opentelemetry import _logs as _otel_logs + from opentelemetry import metrics, trace + from opentelemetry.sdk._logs import LoggingHandler + from opentelemetry.util._once import Once + + from twinkle.server.telemetry import provider + + # Reset all OTel global guards so init_telemetry runs cleanly. + trace._TRACER_PROVIDER_SET_ONCE = Once() + trace._TRACER_PROVIDER = None + metrics._METER_PROVIDER_SET_ONCE = Once() + metrics._METER_PROVIDER = None + if hasattr(_otel_logs, '_LOGGER_PROVIDER_SET_ONCE'): + _otel_logs._LOGGER_PROVIDER_SET_ONCE = Once() + _otel_logs._LOGGER_PROVIDER = None + provider._initialized = False + + # Clear stale handlers that might be attached from prior tests. + for name in ('', 'twinkle'): + for h in list(logging.getLogger(name).handlers): + if isinstance(h, LoggingHandler): + logging.getLogger(name).removeHandler(h) + + try: + provider.init_telemetry( + provider.TelemetryConfig( + enabled=True, + debug=True, # debug=True → console exporter, no real OTLP needed + service_name='twinkle-server-test', + )) + root_handlers = [h for h in logging.getLogger().handlers if isinstance(h, LoggingHandler)] + twinkle_handlers = [h for h in logging.getLogger('twinkle').handlers if isinstance(h, LoggingHandler)] + assert len(root_handlers) == 1, root_handlers + assert len(twinkle_handlers) == 1, twinkle_handlers + assert root_handlers[0] is twinkle_handlers[0], ('root and twinkle should share the same handler instance') + finally: + provider.shutdown_telemetry() + # shutdown should detach from both + assert all(not isinstance(h, LoggingHandler) for name in ('', 'twinkle') + for h in logging.getLogger(name).handlers) + + +def test_pyproject_declares_telemetry_extras() -> None: + """``pyproject.toml`` declares ``psutil`` and ``pynvml`` as telemetry extras (R12.4).""" + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[3] + text = (repo_root / 'pyproject.toml').read_text() + assert 'telemetry =' in text + assert 'psutil' in text + assert 'pynvml' in text diff --git a/tests/server/utils/__init__.py b/tests/server/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/utils/task_queue/__init__.py b/tests/server/utils/task_queue/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/utils/task_queue/test_config.py b/tests/server/utils/task_queue/test_config.py new file mode 100644 index 000000000..305210caf --- /dev/null +++ b/tests/server/utils/task_queue/test_config.py @@ -0,0 +1,172 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Property + unit tests for the Pydantic ``TaskQueueConfig`` (R9). + +Covers: +- # Feature: server-config-observability-refactor, Property 16: TaskQueueConfig constraint enforcement +- # Feature: server-config-observability-refactor, Property 17: from_dict equivalence +- # Feature: server-config-observability-refactor, Property 18: TaskQueueConfig defaulting +- Unit checks that the call sites in model/sampler/processor apps still construct the config + through ``TaskQueueConfig.from_dict``. +""" +from __future__ import annotations + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st +from pydantic import ValidationError + +from twinkle.server.utils.task_queue.config import TaskQueueConfig + +# ---------- defaults snapshot used by Property 18 (R9.8) ------------------- # + +DEFAULTS = { + 'rps_limit': 100.0, + 'tps_limit': 16000.0, + 'window_seconds': 1.0, + 'queue_timeout': 300.0, + 'token_cleanup_interval': 60.0, + 'max_input_tokens': 16000, +} + +# ---------- Property 16: constraint enforcement (R9.2-9.5, 9.7) ------------ # + +_CONSTRAINED_GE0_FLOATS = ['rps_limit', 'tps_limit', 'queue_timeout', 'token_cleanup_interval'] + + +@settings(max_examples=100) +@given( + field=st.sampled_from(_CONSTRAINED_GE0_FLOATS), + bad_value=st.floats(max_value=-1e-6, min_value=-1e6, allow_nan=False, allow_infinity=False), +) +def test_property_16_ge0_floats_reject_negative(field: str, bad_value: float) -> None: + """Non-negative float fields reject any negative input.""" + with pytest.raises(ValidationError) as exc: + TaskQueueConfig(**{field: bad_value}) + assert any(field in err['loc'] for err in exc.value.errors()) + + +@settings(max_examples=100) +@given(bad_value=st.floats(max_value=0.0, min_value=-1e6, allow_nan=False, allow_infinity=False)) +def test_property_16_window_seconds_rejects_zero_and_negative(bad_value: float) -> None: + """``window_seconds`` must be strictly > 0.""" + with pytest.raises(ValidationError) as exc: + TaskQueueConfig(window_seconds=bad_value) + assert any('window_seconds' in err['loc'] for err in exc.value.errors()) + + +@settings(max_examples=100) +@given(bad_value=st.integers(max_value=0, min_value=-1_000_000)) +def test_property_16_max_input_tokens_rejects_lt_1(bad_value: int) -> None: + """``max_input_tokens`` must be an integer ≥ 1.""" + with pytest.raises(ValidationError) as exc: + TaskQueueConfig(max_input_tokens=bad_value) + assert any('max_input_tokens' in err['loc'] for err in exc.value.errors()) + + +@settings(max_examples=100) +@given( + rps=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + tps=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + win=st.floats(min_value=1e-6, max_value=1e6, allow_nan=False, allow_infinity=False), + qt=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + cleanup=st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + mit=st.integers(min_value=1, max_value=10_000_000), +) +def test_property_16_valid_values_accepted(rps: float, tps: float, win: float, qt: float, cleanup: float, + mit: int) -> None: + """Any value satisfying the constraints constructs successfully.""" + cfg = TaskQueueConfig( + rps_limit=rps, + tps_limit=tps, + window_seconds=win, + queue_timeout=qt, + token_cleanup_interval=cleanup, + max_input_tokens=mit, + ) + assert cfg.rps_limit == rps + assert cfg.window_seconds == win + assert cfg.max_input_tokens == mit + + +# ---------- Property 17: from_dict equivalence (R9.6) ---------------------- # + +_INPUT_DICT_STRATEGY = st.fixed_dictionaries( + {}, + optional={ + 'rps_limit': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'tps_limit': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'window_seconds': st.floats(min_value=1e-6, max_value=1e6, allow_nan=False, allow_infinity=False), + 'queue_timeout': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'token_cleanup_interval': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'max_input_tokens': st.integers(min_value=1, max_value=10_000_000), + 'enabled': st.booleans(), + 'execution_timeout': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + 'token_cleanup_multiplier': st.floats(min_value=0.0, max_value=1e6, allow_nan=False, allow_infinity=False), + }, +) + + +@settings(max_examples=100) +@given(payload=_INPUT_DICT_STRATEGY) +def test_property_17_from_dict_equivalence(payload: dict) -> None: + """``from_dict`` returns the same instance as ``model_validate``.""" + via_factory = TaskQueueConfig.from_dict(payload) + via_validate = TaskQueueConfig.model_validate(payload) + assert via_factory.model_dump() == via_validate.model_dump() + + +# ---------- Property 18: defaulting (R9.8) --------------------------------- # + + +def test_property_18_from_dict_with_no_argument() -> None: + """``from_dict()`` with no argument returns the documented defaults.""" + cfg = TaskQueueConfig.from_dict() + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +def test_property_18_from_dict_with_none() -> None: + cfg = TaskQueueConfig.from_dict(None) + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +def test_property_18_from_dict_with_empty_dict() -> None: + cfg = TaskQueueConfig.from_dict({}) + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +@settings(max_examples=100) +@given(present=st.sets(st.sampled_from(sorted(DEFAULTS.keys())), max_size=len(DEFAULTS))) +def test_property_18_omitted_fields_take_defaults(present: set) -> None: + """Fields absent from the dict adopt their documented defaults.""" + payload = {f: DEFAULTS[f] for f in present} + cfg = TaskQueueConfig.from_dict(payload) + for field, value in DEFAULTS.items(): + assert getattr(cfg, field) == value, field + + +# ---------- Unit: extra=forbid + call-site usage --------------------------- # + + +def test_extra_field_rejected() -> None: + """``extra='forbid'`` rejects unknown keys (defends R8.2 scoped to this model).""" + with pytest.raises(ValidationError): + TaskQueueConfig.from_dict({'unknown_field': 1}) + + +def test_call_site_imports_resolve() -> None: + """``TaskQueueConfig.from_dict`` is what the apps call — keep that import alive. + + We don't instantiate the deployments (those need Ray Serve runtime); we just + confirm the call sites import the same name and the factory still produces + valid configs for the dicts they pass. + """ + from twinkle.server.utils.task_queue import TaskQueueConfig as Exported + + assert Exported is TaskQueueConfig + # Mimic the queue_config dicts shipped in cookbook/client/server YAMLs. + cfg = TaskQueueConfig.from_dict({'rps_limit': 100, 'tps_limit': 100000}) + assert cfg.rps_limit == 100.0 + assert cfg.tps_limit == 100000.0