From 0df937554cefa8fbe55aa5d647c0825e2133182f Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 25 Jun 2026 19:55:51 +0545 Subject: [PATCH 1/3] Agent Harness MVP --- mkdocs.yml | 1 + mkdocs/docs/guides/endpoint-harness.md | 183 ++++++++ skills/dstack/SKILL.md | 9 +- src/dstack/_internal/cli/commands/endpoint.py | 147 ++++++ src/dstack/_internal/cli/main.py | 2 + .../cli/services/configurators/run.py | 4 + src/dstack/_internal/harness/__init__.py | 20 + src/dstack/_internal/harness/deployer.py | 437 ++++++++++++++++++ src/dstack/_internal/harness/generator.py | 191 ++++++++ src/dstack/_internal/harness/llm.py | 173 +++++++ src/dstack/_internal/harness/models.py | 59 +++ src/dstack/_internal/harness/skill.py | 31 ++ .../_internal/cli/commands/test_endpoint.py | 21 + src/tests/_internal/harness/test_deployer.py | 302 ++++++++++++ src/tests/_internal/harness/test_generator.py | 253 ++++++++++ src/tests/_internal/harness/test_llm.py | 93 ++++ 16 files changed, 1925 insertions(+), 1 deletion(-) create mode 100644 mkdocs/docs/guides/endpoint-harness.md create mode 100644 src/dstack/_internal/cli/commands/endpoint.py create mode 100644 src/dstack/_internal/harness/__init__.py create mode 100644 src/dstack/_internal/harness/deployer.py create mode 100644 src/dstack/_internal/harness/generator.py create mode 100644 src/dstack/_internal/harness/llm.py create mode 100644 src/dstack/_internal/harness/models.py create mode 100644 src/dstack/_internal/harness/skill.py create mode 100644 src/tests/_internal/cli/commands/test_endpoint.py create mode 100644 src/tests/_internal/harness/test_deployer.py create mode 100644 src/tests/_internal/harness/test_generator.py create mode 100644 src/tests/_internal/harness/test_llm.py diff --git a/mkdocs.yml b/mkdocs.yml index 2fc74935aa..894318ce24 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -305,6 +305,7 @@ nav: - Exports: docs/concepts/exports.md - Guides: - CLI & API: docs/guides/cli-api.md + - Endpoint harness: docs/guides/endpoint-harness.md - Server deployment: docs/guides/server-deployment.md - Troubleshooting: docs/guides/troubleshooting.md - More: diff --git a/mkdocs/docs/guides/endpoint-harness.md b/mkdocs/docs/guides/endpoint-harness.md new file mode 100644 index 0000000000..a7e3b6df2e --- /dev/null +++ b/mkdocs/docs/guides/endpoint-harness.md @@ -0,0 +1,183 @@ +--- +title: Endpoint harness +description: Deploy inference endpoints with dstack endpoint create and the agent harness +--- + +# Endpoint harness + +The endpoint harness powers `dstack endpoint create`. +It uses an LLM to generate a [`type: service`](../concepts/services.md) configuration, +then deploys it through the same code path as [`dstack apply`](../reference/cli/dstack/apply.md). + +You describe what to deploy (model, GPU, backends, and other profile options). The harness: + +1. Asks an LLM to produce a service YAML (including container `commands`) +2. Validates and saves the configuration +3. Submits the run via dstack +4. Monitors logs and, on failure, may ask the LLM to fix the config and redeploy + +The harness does **not** pick cloud offers or provision instances. dstack's scheduler +does that after submission, the same way it does for a hand-written service config. + +??? info "Prerequisites" + - [dstack server and CLI](../installation.md) configured for your project + - At least one [fleet](../concepts/fleets.md) + - `DSTACK_HARNESS_API_KEY` set (see [LLM configuration](#llm-configuration)) + - [`skills/dstack/SKILL.md`](https://github.com/dstackai/dstack/blob/master/skills/dstack/SKILL.md) + in the project, or pass `--skill-path` + +## Quick start + +
+ +```shell +$ export DSTACK_HARNESS_API_KEY=sk-ant-... +$ export DSTACK_HARNESS_MODEL=claude-sonnet-4-8 +$ dstack endpoint create \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --gpu 24GB \ + --max-attempts 3 \ + -y +``` + +
+ +`DSTACK_HARNESS_MODEL` is optional. If unset, the harness defaults to `claude-sonnet-4-6` +for Anthropic. + +!!! note "`--max-attempts`" + Controls how many times the harness tries to deploy the endpoint. If the container + fails to start, it stops the run, asks the LLM to fix the configuration from the + error logs, and redeploys. Default is `3`. Set `--max-attempts 1` for a single + attempt with no retries. + +The command accepts the same resource and profile flags as [`dstack apply`](../reference/cli/dstack/apply.md) +for services (`--gpu`, `--cpu`, `--memory`, `--disk`, `--backend`, `--region`, `--fleet`, +`--max-price`, `--spot-policy`, and others). Run `dstack endpoint create --help` for the full list. + +## How it works + +```mermaid +flowchart TD + A[dstack endpoint create] --> B[Build EndpointCreateParams from CLI] + B --> C["LLM: generate service YAML"] + C --> D[Validate with parse_apply_configuration] + D --> E[Apply CLI overrides via ServiceConfigurator] + E --> F["Save to .dstack-harness-configs/"] + F --> I[ServiceConfigurator.apply_configuration] + I --> M[Monitor container logs] + M --> N{Ready?} + N -->|yes| O[Print service URLs] + N -->|failed| P[Stop run] + P --> Q["LLM: fix YAML from error logs"] + Q --> R{attempts left?} + R -->|yes| I + R -->|no| S[Give up] +``` + +Orchestration is **programmatic** (Python via `ServiceConfigurator`), not LLM-generated +`dstack` shell commands. The LLM only authors the service configuration and container +`commands` that run on the GPU instance. + + +## Relationship to `skills/dstack/SKILL.md` + +On every LLM call, the harness loads `skills/dstack/SKILL.md` and appends it to the system +prompt. + +## Prompts Send to LLM + +### Call 1: Generate configuration + +Fixed prefix: + +``` +You generate dstack service configuration files for model inference endpoints. + +Rules: +- Output a single valid YAML document for `type: service` +- Do not wrap the YAML in markdown unless you also include the YAML body in a fenced block +- Use only documented dstack service fields +- Put secret values only as env var names in `env`, never inline values +- Include `model`, `port`, `commands`, and `resources.gpu` when possible +- Prefer `python: "3.12"` unless the user requests a custom image +- User-provided CLI options in the request are mandatory: use the exact GPU, backends, + regions, fleets, CPU, memory, disk, and other resource/profile values given +- Do not substitute different resource sizes or backends than those specified by the user +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + + +``` + + +CLI options: + +``` +Generate a dstack service configuration for an inference endpoint. +The user passed these CLI options. You MUST use them exactly in the YAML. Do not substitute different GPU memory, backends, regions, fleets, or other resource/profile values. +{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "name": "meta-llama-3-1-8b-instruct", + "gpu": "24GB" +} + +Return only the YAML configuration. +``` + + +### Call 2: Fix configuration + +Fixed prefix: + +``` +You fix dstack service configurations that failed to start on the GPU instance. + +You are given the previous configuration and the container error logs. Return a +corrected single YAML document for `type: service`. + +Rules: +- Change as little as possible to address the specific error in the logs +- Keep `model`, `name`, and `resources` unless the error requires changing them +- For vLLM KV-cache / out-of-memory errors, prefer adding serve flags such as + `--max-model-len` or `--gpu-memory-utilization` rather than changing the GPU +- Keep secret values as env var names only, never inline values +- Use only documented dstack fields and valid serving CLI flags +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + + +``` + + + +Error logs: + +```` +The following dstack service configuration failed to start: +```yaml +type: service +name: meta-llama-3-1-8b-instruct +model: meta-llama/Llama-3.1-8B-Instruct +python: "3.12" +port: 8000 +commands: + - | + pip install vllm + vllm serve meta-llama/Llama-3.1-8B-Instruct \ + --host 0.0.0.0 \ + --port 8000 +resources: + gpu: 24GB:1 +``` + +Container error logs (tail): +``` +... +torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate ... +``` + +Return only the corrected YAML configuration. +```` diff --git a/skills/dstack/SKILL.md b/skills/dstack/SKILL.md index 4df1843b75..b434ea6c86 100644 --- a/skills/dstack/SKILL.md +++ b/skills/dstack/SKILL.md @@ -237,10 +237,17 @@ port: 8000 model: meta-llama/Meta-Llama-3.1-8B-Instruct resources: - gpu: 80GB + gpu: 24GB disk: 200GB ``` +**GPU sizing rule** + +If `--gpu` is not provided: +- 7B/8B models -> `gpu: 24GB` +- 13B/14B models -> `gpu: 40GB` or `48GB` +- 30B+ models -> `gpu: 80GB` + **Service endpoints:** - Without gateway: `/proxy/services///` - With gateway: `https://./` diff --git a/src/dstack/_internal/cli/commands/endpoint.py b/src/dstack/_internal/cli/commands/endpoint.py new file mode 100644 index 0000000000..758a3cdac0 --- /dev/null +++ b/src/dstack/_internal/cli/commands/endpoint.py @@ -0,0 +1,147 @@ +import argparse +import shlex +from typing import cast + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.configurators.run import ServiceConfigurator +from dstack._internal.cli.utils.common import console +from dstack._internal.core.errors import CLIError +from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.harness import ( + EndpointCreateParams, + deploy_service_configuration, + deploy_service_with_self_healing, +) +from dstack._internal.harness.generator import ( + generate_service_configuration, + save_service_configuration, +) + + +class EndpointCommand(APIBaseCommand): + NAME = "endpoint" + DESCRIPTION = "Manage inference endpoints" + ACCEPT_EXTRA_ARGS = True + + def _register(self): + super()._register() + self._parser.set_defaults(subfunc=self._print_help) + subparsers = self._parser.add_subparsers(dest="action") + + create_parser = subparsers.add_parser( + "create", + help="Create an inference endpoint", + formatter_class=self._parser.formatter_class, + ) + create_parser.add_argument( + "--model", + required=True, + metavar="NAME", + help="The model to deploy", + ) + create_parser.add_argument( + "--skill-path", + metavar="PATH", + help="Path to [code]skills/dstack/SKILL.md[/]. Defaults to project skill file", + ) + create_parser.add_argument( + "--dry-run", + action="store_true", + help="Generate and save the configuration without deploying", + ) + create_parser.add_argument( + "-y", + "--yes", + help="Do not ask for confirmation", + action="store_true", + ) + create_parser.add_argument( + "-d", + "--detach", + help="Exit immediately after submitting instead of streaming container logs", + action="store_true", + ) + create_parser.add_argument( + "-v", + "--verbose", + help="Show all plan properties including those with default values", + action="store_true", + ) + create_parser.add_argument( + "--force", + help="Force apply when no changes detected", + action="store_true", + ) + create_parser.add_argument( + "--max-attempts", + type=int, + default=3, + metavar="N", + help=( + "Max deploy attempts. On container failure, the harness stops the run," + " asks the model to fix the configuration from the error logs, and redeploys." + " Set to 1 to disable self-healing" + ), + ) + ServiceConfigurator.register_args(create_parser) + create_parser.set_defaults(subfunc=self._create) + + def _command(self, args: argparse.Namespace): + super()._command(args) + args.subfunc(args) + + def _print_help(self, args: argparse.Namespace): + self._parser.print_help() + + def _create(self, args: argparse.Namespace): + configurator_parser = ServiceConfigurator.get_parser() + _, unknown_args = configurator_parser.parse_known_args(args.extra_args) + if unknown_args: + raise CLIError(f"Unrecognized arguments: {shlex.join(unknown_args)}") + + params = EndpointCreateParams.from_namespace(args, model=args.model) + + with console.status("Generating service configuration..."): + configuration = generate_service_configuration( + params=params, + skill_path=args.skill_path, + ) + + configurator = ServiceConfigurator(api_client=self.api) + configurator.apply_args(cast(TaskConfiguration, configuration), args) + configuration.model = args.model + + config_path = save_service_configuration(configuration) + console.print(f"Saved configuration to [code]{config_path}[/]") + + if args.dry_run: + console.print("Dry run complete. Skipping deployment.") + return + + apply_args = argparse.Namespace( + yes=args.yes, + detach=args.detach, + verbose=args.verbose, + force=args.force, + ) + + if args.detach: + deploy_service_configuration( + api=self.api, + configuration=configuration, + configuration_path=config_path, + command_args=apply_args, + configurator_args=args, + ) + return + + deploy_service_with_self_healing( + api=self.api, + configuration=configuration, + params=params, + configuration_path=config_path, + command_args=apply_args, + configurator_args=args, + skill_path=args.skill_path, + max_attempts=args.max_attempts, + ) diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 32f15a95f8..335f7693f1 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -8,6 +8,7 @@ from dstack._internal.cli.commands.attach import AttachCommand from dstack._internal.cli.commands.completion import CompletionCommand from dstack._internal.cli.commands.delete import DeleteCommand +from dstack._internal.cli.commands.endpoint import EndpointCommand from dstack._internal.cli.commands.event import EventCommand from dstack._internal.cli.commands.export import ExportCommand from dstack._internal.cli.commands.fleet import FleetCommand @@ -67,6 +68,7 @@ def main(): ApplyCommand.register(subparsers) AttachCommand.register(subparsers) DeleteCommand.register(subparsers) + EndpointCommand.register(subparsers) EventCommand.register(subparsers) ExportCommand.register(subparsers) FleetCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 16b0f0a87b..adf2f14265 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -203,6 +203,10 @@ def apply_configuration( console.print(detach_message) return + pre_attach_hook = getattr(command_args, "pre_attach_hook", None) + if pre_attach_hook is not None: + pre_attach_hook(run.name) + abort_at_exit = False try: # We can attach to run multiple times if it goes from running to pending (retried). diff --git a/src/dstack/_internal/harness/__init__.py b/src/dstack/_internal/harness/__init__.py new file mode 100644 index 0000000000..3f069be604 --- /dev/null +++ b/src/dstack/_internal/harness/__init__.py @@ -0,0 +1,20 @@ +from dstack._internal.harness.deployer import ( + deploy_service_configuration, + deploy_service_with_self_healing, +) +from dstack._internal.harness.generator import ( + generate_service_configuration, + regenerate_service_configuration, + save_service_configuration, +) +from dstack._internal.harness.models import EndpointCreateParams, default_endpoint_name + +__all__ = [ + "EndpointCreateParams", + "default_endpoint_name", + "deploy_service_configuration", + "deploy_service_with_self_healing", + "generate_service_configuration", + "regenerate_service_configuration", + "save_service_configuration", +] diff --git a/src/dstack/_internal/harness/deployer.py b/src/dstack/_internal/harness/deployer.py new file mode 100644 index 0000000000..6d814fc7f8 --- /dev/null +++ b/src/dstack/_internal/harness/deployer.py @@ -0,0 +1,437 @@ +import argparse +import base64 +import sys +import time +from collections import deque +from datetime import timedelta +from enum import Enum +from pathlib import Path +from typing import Optional, cast + +from dstack._internal.cli.services.configurators.run import ( + ServiceConfigurator, + _print_service_urls, +) +from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.core.models.configurations import ServiceConfiguration, TaskConfiguration +from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.harness.generator import ( + regenerate_service_configuration, + save_service_configuration, +) +from dstack._internal.harness.models import EndpointCreateParams +from dstack._internal.server.schemas.logs import PollLogsRequest +from dstack.api._public import Client + +# Markers that indicate the model server finished starting and is serving. +READY_PATTERNS = ( + "Application startup complete", + "Uvicorn running on", + "The server is fired up and ready to roll", # SGLang + "Connected", # TGI +) +# Markers that indicate a fatal container error before the server became ready. +FATAL_PATTERNS = ( + "Traceback (most recent call last)", + "Engine core initialization failed", + "EngineCore failed to start", + "RuntimeError", + "CUDA out of memory", + "torch.cuda.OutOfMemoryError", +) + +DEFAULT_MONITOR_TIMEOUT_SECS = 1800 +MONITOR_POLL_INTERVAL_SECS = 3 +MAX_ERROR_LOG_CHARS = 6000 + + +class _Outcome(Enum): + SUCCESS = "success" + FAILED = "failed" + TIMEOUT = "timeout" + + +def deploy_service_configuration( + api: Client, + configuration: ServiceConfiguration, + configuration_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, +) -> None: + """Submit a service detached and report status (no self-healing).""" + _submit_detached(api, configuration, configuration_path, command_args, configurator_args) + if configuration.name: + _wait_for_service_and_report(api, configuration.name) + + +def deploy_service_with_self_healing( + api: Client, + configuration: ServiceConfiguration, + params: EndpointCreateParams, + configuration_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, + skill_path: Optional[str] = None, + max_attempts: int = 3, + monitor_timeout_secs: int = DEFAULT_MONITOR_TIMEOUT_SECS, +) -> None: + """Submit, monitor logs, and on failure stop, ask the LLM to fix, and redeploy.""" + config_path = configuration_path + try: + _run_self_healing_loop( + api=api, + configuration=configuration, + params=params, + config_path=config_path, + command_args=command_args, + configurator_args=configurator_args, + skill_path=skill_path, + max_attempts=max_attempts, + monitor_timeout_secs=monitor_timeout_secs, + ) + except KeyboardInterrupt: + _handle_detach_on_interrupt(api, configuration.name, command_args.yes) + + +def _print_monitoring_message(run_name: str) -> None: + console.print(f"[code]\\[harness][/] Monitoring logs for [code]{run_name}[/]...") + + +def _run_self_healing_loop( + api: Client, + configuration: ServiceConfiguration, + params: EndpointCreateParams, + config_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, + skill_path: Optional[str], + max_attempts: int, + monitor_timeout_secs: int, +) -> None: + detach = command_args.detach + for attempt in range(1, max_attempts + 1): + if attempt > 1: + console.print( + f"\n[code]\\[harness][/] Attempt {attempt}/{max_attempts}:" + " redeploying updated configuration..." + ) + + if not configuration.name: + return + + # Auto-confirm on retries; the user already approved the first plan. + submit_args = argparse.Namespace( + yes=command_args.yes or attempt > 1, + detach=detach, + verbose=command_args.verbose, + force=command_args.force, + pre_attach_hook=_print_monitoring_message if not detach else None, + ) + token_before = _submission_token(api, configuration.name) + configurator = ServiceConfigurator(api_client=api) + try: + configurator.apply_configuration( + conf=configuration, + configuration_path=str(config_path), + command_args=submit_args, + configurator_args=configurator_args, + ) + except SystemExit as e: + if detach or e.code in (0, None): + raise + if attempt == max_attempts: + console.print( + f"[code]\\[harness][/] Reached the maximum of {max_attempts} attempts." + " Giving up. See the logs above for the last error." + ) + raise + console.print( + f"[code]\\[harness][/] Detected a failure on attempt {attempt}." + f" Stopping run [code]{configuration.name}[/]..." + ) + error_logs = _fetch_recent_logs(api, configuration.name) + # Stop the failed run so the next attempt is a clean fresh deployment. + # Otherwise dstack treats the redeploy as an in-place rolling update of + # the still-active service, which breaks the attached log stream. + _stop_run(api, configuration.name) + console.print( + "[code]\\[harness][/] Asking the model to fix the configuration based on the error..." + ) + configuration, config_path = _regenerate_configuration( + api=api, + configuration=configuration, + params=params, + config_path=config_path, + configurator_args=configurator_args, + error_logs=error_logs, + skill_path=skill_path, + ) + continue + + # If no new submission was created, the user declined the plan (or there + # was nothing to apply). That is a clean exit, not a deployment failure. + token_after = _submission_token(api, configuration.name) + if token_after is None or token_after == token_before: + return + + if detach: + console.print( + f"[code]\\[harness][/] Monitoring logs for [code]{configuration.name}[/]..." + ) + outcome, error_logs = _monitor_run(api, configuration.name, monitor_timeout_secs) + else: + # Attached apply streams logs via dstack until the user detaches or the run ends. + return + + if outcome is _Outcome.SUCCESS: + run = api.runs.get(configuration.name) + if run is not None: + run.refresh() + _print_service_urls(run) + console.print( + f"[code]\\[harness][/] Endpoint [code]{configuration.name}[/] is up and serving." + ) + return + + if outcome is _Outcome.TIMEOUT: + console.print( + f"[code]\\[harness][/] Timed out waiting for [code]{configuration.name}[/]" + " to become ready. Stopping the run." + ) + _stop_run(api, configuration.name) + return + + console.print( + f"[code]\\[harness][/] Detected a failure on attempt {attempt}." + f" Stopping run [code]{configuration.name}[/]..." + ) + _stop_run(api, configuration.name) + + if attempt == max_attempts: + console.print( + f"[code]\\[harness][/] Reached the maximum of {max_attempts} attempts." + " Giving up. See the logs above for the last error." + ) + return + + console.print( + "[code]\\[harness][/] Asking the model to fix the configuration based on the error..." + ) + configuration, config_path = _regenerate_configuration( + api=api, + configuration=configuration, + params=params, + config_path=config_path, + configurator_args=configurator_args, + error_logs=error_logs, + skill_path=skill_path, + ) + + +def _regenerate_configuration( + api: Client, + configuration: ServiceConfiguration, + params: EndpointCreateParams, + config_path: Path, + configurator_args: argparse.Namespace, + error_logs: str, + skill_path: Optional[str], +) -> tuple[ServiceConfiguration, Path]: + previous_yaml = config_path.read_text(encoding="utf-8") + configuration = regenerate_service_configuration( + params=params, + previous_yaml=previous_yaml, + error_logs=error_logs, + skill_path=skill_path, + ) + ServiceConfigurator(api_client=api).apply_args( + cast(TaskConfiguration, configuration), configurator_args + ) + config_path = save_service_configuration(configuration) + console.print(f"[code]\\[harness][/] Saved updated configuration to [code]{config_path}[/]") + return configuration, config_path + + +def _fetch_recent_logs(api: Client, run_name: str) -> str: + run = api.runs.get(run_name) + if run is None: + return "" + try: + submission = run._run.jobs[0].job_submissions[-1] + except (AttributeError, IndexError): + return "" + events = _poll_new_logs(api, run_name, submission.id, None) + text = "".join(base64.b64decode(event.message).decode(errors="replace") for event in events) + return _truncate_logs(text) + + +def _truncate_logs(text: str) -> str: + if len(text) > MAX_ERROR_LOG_CHARS: + return text[-MAX_ERROR_LOG_CHARS:] + return text + + +def _handle_detach_on_interrupt(api: Client, run_name: Optional[str], yes: bool) -> None: + if not run_name: + return + run = api.runs.get(run_name) + if run is None or run.status.is_finished(): + return + try: + if yes or not confirm_ask(f"\nStop the run [code]{run_name}[/] before detaching?"): + console.print("Detached") + return + with console.status("Stopping..."): + api.client.runs.stop(api.project, [run_name], False) + while True: + run = api.runs.get(run_name) + if run is None or run.status.is_finished(): + break + time.sleep(2) + console.print("Stopped") + except KeyboardInterrupt: + with console.status("Aborting..."): + api.client.runs.stop(api.project, [run_name], True) + console.print("[error]Aborted[/]") + + +def _submit_detached( + api: Client, + configuration: ServiceConfiguration, + configuration_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, +) -> None: + configurator = ServiceConfigurator(api_client=api) + submit_args = argparse.Namespace( + yes=command_args.yes, + detach=True, + verbose=command_args.verbose, + force=command_args.force, + ) + configurator.apply_configuration( + conf=configuration, + configuration_path=str(configuration_path), + command_args=submit_args, + configurator_args=configurator_args, + ) + + +def _submission_token(api: Client, run_name: str) -> Optional[tuple]: + """Identify the latest run submission so we can tell if a new run was submitted. + + Returns None when the run does not exist. When `apply` is declined or makes no + change, the token is unchanged; a fresh submission produces a different token. + """ + run = api.runs.get(run_name) + if run is None: + return None + run_id = getattr(run._run, "id", None) + try: + submission = run._run.jobs[0].job_submissions[-1] + except (AttributeError, IndexError): + return (str(run_id), None, None) + return (str(run_id), str(submission.id), submission.deployment_num) + + +def _monitor_run(api: Client, run_name: str, timeout_secs: int) -> tuple[_Outcome, str]: + deadline = time.monotonic() + timeout_secs + log_tail: deque[str] = deque(maxlen=300) + last_timestamp = None + ready_seen = False + fatal_seen = False + + while time.monotonic() < deadline: + run = api.runs.get(run_name) + if run is None: + return _Outcome.FAILED, _format_tail(log_tail) + + submission = run._run.jobs[0].job_submissions[-1] + events = _poll_new_logs(api, run_name, submission.id, last_timestamp) + for event in events: + text = base64.b64decode(event.message).decode(errors="replace") + sys.stdout.write(text) + sys.stdout.flush() + log_tail.append(text) + if any(pattern in text for pattern in READY_PATTERNS): + ready_seen = True + if any(pattern in text for pattern in FATAL_PATTERNS): + fatal_seen = True + last_timestamp = event.timestamp + + if run.status in (RunStatus.FAILED, RunStatus.TERMINATED) or ( + submission.status == JobStatus.FAILED + ): + return _Outcome.FAILED, _format_tail(log_tail) + + if ready_seen and submission.status == JobStatus.RUNNING: + return _Outcome.SUCCESS, _format_tail(log_tail) + + if fatal_seen and not ready_seen: + return _Outcome.FAILED, _format_tail(log_tail) + + time.sleep(MONITOR_POLL_INTERVAL_SECS) + + return _Outcome.TIMEOUT, _format_tail(log_tail) + + +def _poll_new_logs(api: Client, run_name: str, submission_id, start_time): + if start_time is not None: + start_time = start_time + timedelta(microseconds=1) + events = [] + next_token = None + while True: + resp = api.client.logs.poll( + project_name=api.project, + body=PollLogsRequest( + run_name=run_name, + job_submission_id=submission_id, + start_time=start_time, + end_time=None, + descending=False, + limit=1000, + diagnose=False, + next_token=next_token, + ), + ) + events.extend(resp.logs) + next_token = resp.next_token + if next_token is None: + break + return events + + +def _format_tail(log_tail: deque) -> str: + return _truncate_logs("".join(log_tail)) + + +def _stop_run(api: Client, run_name: str) -> None: + with console.status(f"Stopping {run_name}..."): + api.client.runs.stop(api.project, [run_name], False) + while True: + run = api.runs.get(run_name) + if run is None or run.status.is_finished(): + break + time.sleep(2) + console.print(f"[code]\\[harness][/] Stopped [code]{run_name}[/]") + + +def _wait_for_service_and_report(api: Client, run_name: str, attempts: int = 30) -> None: + for _ in range(attempts): + run = api.runs.get(run_name) + if run is None: + return + if run.status in (RunStatus.RUNNING, RunStatus.DONE, RunStatus.FAILED): + console.print() + run.refresh() + _print_service_urls(run) + if run.status == RunStatus.FAILED: + console.print( + f"[error]Run [code]{run_name}[/] failed. Check [code]dstack logs {run_name}[/]." + ) + return + time.sleep(2) + + console.print( + f"Run [code]{run_name}[/] is still provisioning. Check status with [code]dstack ps -v[/]." + ) diff --git a/src/dstack/_internal/harness/generator.py b/src/dstack/_internal/harness/generator.py new file mode 100644 index 0000000000..aff5ef2b36 --- /dev/null +++ b/src/dstack/_internal/harness/generator.py @@ -0,0 +1,191 @@ +import json +import re +from pathlib import Path +from typing import Optional + +import yaml + +from dstack._internal.core.errors import CLIError, ConfigurationError +from dstack._internal.core.models.configurations import ( + ApplyConfigurationType, + ServiceConfiguration, + parse_apply_configuration, +) +from dstack._internal.core.models.envs import Env +from dstack._internal.harness.llm import HarnessLLMClient +from dstack._internal.harness.models import EndpointCreateParams +from dstack._internal.harness.skill import load_skill_content + +HARNESS_CONFIGS_DIR = Path(".dstack-harness-configs") + +SYSTEM_PROMPT_PREFIX = """\ +You generate dstack service configuration files for model inference endpoints. + +Rules: +- Output a single valid YAML document for `type: service` +- Do not wrap the YAML in markdown unless you also include the YAML body in a fenced block +- Use only documented dstack service fields +- Put secret values only as env var names in `env`, never inline values +- Include `model`, `port`, `commands`, and `resources.gpu` when possible +- Prefer `python: "3.12"` unless the user requests a custom image +- User-provided CLI options in the request are mandatory: use the exact GPU, backends, + regions, fleets, CPU, memory, disk, and other resource/profile values given +- Do not substitute different resource sizes or backends than those specified by the user +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + +""" + +FIX_SYSTEM_PROMPT_PREFIX = """\ +You fix dstack service configurations that failed to start on the GPU instance. + +You are given the previous configuration and the container error logs. Return a +corrected single YAML document for `type: service`. + +Rules: +- Change as little as possible to address the specific error in the logs +- Keep `model`, `name`, and `resources` unless the error requires changing them +- For vLLM KV-cache / out-of-memory errors, prefer adding serve flags such as + `--max-model-len` or `--gpu-memory-utilization` rather than changing the GPU +- Keep secret values as env var names only, never inline values +- Use only documented dstack fields and valid serving CLI flags +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + +""" + + +def _extract_yaml(text: str) -> str: + fenced = re.search(r"```(?:ya?ml)?\s*\n(.*?)```", text, flags=re.DOTALL | re.IGNORECASE) + if fenced: + return fenced.group(1).strip() + + stripped = text.strip() + if stripped.startswith("type:") or stripped.startswith("name:"): + return stripped + + raise CLIError("Harness LLM response did not contain YAML configuration") + + +def _normalize_env_names(configuration: ServiceConfiguration) -> None: + configuration.env = Env.parse_obj(list(configuration.env)) + + +def _validate_service_configuration(configuration: ServiceConfiguration) -> ServiceConfiguration: + if configuration.type != ApplyConfigurationType.SERVICE.value: + raise CLIError("Generated configuration must have [code]type: service[/]") + if configuration.model is None: + raise CLIError("Generated configuration must include a [code]model[/] field") + _normalize_env_names(configuration) + return configuration + + +def parse_service_yaml(yaml_text: str) -> ServiceConfiguration: + try: + data = yaml.safe_load(yaml_text) + except yaml.YAMLError as e: + raise CLIError(f"Generated YAML is invalid: {e}") from e + if not isinstance(data, dict): + raise CLIError("Generated YAML must be a mapping") + + try: + configuration = parse_apply_configuration(data) + except ConfigurationError as e: + raise CLIError(f"Generated configuration is invalid: {e}") from e + + if not isinstance(configuration, ServiceConfiguration): + raise CLIError("Generated configuration must be a service configuration") + return _validate_service_configuration(configuration) + + +def build_user_prompt(params: EndpointCreateParams) -> str: + return ( + "Generate a dstack service configuration for an inference endpoint.\n" + "The user passed these CLI options. You MUST use them exactly in the YAML." + " Do not substitute different GPU memory, backends, regions, fleets," + " or other resource/profile values.\n" + f"{json.dumps(params.cli_options(), indent=2, default=str)}\n\n" + "Return only the YAML configuration." + ) + + +def build_fix_prompt(params: EndpointCreateParams, previous_yaml: str, error_logs: str) -> str: + return ( + "The following dstack service configuration failed to start:\n" + f"```yaml\n{previous_yaml}\n```\n\n" + "Container error logs (tail):\n" + f"```\n{error_logs}\n```\n\n" + "Return only the corrected YAML configuration." + ) + + +def _apply_param_overrides( + configuration: ServiceConfiguration, params: EndpointCreateParams +) -> None: + if params.name: + configuration.name = params.name + if params.model: + configuration.model = params.model + + +def generate_service_configuration( + params: EndpointCreateParams, + skill_path: Optional[str] = None, + llm_client: Optional[HarnessLLMClient] = None, +) -> ServiceConfiguration: + skill_content = load_skill_content(skill_path) + client = llm_client or HarnessLLMClient() + response = client.chat( + system_prompt=SYSTEM_PROMPT_PREFIX + skill_content, + user_prompt=build_user_prompt(params), + ) + configuration = parse_service_yaml(_extract_yaml(response)) + _apply_param_overrides(configuration, params) + return configuration + + +def regenerate_service_configuration( + params: EndpointCreateParams, + previous_yaml: str, + error_logs: str, + skill_path: Optional[str] = None, + llm_client: Optional[HarnessLLMClient] = None, +) -> ServiceConfiguration: + skill_content = load_skill_content(skill_path) + client = llm_client or HarnessLLMClient() + response = client.chat( + system_prompt=FIX_SYSTEM_PROMPT_PREFIX + skill_content, + user_prompt=build_fix_prompt(params, previous_yaml, error_logs), + ) + configuration = parse_service_yaml(_extract_yaml(response)) + _apply_param_overrides(configuration, params) + return configuration + + +def get_endpoint_path(name: str) -> Path: + HARNESS_CONFIGS_DIR.mkdir(parents=True, exist_ok=True) + return HARNESS_CONFIGS_DIR / f"{name}.dstack.yml" + + +def save_service_configuration(configuration: ServiceConfiguration) -> Path: + if not configuration.name: + raise CLIError("Generated configuration must include a [code]name[/]") + + config_path = get_endpoint_path(configuration.name) + # Round-trip through JSON so enums and other rich types become plain + # primitives that yaml.safe_dump can represent. + config_dict = json.loads(configuration.json(exclude_none=True)) + # Never persist secret values to disk: env is always written as names only, + # even if values were resolved from the environment earlier in the flow. + env_names = list(configuration.env) + if env_names: + config_dict["env"] = env_names + else: + config_dict.pop("env", None) + config_path.write_text( + yaml.safe_dump(config_dict, sort_keys=False), + encoding="utf-8", + ) + return config_path diff --git a/src/dstack/_internal/harness/llm.py b/src/dstack/_internal/harness/llm.py new file mode 100644 index 0000000000..9eafdd265d --- /dev/null +++ b/src/dstack/_internal/harness/llm.py @@ -0,0 +1,173 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import requests + +from dstack._internal.cli.utils.common import console +from dstack._internal.core.errors import CLIError + +REQUEST_TIMEOUT_SECS = 120 +DEFAULT_MAX_TOKENS = 4096 + +PROVIDER_ANTHROPIC = "anthropic" +PROVIDER_OPENAI = "openai" + +DEFAULTS = { + PROVIDER_ANTHROPIC: { + "base_url": "https://api.anthropic.com/v1", + "model": "claude-sonnet-4-6", + }, + PROVIDER_OPENAI: { + "base_url": "https://api.openai.com/v1", + "model": "gpt-4o-mini", + }, +} +ANTHROPIC_VERSION = "2023-06-01" + + +@dataclass(frozen=True) +class LLMUsage: + input_tokens: int + output_tokens: int + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +class HarnessLLMClient: + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: Optional[str] = None, + provider: Optional[str] = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + ): + self.api_key = api_key or os.getenv("DSTACK_HARNESS_API_KEY") + if not self.api_key: + raise CLIError( + "DSTACK_HARNESS_API_KEY is not set." + " Export it before running [code]dstack endpoint create[/]." + ) + self.provider = ( + provider or os.getenv("DSTACK_HARNESS_PROVIDER") or PROVIDER_ANTHROPIC + ).lower() + if self.provider not in DEFAULTS: + raise CLIError( + f"Unsupported harness provider [code]{self.provider}[/]." + f" Supported: {', '.join(DEFAULTS)}." + ) + defaults = DEFAULTS[self.provider] + self.base_url = ( + base_url or os.getenv("DSTACK_HARNESS_BASE_URL") or defaults["base_url"] + ).rstrip("/") + self.model = model or os.getenv("DSTACK_HARNESS_MODEL") or defaults["model"] + self.max_tokens = max_tokens + + def chat(self, system_prompt: str, user_prompt: str) -> str: + if self.provider == PROVIDER_ANTHROPIC: + return self._chat_anthropic(system_prompt, user_prompt) + return self._chat_openai(system_prompt, user_prompt) + + def _chat_anthropic(self, system_prompt: str, user_prompt: str) -> str: + url = f"{self.base_url}/messages" + payload = { + "model": self.model, + "max_tokens": self.max_tokens, + "system": system_prompt, + "messages": [{"role": "user", "content": user_prompt}], + } + headers = { + "x-api-key": self.api_key, + "anthropic-version": ANTHROPIC_VERSION, + "content-type": "application/json", + } + data = self._post(url, payload, headers) + self._print_usage(_parse_anthropic_usage(data)) + try: + return data["content"][0]["text"] + except (KeyError, IndexError, TypeError) as e: + raise CLIError(f"Unexpected harness LLM response: {data}") from e + + def _chat_openai(self, system_prompt: str, user_prompt: str) -> str: + url = f"{self.base_url}/chat/completions" + payload = { + "model": self.model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "temperature": 0, + } + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = self._post(url, payload, headers) + self._print_usage(_parse_openai_usage(data)) + try: + return data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as e: + raise CLIError(f"Unexpected harness LLM response: {data}") from e + + def _print_usage(self, usage: Optional[LLMUsage]) -> None: + provider_label = self.provider.capitalize() + console.print(f"[code]\\[harness][/] LLM Provider: {provider_label}") + console.print(f"[code]\\[harness][/] LLM Model: {self.model}") + if usage is None: + return + console.print( + "[code]\\[harness][/] LLM tokens:" + f" input={usage.input_tokens}," + f" output={usage.output_tokens}," + f" total={usage.total_tokens}" + ) + + def _post(self, url: str, payload: dict, headers: dict) -> dict: + try: + response = requests.post( + url, + json=payload, + headers=headers, + timeout=REQUEST_TIMEOUT_SECS, + ) + except requests.RequestException as e: + raise CLIError(f"Failed to call harness LLM: {e}") from e + + if response.status_code >= 400: + raise CLIError( + f"Harness LLM request failed with status {response.status_code}: {response.text}" + ) + + try: + return response.json() + except ValueError as e: + raise CLIError(f"Unexpected harness LLM response: {response.text}") from e + + +def _parse_anthropic_usage(data: dict) -> Optional[LLMUsage]: + usage = data.get("usage") + if not isinstance(usage, dict): + return None + try: + return LLMUsage( + input_tokens=int(usage["input_tokens"]), + output_tokens=int(usage["output_tokens"]), + ) + except (KeyError, TypeError, ValueError): + return None + + +def _parse_openai_usage(data: dict) -> Optional[LLMUsage]: + usage = data.get("usage") + if not isinstance(usage, dict): + return None + try: + return LLMUsage( + input_tokens=int(usage["prompt_tokens"]), + output_tokens=int(usage["completion_tokens"]), + ) + except (KeyError, TypeError, ValueError): + return None diff --git a/src/dstack/_internal/harness/models.py b/src/dstack/_internal/harness/models.py new file mode 100644 index 0000000000..0bb70815d8 --- /dev/null +++ b/src/dstack/_internal/harness/models.py @@ -0,0 +1,59 @@ +import argparse +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional + +from dstack._internal.core.errors import CLIError + + +def default_endpoint_name(model: str) -> str: + """Derive a stable run name from a model id (e.g. org/model -> model, normalized).""" + base = model.rsplit("/", 1)[-1] + name = base.lower().replace(".", "-") + name = re.sub(r"[^a-z0-9-]+", "-", name) + name = re.sub(r"-+", "-", name).strip("-") + if not name: + raise CLIError(f"Cannot derive an endpoint name from model [code]{model}[/]") + return name + + +@dataclass +class EndpointCreateParams: + model: str + name: Optional[str] = None + gpu: Optional[Any] = None + cpu: Optional[Any] = None + memory: Optional[Any] = None + disk: Optional[Any] = None + backends: List[str] = field(default_factory=list) + regions: List[str] = field(default_factory=list) + instance_types: List[str] = field(default_factory=list) + fleets: List[str] = field(default_factory=list) + max_price: Optional[float] = None + max_duration: Optional[int] = None + spot_policy: Optional[str] = None + env_vars: List[str] = field(default_factory=list) + + @classmethod + def from_namespace(cls, args: argparse.Namespace, model: str) -> "EndpointCreateParams": + spot_policy = getattr(args, "spot_policy", None) + return cls( + model=model, + name=getattr(args, "run_name", None) or default_endpoint_name(model), + gpu=getattr(args, "gpu_spec", None), + cpu=getattr(args, "cpu_spec", None), + memory=getattr(args, "memory_spec", None), + disk=getattr(args, "disk_spec", None), + backends=getattr(args, "backends", None) or [], + regions=getattr(args, "regions", None) or [], + instance_types=getattr(args, "instance_types", None) or [], + fleets=getattr(args, "fleets", None) or [], + max_price=getattr(args, "max_price", None), + max_duration=getattr(args, "max_duration", None), + spot_policy=spot_policy.value if spot_policy is not None else None, + env_vars=[item.key for item in getattr(args, "env_vars", [])], + ) + + def cli_options(self) -> Dict[str, Any]: + options = asdict(self) + return {key: value for key, value in options.items() if value not in (None, [], {})} diff --git a/src/dstack/_internal/harness/skill.py b/src/dstack/_internal/harness/skill.py new file mode 100644 index 0000000000..2e16533dd7 --- /dev/null +++ b/src/dstack/_internal/harness/skill.py @@ -0,0 +1,31 @@ +from pathlib import Path +from typing import Optional + +from dstack._internal.core.errors import CLIError + +DEFAULT_SKILL_RELATIVE_PATH = Path("skills/dstack/SKILL.md") + + +def find_skill_path(skill_path: Optional[str] = None) -> Path: + if skill_path is not None: + path = Path(skill_path) + if not path.is_file(): + raise CLIError(f"Skill file not found: {skill_path}") + return path + + candidates = [ + Path.cwd() / DEFAULT_SKILL_RELATIVE_PATH, + Path(__file__).resolve().parents[4] / DEFAULT_SKILL_RELATIVE_PATH, + ] + for path in candidates: + if path.is_file(): + return path + + raise CLIError( + "dstack skill not found. Expected " + f"[code]{DEFAULT_SKILL_RELATIVE_PATH}[/] in the current directory." + ) + + +def load_skill_content(skill_path: Optional[str] = None) -> str: + return find_skill_path(skill_path).read_text(encoding="utf-8") diff --git a/src/tests/_internal/cli/commands/test_endpoint.py b/src/tests/_internal/cli/commands/test_endpoint.py new file mode 100644 index 0000000000..85a3b12494 --- /dev/null +++ b/src/tests/_internal/cli/commands/test_endpoint.py @@ -0,0 +1,21 @@ +from pytest import CaptureFixture + +from tests._internal.cli.common import run_dstack_cli + + +class TestEndpointCommand: + def test_create_help(self, capsys: CaptureFixture): + exit_code = run_dstack_cli(["endpoint", "create", "--help"]) + assert exit_code == 0 + output = capsys.readouterr().out + assert "--model" in output + assert "--gpu" in output + assert "--backend" in output + + def test_requires_harness_api_key(self, capsys: CaptureFixture, monkeypatch): + monkeypatch.delenv("DSTACK_HARNESS_API_KEY", raising=False) + exit_code = run_dstack_cli( + ["endpoint", "create", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--dry-run"] + ) + assert exit_code == 1 + assert "DSTACK_HARNESS_API_KEY" in capsys.readouterr().out diff --git a/src/tests/_internal/harness/test_deployer.py b/src/tests/_internal/harness/test_deployer.py new file mode 100644 index 0000000000..4ba18dda24 --- /dev/null +++ b/src/tests/_internal/harness/test_deployer.py @@ -0,0 +1,302 @@ +import argparse +import base64 +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.harness import deployer +from dstack._internal.harness.deployer import ( + _handle_detach_on_interrupt, + _monitor_run, + _Outcome, + _run_self_healing_loop, + _submission_token, +) +from dstack._internal.harness.models import EndpointCreateParams + + +def _log_event(message: str, ts: datetime) -> MagicMock: + event = MagicMock() + event.message = base64.b64encode(message.encode()).decode() + event.timestamp = ts + return event + + +def _make_api(statuses, logs_per_poll): + """Build a fake Client whose run status and logs change per monitor iteration.""" + api = MagicMock() + api.project = "main" + + state = {"i": 0} + + def get_run(_name): + idx = min(state["i"], len(statuses) - 1) + run_status, job_status = statuses[idx] + run = MagicMock() + run.status = run_status + submission = MagicMock() + submission.status = job_status + submission.id = "11111111-1111-4111-8111-111111111111" + run._run.jobs = [MagicMock(job_submissions=[submission])] + return run + + api.runs.get.side_effect = get_run + + def poll(project_name, body): + idx = min(state["i"], len(logs_per_poll) - 1) + resp = MagicMock() + resp.logs = logs_per_poll[idx] + resp.next_token = None + state["i"] += 1 + return resp + + api.client.logs.poll.side_effect = poll + return api + + +class TestMonitorRun: + def test_detects_success_on_ready_marker(self, monkeypatch): + monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) + ts = datetime(2026, 1, 1, tzinfo=timezone.utc) + statuses = [ + (RunStatus.RUNNING, JobStatus.RUNNING), + (RunStatus.RUNNING, JobStatus.RUNNING), + ] + logs = [ + [_log_event("Loading model...\n", ts)], + [_log_event("INFO: Application startup complete\n", ts)], + ] + api = _make_api(statuses, logs) + + outcome, _ = _monitor_run(api, "svc", timeout_secs=60) + assert outcome is _Outcome.SUCCESS + + def test_detects_failure_on_failed_status(self, monkeypatch): + monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) + ts = datetime(2026, 1, 1, tzinfo=timezone.utc) + statuses = [(RunStatus.FAILED, JobStatus.FAILED)] + logs = [[_log_event("ValueError: KV cache\n", ts)]] + api = _make_api(statuses, logs) + + outcome, error_logs = _monitor_run(api, "svc", timeout_secs=60) + assert outcome is _Outcome.FAILED + assert "KV cache" in error_logs + + def test_detects_failure_on_fatal_log_pattern(self, monkeypatch): + monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) + ts = datetime(2026, 1, 1, tzinfo=timezone.utc) + statuses = [(RunStatus.RUNNING, JobStatus.RUNNING)] + logs = [[_log_event("Engine core initialization failed\n", ts)]] + api = _make_api(statuses, logs) + + outcome, error_logs = _monitor_run(api, "svc", timeout_secs=60) + assert outcome is _Outcome.FAILED + assert "Engine core initialization failed" in error_logs + + +def _make_run(run_id: str, submission_id: str, deployment_num: int = 0) -> MagicMock: + run = MagicMock() + run._run.id = run_id + submission = MagicMock() + submission.id = submission_id + submission.deployment_num = deployment_num + run._run.jobs = [MagicMock(job_submissions=[submission])] + return run + + +class TestSubmissionToken: + def test_returns_none_when_run_missing(self): + api = MagicMock() + api.runs.get.return_value = None + assert _submission_token(api, "svc") is None + + def test_changes_when_new_submission(self): + api = MagicMock() + token1 = _make_token(api, "run-1", "sub-1", 0) + token2 = _make_token(api, "run-1", "sub-2", 1) + assert token1 != token2 + + +def _make_token(api, run_id, submission_id, deployment_num): + api.runs.get.return_value = _make_run(run_id, submission_id, deployment_num) + return _submission_token(api, "svc") + + +class TestSelfHealingLoopDeclineExit: + def test_exits_without_monitoring_when_user_declines(self, monkeypatch): + """Declining the plan must not be treated as a deployment failure.""" + api = MagicMock() + # No run exists before or after apply -> user declined the plan. + api.runs.get.return_value = None + + # apply_configuration is a no-op (simulates the declined prompt path). + monkeypatch.setattr( + deployer.ServiceConfigurator, "apply_configuration", lambda *a, **k: None + ) + monitor_called = MagicMock() + monkeypatch.setattr(deployer, "_monitor_run", monitor_called) + stop_called = MagicMock() + monkeypatch.setattr(deployer, "_stop_run", stop_called) + + params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="svc") + command_args = argparse.Namespace(yes=False, verbose=False, force=False, detach=False) + configuration = MagicMock() + configuration.name = "svc" + + _run_self_healing_loop( + api=api, + configuration=configuration, + params=params, + config_path=MagicMock(), + command_args=command_args, + configurator_args=argparse.Namespace(), + skill_path=None, + max_attempts=3, + monitor_timeout_secs=60, + ) + + monitor_called.assert_not_called() + stop_called.assert_not_called() + + +class TestAttachedApply: + def test_uses_pre_attach_hook_and_skips_custom_monitor(self, monkeypatch): + api = MagicMock() + run = _make_run("run-1", "sub-1", 0) + api.runs.get.side_effect = [None, run] + apply_calls = [] + + def fake_apply(_self, conf, configuration_path, command_args, configurator_args): + apply_calls.append(command_args) + if command_args.pre_attach_hook is not None: + command_args.pre_attach_hook(conf.name) + + monkeypatch.setattr(deployer.ServiceConfigurator, "apply_configuration", fake_apply) + monitor_called = MagicMock() + monkeypatch.setattr(deployer, "_monitor_run", monitor_called) + + params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="svc") + command_args = argparse.Namespace(yes=False, verbose=False, force=False, detach=False) + configuration = MagicMock() + configuration.name = "svc" + + _run_self_healing_loop( + api=api, + configuration=configuration, + params=params, + config_path=MagicMock(), + command_args=command_args, + configurator_args=argparse.Namespace(), + skill_path=None, + max_attempts=3, + monitor_timeout_secs=60, + ) + + assert apply_calls[0].detach is False + assert apply_calls[0].pre_attach_hook is deployer._print_monitoring_message + monitor_called.assert_not_called() + + def test_stops_failed_run_before_regenerating_in_attached_mode(self, monkeypatch): + """A failed attached apply must stop the run before redeploying. + + Otherwise dstack treats the next apply as an in-place rolling update of the + still-active service and the attached log stream breaks. + """ + api = MagicMock() + api.runs.get.return_value = _make_run("run-1", "sub-1", 0) + call_order = [] + + attempts = {"n": 0} + + def fake_apply(_self, conf, configuration_path, command_args, configurator_args): + attempts["n"] += 1 + call_order.append(f"apply-{attempts['n']}") + if attempts["n"] == 1: + raise SystemExit(1) + + monkeypatch.setattr(deployer.ServiceConfigurator, "apply_configuration", fake_apply) + monkeypatch.setattr(deployer, "_fetch_recent_logs", lambda *a, **k: "boom") + monkeypatch.setattr(deployer, "_stop_run", lambda *a, **k: call_order.append("stop")) + + new_config = MagicMock() + new_config.name = "svc" + monkeypatch.setattr( + deployer, + "_regenerate_configuration", + lambda **k: (call_order.append("regenerate"), (new_config, k["config_path"]))[1], + ) + + params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="svc") + command_args = argparse.Namespace(yes=False, verbose=False, force=False, detach=False) + configuration = MagicMock() + configuration.name = "svc" + + _run_self_healing_loop( + api=api, + configuration=configuration, + params=params, + config_path=MagicMock(), + command_args=command_args, + configurator_args=argparse.Namespace(), + skill_path=None, + max_attempts=4, + monitor_timeout_secs=60, + ) + + assert call_order == ["apply-1", "stop", "regenerate", "apply-2"] + + +class TestDetachOnInterrupt: + def test_prompts_and_detaches_when_user_declines_stop(self, monkeypatch): + api = MagicMock() + run = MagicMock() + run.status.is_finished.return_value = False + api.runs.get.return_value = run + monkeypatch.setattr(deployer, "confirm_ask", lambda _prompt: False) + stop_called = MagicMock() + monkeypatch.setattr(api.client.runs, "stop", stop_called) + + _handle_detach_on_interrupt(api, "my-run", yes=False) + + stop_called.assert_not_called() + + def test_stops_run_when_user_confirms(self, monkeypatch): + api = MagicMock() + run = MagicMock() + run.status.is_finished.return_value = False + finished_run = MagicMock() + finished_run.status.is_finished.return_value = True + api.runs.get.side_effect = [run, finished_run] + monkeypatch.setattr(deployer, "confirm_ask", lambda _prompt: True) + monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) + + _handle_detach_on_interrupt(api, "my-run", yes=False) + + api.client.runs.stop.assert_called_once_with(api.project, ["my-run"], False) + + def test_skips_prompt_when_run_already_finished(self, monkeypatch): + api = MagicMock() + run = MagicMock() + run.status.is_finished.return_value = True + api.runs.get.return_value = run + confirm_called = MagicMock() + monkeypatch.setattr(deployer, "confirm_ask", confirm_called) + + _handle_detach_on_interrupt(api, "my-run", yes=False) + + confirm_called.assert_not_called() + + def test_detaches_without_prompt_when_yes_flag_set(self, monkeypatch): + api = MagicMock() + run = MagicMock() + run.status.is_finished.return_value = False + api.runs.get.return_value = run + confirm_called = MagicMock() + monkeypatch.setattr(deployer, "confirm_ask", confirm_called) + stop_called = MagicMock() + monkeypatch.setattr(api.client.runs, "stop", stop_called) + + _handle_detach_on_interrupt(api, "my-run", yes=True) + + confirm_called.assert_not_called() + stop_called.assert_not_called() diff --git a/src/tests/_internal/harness/test_generator.py b/src/tests/_internal/harness/test_generator.py new file mode 100644 index 0000000000..5d51428aa8 --- /dev/null +++ b/src/tests/_internal/harness/test_generator.py @@ -0,0 +1,253 @@ +import argparse + +import pytest +import yaml + +from dstack._internal.cli.services.args import gpu_spec +from dstack._internal.core.errors import CLIError +from dstack._internal.harness.generator import ( + _extract_yaml, + generate_service_configuration, + parse_service_yaml, + regenerate_service_configuration, + save_service_configuration, +) +from dstack._internal.harness.models import EndpointCreateParams, default_endpoint_name + + +class _StubLLM: + def __init__(self, response: str): + self.response = response + self.last_system_prompt = None + self.last_user_prompt = None + + def chat(self, system_prompt: str, user_prompt: str) -> str: + self.last_system_prompt = system_prompt + self.last_user_prompt = user_prompt + return self.response + + +class TestDefaultEndpointName: + def test_derives_stable_name_from_model(self): + assert ( + default_endpoint_name("meta-llama/Meta-Llama-3.1-8B-Instruct") + == "meta-llama-3-1-8b-instruct" + ) + + def test_handles_model_without_org(self): + assert default_endpoint_name("gpt-4o") == "gpt-4o" + + +class TestEndpointCreateParams: + def test_from_namespace_includes_cli_options(self): + args = argparse.Namespace( + run_name=None, + gpu_spec=gpu_spec("24GB"), + cpu_spec=None, + memory_spec=None, + disk_spec=None, + backends=["runpod"], + regions=["us-east-1"], + instance_types=None, + fleets=None, + max_price=None, + max_duration=None, + spot_policy=None, + env_vars=[], + ) + params = EndpointCreateParams.from_namespace( + args, model="meta-llama/Meta-Llama-3.1-8B-Instruct" + ) + assert params.gpu is not None + assert params.backends == ["runpod"] + assert params.regions == ["us-east-1"] + assert params.name == "meta-llama-3-1-8b-instruct" + assert "gpu" in params.cli_options() + assert "backends" in params.cli_options() + + def test_apply_resources_args_overrides_llm_gpu(self): + from dstack._internal.cli.services.resources import apply_resources_args + + yaml_text = """ +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +commands: + - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct +resources: + gpu: 80GB +""" + configuration = parse_service_yaml(yaml_text) + args = argparse.Namespace( + cpu_spec=None, + gpu_spec=gpu_spec("24GB"), + memory_spec=None, + disk_spec=None, + ) + apply_resources_args(args, configuration) + assert configuration.resources.gpu.memory.min == 24.0 + assert configuration.resources.gpu.memory.max == 24.0 + + +class TestExtractYaml: + def test_extracts_fenced_yaml(self): + text = """Here is the config: +```yaml +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +commands: + - uv pip install vllm + - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct +resources: + gpu: 80GB +``` +""" + yaml_text = _extract_yaml(text) + assert yaml_text.startswith("type: service") + + def test_raises_when_yaml_missing(self): + with pytest.raises(CLIError): + _extract_yaml("no yaml here") + + +class TestParseServiceYaml: + def test_parses_valid_service_yaml(self): + yaml_text = """ +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +commands: + - uv pip install vllm + - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct +resources: + gpu: 80GB +""" + configuration = parse_service_yaml(yaml_text) + assert configuration.name == "llama" + assert configuration.model.name == "meta-llama/Meta-Llama-3.1-8B-Instruct" + + def test_strips_secret_values_from_env(self): + yaml_text = """ +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +env: + - HF_TOKEN=secret +commands: + - echo hi +resources: + gpu: 80GB +""" + configuration = parse_service_yaml(yaml_text) + assert "HF_TOKEN" in configuration.env + assert configuration.env["HF_TOKEN"].key == "HF_TOKEN" + + +class TestSaveServiceConfiguration: + def test_saves_yaml_with_python_version(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + yaml_text = """ +type: service +name: llama +python: "3.12" +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +commands: + - uv pip install vllm + - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct +resources: + gpu: 80GB +""" + configuration = parse_service_yaml(yaml_text) + + config_path = save_service_configuration(configuration) + + assert config_path.exists() + saved = yaml.safe_load(config_path.read_text()) + assert saved["type"] == "service" + assert saved["python"] == "3.12" + # Re-parsing the saved file should succeed. + parse_service_yaml(config_path.read_text()) + + def test_never_persists_resolved_secret_values(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + yaml_text = """ +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +env: + - HF_TOKEN +commands: + - echo hi +resources: + gpu: 80GB +""" + configuration = parse_service_yaml(yaml_text) + # Simulate env resolution that happens via configurator.apply_args: + # the sentinel is replaced by a real secret value in memory. + configuration.env["HF_TOKEN"] = "hf_super_secret_value" + + config_path = save_service_configuration(configuration) + + content = config_path.read_text() + assert "hf_super_secret_value" not in content + saved = yaml.safe_load(content) + assert saved["env"] == ["HF_TOKEN"] + + +class TestRegenerateServiceConfiguration: + def test_uses_error_logs_and_returns_fixed_config(self, tmp_path): + skill = tmp_path / "SKILL.md" + skill.write_text("dummy skill") + fixed_yaml = """```yaml +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +commands: + - uv pip install vllm + - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 8192 +resources: + gpu: L4:1 +```""" + stub = _StubLLM(fixed_yaml) + params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="llama") + + configuration = regenerate_service_configuration( + params=params, + previous_yaml="type: service\nname: llama\n", + error_logs="ValueError: 16.0 GiB KV cache is needed ... available 5.58 GiB", + skill_path=str(skill), + llm_client=stub, + ) + + assert "--max-model-len 8192" in configuration.commands[-1] + assert "KV cache" in stub.last_user_prompt + assert configuration.name == "llama" + + def test_generate_uses_stub_client(self, tmp_path): + skill = tmp_path / "SKILL.md" + skill.write_text("dummy skill") + generated = """```yaml +type: service +name: llama +port: 8000 +model: meta-llama/Meta-Llama-3.1-8B-Instruct +commands: + - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct +resources: + gpu: L4:1 +```""" + stub = _StubLLM(generated) + params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="llama") + + configuration = generate_service_configuration( + params=params, skill_path=str(skill), llm_client=stub + ) + assert configuration.name == "llama" diff --git a/src/tests/_internal/harness/test_llm.py b/src/tests/_internal/harness/test_llm.py new file mode 100644 index 0000000000..924c5db0dc --- /dev/null +++ b/src/tests/_internal/harness/test_llm.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dstack._internal.core.errors import CLIError +from dstack._internal.harness.llm import HarnessLLMClient + + +def _mock_response(status_code: int, json_body: dict) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.json.return_value = json_body + response.text = str(json_body) + return response + + +class TestHarnessLLMClient: + def test_requires_api_key(self, monkeypatch): + monkeypatch.delenv("DSTACK_HARNESS_API_KEY", raising=False) + with pytest.raises(CLIError): + HarnessLLMClient() + + def test_anthropic_request_shape(self, monkeypatch): + monkeypatch.delenv("DSTACK_HARNESS_PROVIDER", raising=False) + monkeypatch.delenv("DSTACK_HARNESS_BASE_URL", raising=False) + monkeypatch.delenv("DSTACK_HARNESS_MODEL", raising=False) + client = HarnessLLMClient(api_key="test-key") + assert client.provider == "anthropic" + + with patch("dstack._internal.harness.llm.requests.post") as post: + post.return_value = _mock_response( + 200, + { + "content": [{"type": "text", "text": "yaml here"}], + "usage": {"input_tokens": 10, "output_tokens": 24}, + }, + ) + result = client.chat("system", "user") + + assert result == "yaml here" + called_url = post.call_args.args[0] + called_kwargs = post.call_args.kwargs + assert called_url.endswith("/messages") + assert called_kwargs["headers"]["x-api-key"] == "test-key" + assert called_kwargs["headers"]["anthropic-version"] == "2023-06-01" + assert called_kwargs["json"]["system"] == "system" + assert called_kwargs["json"]["max_tokens"] > 0 + assert called_kwargs["json"]["messages"] == [{"role": "user", "content": "user"}] + + def test_openai_request_shape(self): + client = HarnessLLMClient(api_key="test-key", provider="openai") + + with patch("dstack._internal.harness.llm.requests.post") as post: + post.return_value = _mock_response( + 200, + { + "choices": [{"message": {"content": "yaml here"}}], + "usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20}, + }, + ) + result = client.chat("system", "user") + + assert result == "yaml here" + called_url = post.call_args.args[0] + called_kwargs = post.call_args.kwargs + assert called_url.endswith("/chat/completions") + assert called_kwargs["headers"]["Authorization"] == "Bearer test-key" + + def test_raises_on_error_status(self): + client = HarnessLLMClient(api_key="test-key", provider="anthropic") + with patch("dstack._internal.harness.llm.requests.post") as post: + post.return_value = _mock_response(401, {"error": "unauthorized"}) + with pytest.raises(CLIError): + client.chat("system", "user") + + def test_prints_token_usage_for_anthropic(self, capsys): + client = HarnessLLMClient(api_key="test-key", provider="anthropic") + with patch("dstack._internal.harness.llm.requests.post") as post: + post.return_value = _mock_response( + 200, + { + "content": [{"type": "text", "text": "yaml here"}], + "usage": {"input_tokens": 10, "output_tokens": 24}, + }, + ) + client.chat("system", "user") + + output = capsys.readouterr().out + assert "LLM Provider: Anthropic" in output + assert "LLM Model:" in output + assert "input=10" in output + assert "output=24" in output + assert "total=34" in output From a25d1636276a49ca1f0063c8dcedebf0df3586b6 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 25 Jun 2026 19:57:53 +0545 Subject: [PATCH 2/3] Agent Harness MVP --- .../_internal/cli/commands/test_endpoint.py | 21 -- src/tests/_internal/harness/test_deployer.py | 302 ------------------ src/tests/_internal/harness/test_generator.py | 253 --------------- src/tests/_internal/harness/test_llm.py | 93 ------ 4 files changed, 669 deletions(-) delete mode 100644 src/tests/_internal/cli/commands/test_endpoint.py delete mode 100644 src/tests/_internal/harness/test_deployer.py delete mode 100644 src/tests/_internal/harness/test_generator.py delete mode 100644 src/tests/_internal/harness/test_llm.py diff --git a/src/tests/_internal/cli/commands/test_endpoint.py b/src/tests/_internal/cli/commands/test_endpoint.py deleted file mode 100644 index 85a3b12494..0000000000 --- a/src/tests/_internal/cli/commands/test_endpoint.py +++ /dev/null @@ -1,21 +0,0 @@ -from pytest import CaptureFixture - -from tests._internal.cli.common import run_dstack_cli - - -class TestEndpointCommand: - def test_create_help(self, capsys: CaptureFixture): - exit_code = run_dstack_cli(["endpoint", "create", "--help"]) - assert exit_code == 0 - output = capsys.readouterr().out - assert "--model" in output - assert "--gpu" in output - assert "--backend" in output - - def test_requires_harness_api_key(self, capsys: CaptureFixture, monkeypatch): - monkeypatch.delenv("DSTACK_HARNESS_API_KEY", raising=False) - exit_code = run_dstack_cli( - ["endpoint", "create", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--dry-run"] - ) - assert exit_code == 1 - assert "DSTACK_HARNESS_API_KEY" in capsys.readouterr().out diff --git a/src/tests/_internal/harness/test_deployer.py b/src/tests/_internal/harness/test_deployer.py deleted file mode 100644 index 4ba18dda24..0000000000 --- a/src/tests/_internal/harness/test_deployer.py +++ /dev/null @@ -1,302 +0,0 @@ -import argparse -import base64 -from datetime import datetime, timezone -from unittest.mock import MagicMock - -from dstack._internal.core.models.runs import JobStatus, RunStatus -from dstack._internal.harness import deployer -from dstack._internal.harness.deployer import ( - _handle_detach_on_interrupt, - _monitor_run, - _Outcome, - _run_self_healing_loop, - _submission_token, -) -from dstack._internal.harness.models import EndpointCreateParams - - -def _log_event(message: str, ts: datetime) -> MagicMock: - event = MagicMock() - event.message = base64.b64encode(message.encode()).decode() - event.timestamp = ts - return event - - -def _make_api(statuses, logs_per_poll): - """Build a fake Client whose run status and logs change per monitor iteration.""" - api = MagicMock() - api.project = "main" - - state = {"i": 0} - - def get_run(_name): - idx = min(state["i"], len(statuses) - 1) - run_status, job_status = statuses[idx] - run = MagicMock() - run.status = run_status - submission = MagicMock() - submission.status = job_status - submission.id = "11111111-1111-4111-8111-111111111111" - run._run.jobs = [MagicMock(job_submissions=[submission])] - return run - - api.runs.get.side_effect = get_run - - def poll(project_name, body): - idx = min(state["i"], len(logs_per_poll) - 1) - resp = MagicMock() - resp.logs = logs_per_poll[idx] - resp.next_token = None - state["i"] += 1 - return resp - - api.client.logs.poll.side_effect = poll - return api - - -class TestMonitorRun: - def test_detects_success_on_ready_marker(self, monkeypatch): - monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) - ts = datetime(2026, 1, 1, tzinfo=timezone.utc) - statuses = [ - (RunStatus.RUNNING, JobStatus.RUNNING), - (RunStatus.RUNNING, JobStatus.RUNNING), - ] - logs = [ - [_log_event("Loading model...\n", ts)], - [_log_event("INFO: Application startup complete\n", ts)], - ] - api = _make_api(statuses, logs) - - outcome, _ = _monitor_run(api, "svc", timeout_secs=60) - assert outcome is _Outcome.SUCCESS - - def test_detects_failure_on_failed_status(self, monkeypatch): - monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) - ts = datetime(2026, 1, 1, tzinfo=timezone.utc) - statuses = [(RunStatus.FAILED, JobStatus.FAILED)] - logs = [[_log_event("ValueError: KV cache\n", ts)]] - api = _make_api(statuses, logs) - - outcome, error_logs = _monitor_run(api, "svc", timeout_secs=60) - assert outcome is _Outcome.FAILED - assert "KV cache" in error_logs - - def test_detects_failure_on_fatal_log_pattern(self, monkeypatch): - monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) - ts = datetime(2026, 1, 1, tzinfo=timezone.utc) - statuses = [(RunStatus.RUNNING, JobStatus.RUNNING)] - logs = [[_log_event("Engine core initialization failed\n", ts)]] - api = _make_api(statuses, logs) - - outcome, error_logs = _monitor_run(api, "svc", timeout_secs=60) - assert outcome is _Outcome.FAILED - assert "Engine core initialization failed" in error_logs - - -def _make_run(run_id: str, submission_id: str, deployment_num: int = 0) -> MagicMock: - run = MagicMock() - run._run.id = run_id - submission = MagicMock() - submission.id = submission_id - submission.deployment_num = deployment_num - run._run.jobs = [MagicMock(job_submissions=[submission])] - return run - - -class TestSubmissionToken: - def test_returns_none_when_run_missing(self): - api = MagicMock() - api.runs.get.return_value = None - assert _submission_token(api, "svc") is None - - def test_changes_when_new_submission(self): - api = MagicMock() - token1 = _make_token(api, "run-1", "sub-1", 0) - token2 = _make_token(api, "run-1", "sub-2", 1) - assert token1 != token2 - - -def _make_token(api, run_id, submission_id, deployment_num): - api.runs.get.return_value = _make_run(run_id, submission_id, deployment_num) - return _submission_token(api, "svc") - - -class TestSelfHealingLoopDeclineExit: - def test_exits_without_monitoring_when_user_declines(self, monkeypatch): - """Declining the plan must not be treated as a deployment failure.""" - api = MagicMock() - # No run exists before or after apply -> user declined the plan. - api.runs.get.return_value = None - - # apply_configuration is a no-op (simulates the declined prompt path). - monkeypatch.setattr( - deployer.ServiceConfigurator, "apply_configuration", lambda *a, **k: None - ) - monitor_called = MagicMock() - monkeypatch.setattr(deployer, "_monitor_run", monitor_called) - stop_called = MagicMock() - monkeypatch.setattr(deployer, "_stop_run", stop_called) - - params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="svc") - command_args = argparse.Namespace(yes=False, verbose=False, force=False, detach=False) - configuration = MagicMock() - configuration.name = "svc" - - _run_self_healing_loop( - api=api, - configuration=configuration, - params=params, - config_path=MagicMock(), - command_args=command_args, - configurator_args=argparse.Namespace(), - skill_path=None, - max_attempts=3, - monitor_timeout_secs=60, - ) - - monitor_called.assert_not_called() - stop_called.assert_not_called() - - -class TestAttachedApply: - def test_uses_pre_attach_hook_and_skips_custom_monitor(self, monkeypatch): - api = MagicMock() - run = _make_run("run-1", "sub-1", 0) - api.runs.get.side_effect = [None, run] - apply_calls = [] - - def fake_apply(_self, conf, configuration_path, command_args, configurator_args): - apply_calls.append(command_args) - if command_args.pre_attach_hook is not None: - command_args.pre_attach_hook(conf.name) - - monkeypatch.setattr(deployer.ServiceConfigurator, "apply_configuration", fake_apply) - monitor_called = MagicMock() - monkeypatch.setattr(deployer, "_monitor_run", monitor_called) - - params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="svc") - command_args = argparse.Namespace(yes=False, verbose=False, force=False, detach=False) - configuration = MagicMock() - configuration.name = "svc" - - _run_self_healing_loop( - api=api, - configuration=configuration, - params=params, - config_path=MagicMock(), - command_args=command_args, - configurator_args=argparse.Namespace(), - skill_path=None, - max_attempts=3, - monitor_timeout_secs=60, - ) - - assert apply_calls[0].detach is False - assert apply_calls[0].pre_attach_hook is deployer._print_monitoring_message - monitor_called.assert_not_called() - - def test_stops_failed_run_before_regenerating_in_attached_mode(self, monkeypatch): - """A failed attached apply must stop the run before redeploying. - - Otherwise dstack treats the next apply as an in-place rolling update of the - still-active service and the attached log stream breaks. - """ - api = MagicMock() - api.runs.get.return_value = _make_run("run-1", "sub-1", 0) - call_order = [] - - attempts = {"n": 0} - - def fake_apply(_self, conf, configuration_path, command_args, configurator_args): - attempts["n"] += 1 - call_order.append(f"apply-{attempts['n']}") - if attempts["n"] == 1: - raise SystemExit(1) - - monkeypatch.setattr(deployer.ServiceConfigurator, "apply_configuration", fake_apply) - monkeypatch.setattr(deployer, "_fetch_recent_logs", lambda *a, **k: "boom") - monkeypatch.setattr(deployer, "_stop_run", lambda *a, **k: call_order.append("stop")) - - new_config = MagicMock() - new_config.name = "svc" - monkeypatch.setattr( - deployer, - "_regenerate_configuration", - lambda **k: (call_order.append("regenerate"), (new_config, k["config_path"]))[1], - ) - - params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="svc") - command_args = argparse.Namespace(yes=False, verbose=False, force=False, detach=False) - configuration = MagicMock() - configuration.name = "svc" - - _run_self_healing_loop( - api=api, - configuration=configuration, - params=params, - config_path=MagicMock(), - command_args=command_args, - configurator_args=argparse.Namespace(), - skill_path=None, - max_attempts=4, - monitor_timeout_secs=60, - ) - - assert call_order == ["apply-1", "stop", "regenerate", "apply-2"] - - -class TestDetachOnInterrupt: - def test_prompts_and_detaches_when_user_declines_stop(self, monkeypatch): - api = MagicMock() - run = MagicMock() - run.status.is_finished.return_value = False - api.runs.get.return_value = run - monkeypatch.setattr(deployer, "confirm_ask", lambda _prompt: False) - stop_called = MagicMock() - monkeypatch.setattr(api.client.runs, "stop", stop_called) - - _handle_detach_on_interrupt(api, "my-run", yes=False) - - stop_called.assert_not_called() - - def test_stops_run_when_user_confirms(self, monkeypatch): - api = MagicMock() - run = MagicMock() - run.status.is_finished.return_value = False - finished_run = MagicMock() - finished_run.status.is_finished.return_value = True - api.runs.get.side_effect = [run, finished_run] - monkeypatch.setattr(deployer, "confirm_ask", lambda _prompt: True) - monkeypatch.setattr(deployer.time, "sleep", lambda _s: None) - - _handle_detach_on_interrupt(api, "my-run", yes=False) - - api.client.runs.stop.assert_called_once_with(api.project, ["my-run"], False) - - def test_skips_prompt_when_run_already_finished(self, monkeypatch): - api = MagicMock() - run = MagicMock() - run.status.is_finished.return_value = True - api.runs.get.return_value = run - confirm_called = MagicMock() - monkeypatch.setattr(deployer, "confirm_ask", confirm_called) - - _handle_detach_on_interrupt(api, "my-run", yes=False) - - confirm_called.assert_not_called() - - def test_detaches_without_prompt_when_yes_flag_set(self, monkeypatch): - api = MagicMock() - run = MagicMock() - run.status.is_finished.return_value = False - api.runs.get.return_value = run - confirm_called = MagicMock() - monkeypatch.setattr(deployer, "confirm_ask", confirm_called) - stop_called = MagicMock() - monkeypatch.setattr(api.client.runs, "stop", stop_called) - - _handle_detach_on_interrupt(api, "my-run", yes=True) - - confirm_called.assert_not_called() - stop_called.assert_not_called() diff --git a/src/tests/_internal/harness/test_generator.py b/src/tests/_internal/harness/test_generator.py deleted file mode 100644 index 5d51428aa8..0000000000 --- a/src/tests/_internal/harness/test_generator.py +++ /dev/null @@ -1,253 +0,0 @@ -import argparse - -import pytest -import yaml - -from dstack._internal.cli.services.args import gpu_spec -from dstack._internal.core.errors import CLIError -from dstack._internal.harness.generator import ( - _extract_yaml, - generate_service_configuration, - parse_service_yaml, - regenerate_service_configuration, - save_service_configuration, -) -from dstack._internal.harness.models import EndpointCreateParams, default_endpoint_name - - -class _StubLLM: - def __init__(self, response: str): - self.response = response - self.last_system_prompt = None - self.last_user_prompt = None - - def chat(self, system_prompt: str, user_prompt: str) -> str: - self.last_system_prompt = system_prompt - self.last_user_prompt = user_prompt - return self.response - - -class TestDefaultEndpointName: - def test_derives_stable_name_from_model(self): - assert ( - default_endpoint_name("meta-llama/Meta-Llama-3.1-8B-Instruct") - == "meta-llama-3-1-8b-instruct" - ) - - def test_handles_model_without_org(self): - assert default_endpoint_name("gpt-4o") == "gpt-4o" - - -class TestEndpointCreateParams: - def test_from_namespace_includes_cli_options(self): - args = argparse.Namespace( - run_name=None, - gpu_spec=gpu_spec("24GB"), - cpu_spec=None, - memory_spec=None, - disk_spec=None, - backends=["runpod"], - regions=["us-east-1"], - instance_types=None, - fleets=None, - max_price=None, - max_duration=None, - spot_policy=None, - env_vars=[], - ) - params = EndpointCreateParams.from_namespace( - args, model="meta-llama/Meta-Llama-3.1-8B-Instruct" - ) - assert params.gpu is not None - assert params.backends == ["runpod"] - assert params.regions == ["us-east-1"] - assert params.name == "meta-llama-3-1-8b-instruct" - assert "gpu" in params.cli_options() - assert "backends" in params.cli_options() - - def test_apply_resources_args_overrides_llm_gpu(self): - from dstack._internal.cli.services.resources import apply_resources_args - - yaml_text = """ -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -commands: - - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct -resources: - gpu: 80GB -""" - configuration = parse_service_yaml(yaml_text) - args = argparse.Namespace( - cpu_spec=None, - gpu_spec=gpu_spec("24GB"), - memory_spec=None, - disk_spec=None, - ) - apply_resources_args(args, configuration) - assert configuration.resources.gpu.memory.min == 24.0 - assert configuration.resources.gpu.memory.max == 24.0 - - -class TestExtractYaml: - def test_extracts_fenced_yaml(self): - text = """Here is the config: -```yaml -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -commands: - - uv pip install vllm - - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct -resources: - gpu: 80GB -``` -""" - yaml_text = _extract_yaml(text) - assert yaml_text.startswith("type: service") - - def test_raises_when_yaml_missing(self): - with pytest.raises(CLIError): - _extract_yaml("no yaml here") - - -class TestParseServiceYaml: - def test_parses_valid_service_yaml(self): - yaml_text = """ -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -commands: - - uv pip install vllm - - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct -resources: - gpu: 80GB -""" - configuration = parse_service_yaml(yaml_text) - assert configuration.name == "llama" - assert configuration.model.name == "meta-llama/Meta-Llama-3.1-8B-Instruct" - - def test_strips_secret_values_from_env(self): - yaml_text = """ -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -env: - - HF_TOKEN=secret -commands: - - echo hi -resources: - gpu: 80GB -""" - configuration = parse_service_yaml(yaml_text) - assert "HF_TOKEN" in configuration.env - assert configuration.env["HF_TOKEN"].key == "HF_TOKEN" - - -class TestSaveServiceConfiguration: - def test_saves_yaml_with_python_version(self, tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - yaml_text = """ -type: service -name: llama -python: "3.12" -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -commands: - - uv pip install vllm - - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct -resources: - gpu: 80GB -""" - configuration = parse_service_yaml(yaml_text) - - config_path = save_service_configuration(configuration) - - assert config_path.exists() - saved = yaml.safe_load(config_path.read_text()) - assert saved["type"] == "service" - assert saved["python"] == "3.12" - # Re-parsing the saved file should succeed. - parse_service_yaml(config_path.read_text()) - - def test_never_persists_resolved_secret_values(self, tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - yaml_text = """ -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -env: - - HF_TOKEN -commands: - - echo hi -resources: - gpu: 80GB -""" - configuration = parse_service_yaml(yaml_text) - # Simulate env resolution that happens via configurator.apply_args: - # the sentinel is replaced by a real secret value in memory. - configuration.env["HF_TOKEN"] = "hf_super_secret_value" - - config_path = save_service_configuration(configuration) - - content = config_path.read_text() - assert "hf_super_secret_value" not in content - saved = yaml.safe_load(content) - assert saved["env"] == ["HF_TOKEN"] - - -class TestRegenerateServiceConfiguration: - def test_uses_error_logs_and_returns_fixed_config(self, tmp_path): - skill = tmp_path / "SKILL.md" - skill.write_text("dummy skill") - fixed_yaml = """```yaml -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -commands: - - uv pip install vllm - - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 8192 -resources: - gpu: L4:1 -```""" - stub = _StubLLM(fixed_yaml) - params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="llama") - - configuration = regenerate_service_configuration( - params=params, - previous_yaml="type: service\nname: llama\n", - error_logs="ValueError: 16.0 GiB KV cache is needed ... available 5.58 GiB", - skill_path=str(skill), - llm_client=stub, - ) - - assert "--max-model-len 8192" in configuration.commands[-1] - assert "KV cache" in stub.last_user_prompt - assert configuration.name == "llama" - - def test_generate_uses_stub_client(self, tmp_path): - skill = tmp_path / "SKILL.md" - skill.write_text("dummy skill") - generated = """```yaml -type: service -name: llama -port: 8000 -model: meta-llama/Meta-Llama-3.1-8B-Instruct -commands: - - uv run vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct -resources: - gpu: L4:1 -```""" - stub = _StubLLM(generated) - params = EndpointCreateParams(model="meta-llama/Meta-Llama-3.1-8B-Instruct", name="llama") - - configuration = generate_service_configuration( - params=params, skill_path=str(skill), llm_client=stub - ) - assert configuration.name == "llama" diff --git a/src/tests/_internal/harness/test_llm.py b/src/tests/_internal/harness/test_llm.py deleted file mode 100644 index 924c5db0dc..0000000000 --- a/src/tests/_internal/harness/test_llm.py +++ /dev/null @@ -1,93 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from dstack._internal.core.errors import CLIError -from dstack._internal.harness.llm import HarnessLLMClient - - -def _mock_response(status_code: int, json_body: dict) -> MagicMock: - response = MagicMock() - response.status_code = status_code - response.json.return_value = json_body - response.text = str(json_body) - return response - - -class TestHarnessLLMClient: - def test_requires_api_key(self, monkeypatch): - monkeypatch.delenv("DSTACK_HARNESS_API_KEY", raising=False) - with pytest.raises(CLIError): - HarnessLLMClient() - - def test_anthropic_request_shape(self, monkeypatch): - monkeypatch.delenv("DSTACK_HARNESS_PROVIDER", raising=False) - monkeypatch.delenv("DSTACK_HARNESS_BASE_URL", raising=False) - monkeypatch.delenv("DSTACK_HARNESS_MODEL", raising=False) - client = HarnessLLMClient(api_key="test-key") - assert client.provider == "anthropic" - - with patch("dstack._internal.harness.llm.requests.post") as post: - post.return_value = _mock_response( - 200, - { - "content": [{"type": "text", "text": "yaml here"}], - "usage": {"input_tokens": 10, "output_tokens": 24}, - }, - ) - result = client.chat("system", "user") - - assert result == "yaml here" - called_url = post.call_args.args[0] - called_kwargs = post.call_args.kwargs - assert called_url.endswith("/messages") - assert called_kwargs["headers"]["x-api-key"] == "test-key" - assert called_kwargs["headers"]["anthropic-version"] == "2023-06-01" - assert called_kwargs["json"]["system"] == "system" - assert called_kwargs["json"]["max_tokens"] > 0 - assert called_kwargs["json"]["messages"] == [{"role": "user", "content": "user"}] - - def test_openai_request_shape(self): - client = HarnessLLMClient(api_key="test-key", provider="openai") - - with patch("dstack._internal.harness.llm.requests.post") as post: - post.return_value = _mock_response( - 200, - { - "choices": [{"message": {"content": "yaml here"}}], - "usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20}, - }, - ) - result = client.chat("system", "user") - - assert result == "yaml here" - called_url = post.call_args.args[0] - called_kwargs = post.call_args.kwargs - assert called_url.endswith("/chat/completions") - assert called_kwargs["headers"]["Authorization"] == "Bearer test-key" - - def test_raises_on_error_status(self): - client = HarnessLLMClient(api_key="test-key", provider="anthropic") - with patch("dstack._internal.harness.llm.requests.post") as post: - post.return_value = _mock_response(401, {"error": "unauthorized"}) - with pytest.raises(CLIError): - client.chat("system", "user") - - def test_prints_token_usage_for_anthropic(self, capsys): - client = HarnessLLMClient(api_key="test-key", provider="anthropic") - with patch("dstack._internal.harness.llm.requests.post") as post: - post.return_value = _mock_response( - 200, - { - "content": [{"type": "text", "text": "yaml here"}], - "usage": {"input_tokens": 10, "output_tokens": 24}, - }, - ) - client.chat("system", "user") - - output = capsys.readouterr().out - assert "LLM Provider: Anthropic" in output - assert "LLM Model:" in output - assert "input=10" in output - assert "output=24" in output - assert "total=34" in output From ddeb66bd932ff9dbb07bc87989540fc75f56a0a4 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 25 Jun 2026 20:09:29 +0545 Subject: [PATCH 3/3] Agent Harness MVP --- mkdocs/docs/guides/endpoint-harness.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mkdocs/docs/guides/endpoint-harness.md b/mkdocs/docs/guides/endpoint-harness.md index a7e3b6df2e..5f3478b5e1 100644 --- a/mkdocs/docs/guides/endpoint-harness.md +++ b/mkdocs/docs/guides/endpoint-harness.md @@ -19,12 +19,6 @@ You describe what to deploy (model, GPU, backends, and other profile options). T The harness does **not** pick cloud offers or provision instances. dstack's scheduler does that after submission, the same way it does for a hand-written service config. -??? info "Prerequisites" - - [dstack server and CLI](../installation.md) configured for your project - - At least one [fleet](../concepts/fleets.md) - - `DSTACK_HARNESS_API_KEY` set (see [LLM configuration](#llm-configuration)) - - [`skills/dstack/SKILL.md`](https://github.com/dstackai/dstack/blob/master/skills/dstack/SKILL.md) - in the project, or pass `--skill-path` ## Quick start