diff --git a/mkdocs.yml b/mkdocs.yml index 2fc74935a..894318ce2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -305,6 +305,7 @@ nav: - Exports: docs/concepts/exports.md - Guides: - CLI & API: docs/guides/cli-api.md + - Endpoint harness: docs/guides/endpoint-harness.md - Server deployment: docs/guides/server-deployment.md - Troubleshooting: docs/guides/troubleshooting.md - More: diff --git a/mkdocs/docs/guides/endpoint-harness.md b/mkdocs/docs/guides/endpoint-harness.md new file mode 100644 index 000000000..5f3478b5e --- /dev/null +++ b/mkdocs/docs/guides/endpoint-harness.md @@ -0,0 +1,177 @@ +--- +title: Endpoint harness +description: Deploy inference endpoints with dstack endpoint create and the agent harness +--- + +# Endpoint harness + +The endpoint harness powers `dstack endpoint create`. +It uses an LLM to generate a [`type: service`](../concepts/services.md) configuration, +then deploys it through the same code path as [`dstack apply`](../reference/cli/dstack/apply.md). + +You describe what to deploy (model, GPU, backends, and other profile options). The harness: + +1. Asks an LLM to produce a service YAML (including container `commands`) +2. Validates and saves the configuration +3. Submits the run via dstack +4. Monitors logs and, on failure, may ask the LLM to fix the config and redeploy + +The harness does **not** pick cloud offers or provision instances. dstack's scheduler +does that after submission, the same way it does for a hand-written service config. + + +## Quick start + +
+ +```shell +$ export DSTACK_HARNESS_API_KEY=sk-ant-... +$ export DSTACK_HARNESS_MODEL=claude-sonnet-4-8 +$ dstack endpoint create \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --gpu 24GB \ + --max-attempts 3 \ + -y +``` + +
+ +`DSTACK_HARNESS_MODEL` is optional. If unset, the harness defaults to `claude-sonnet-4-6` +for Anthropic. + +!!! note "`--max-attempts`" + Controls how many times the harness tries to deploy the endpoint. If the container + fails to start, it stops the run, asks the LLM to fix the configuration from the + error logs, and redeploys. Default is `3`. Set `--max-attempts 1` for a single + attempt with no retries. + +The command accepts the same resource and profile flags as [`dstack apply`](../reference/cli/dstack/apply.md) +for services (`--gpu`, `--cpu`, `--memory`, `--disk`, `--backend`, `--region`, `--fleet`, +`--max-price`, `--spot-policy`, and others). Run `dstack endpoint create --help` for the full list. + +## How it works + +```mermaid +flowchart TD + A[dstack endpoint create] --> B[Build EndpointCreateParams from CLI] + B --> C["LLM: generate service YAML"] + C --> D[Validate with parse_apply_configuration] + D --> E[Apply CLI overrides via ServiceConfigurator] + E --> F["Save to .dstack-harness-configs/"] + F --> I[ServiceConfigurator.apply_configuration] + I --> M[Monitor container logs] + M --> N{Ready?} + N -->|yes| O[Print service URLs] + N -->|failed| P[Stop run] + P --> Q["LLM: fix YAML from error logs"] + Q --> R{attempts left?} + R -->|yes| I + R -->|no| S[Give up] +``` + +Orchestration is **programmatic** (Python via `ServiceConfigurator`), not LLM-generated +`dstack` shell commands. The LLM only authors the service configuration and container +`commands` that run on the GPU instance. + + +## Relationship to `skills/dstack/SKILL.md` + +On every LLM call, the harness loads `skills/dstack/SKILL.md` and appends it to the system +prompt. + +## Prompts Send to LLM + +### Call 1: Generate configuration + +Fixed prefix: + +``` +You generate dstack service configuration files for model inference endpoints. + +Rules: +- Output a single valid YAML document for `type: service` +- Do not wrap the YAML in markdown unless you also include the YAML body in a fenced block +- Use only documented dstack service fields +- Put secret values only as env var names in `env`, never inline values +- Include `model`, `port`, `commands`, and `resources.gpu` when possible +- Prefer `python: "3.12"` unless the user requests a custom image +- User-provided CLI options in the request are mandatory: use the exact GPU, backends, + regions, fleets, CPU, memory, disk, and other resource/profile values given +- Do not substitute different resource sizes or backends than those specified by the user +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + + +``` + + +CLI options: + +``` +Generate a dstack service configuration for an inference endpoint. +The user passed these CLI options. You MUST use them exactly in the YAML. Do not substitute different GPU memory, backends, regions, fleets, or other resource/profile values. +{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "name": "meta-llama-3-1-8b-instruct", + "gpu": "24GB" +} + +Return only the YAML configuration. +``` + + +### Call 2: Fix configuration + +Fixed prefix: + +``` +You fix dstack service configurations that failed to start on the GPU instance. + +You are given the previous configuration and the container error logs. Return a +corrected single YAML document for `type: service`. + +Rules: +- Change as little as possible to address the specific error in the logs +- Keep `model`, `name`, and `resources` unless the error requires changing them +- For vLLM KV-cache / out-of-memory errors, prefer adding serve flags such as + `--max-model-len` or `--gpu-memory-utilization` rather than changing the GPU +- Keep secret values as env var names only, never inline values +- Use only documented dstack fields and valid serving CLI flags +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + + +``` + + + +Error logs: + +```` +The following dstack service configuration failed to start: +```yaml +type: service +name: meta-llama-3-1-8b-instruct +model: meta-llama/Llama-3.1-8B-Instruct +python: "3.12" +port: 8000 +commands: + - | + pip install vllm + vllm serve meta-llama/Llama-3.1-8B-Instruct \ + --host 0.0.0.0 \ + --port 8000 +resources: + gpu: 24GB:1 +``` + +Container error logs (tail): +``` +... +torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate ... +``` + +Return only the corrected YAML configuration. +```` diff --git a/skills/dstack/SKILL.md b/skills/dstack/SKILL.md index 4df1843b7..b434ea6c8 100644 --- a/skills/dstack/SKILL.md +++ b/skills/dstack/SKILL.md @@ -237,10 +237,17 @@ port: 8000 model: meta-llama/Meta-Llama-3.1-8B-Instruct resources: - gpu: 80GB + gpu: 24GB disk: 200GB ``` +**GPU sizing rule** + +If `--gpu` is not provided: +- 7B/8B models -> `gpu: 24GB` +- 13B/14B models -> `gpu: 40GB` or `48GB` +- 30B+ models -> `gpu: 80GB` + **Service endpoints:** - Without gateway: `/proxy/services///` - With gateway: `https://./` diff --git a/src/dstack/_internal/cli/commands/endpoint.py b/src/dstack/_internal/cli/commands/endpoint.py new file mode 100644 index 000000000..758a3cdac --- /dev/null +++ b/src/dstack/_internal/cli/commands/endpoint.py @@ -0,0 +1,147 @@ +import argparse +import shlex +from typing import cast + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.configurators.run import ServiceConfigurator +from dstack._internal.cli.utils.common import console +from dstack._internal.core.errors import CLIError +from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.harness import ( + EndpointCreateParams, + deploy_service_configuration, + deploy_service_with_self_healing, +) +from dstack._internal.harness.generator import ( + generate_service_configuration, + save_service_configuration, +) + + +class EndpointCommand(APIBaseCommand): + NAME = "endpoint" + DESCRIPTION = "Manage inference endpoints" + ACCEPT_EXTRA_ARGS = True + + def _register(self): + super()._register() + self._parser.set_defaults(subfunc=self._print_help) + subparsers = self._parser.add_subparsers(dest="action") + + create_parser = subparsers.add_parser( + "create", + help="Create an inference endpoint", + formatter_class=self._parser.formatter_class, + ) + create_parser.add_argument( + "--model", + required=True, + metavar="NAME", + help="The model to deploy", + ) + create_parser.add_argument( + "--skill-path", + metavar="PATH", + help="Path to [code]skills/dstack/SKILL.md[/]. Defaults to project skill file", + ) + create_parser.add_argument( + "--dry-run", + action="store_true", + help="Generate and save the configuration without deploying", + ) + create_parser.add_argument( + "-y", + "--yes", + help="Do not ask for confirmation", + action="store_true", + ) + create_parser.add_argument( + "-d", + "--detach", + help="Exit immediately after submitting instead of streaming container logs", + action="store_true", + ) + create_parser.add_argument( + "-v", + "--verbose", + help="Show all plan properties including those with default values", + action="store_true", + ) + create_parser.add_argument( + "--force", + help="Force apply when no changes detected", + action="store_true", + ) + create_parser.add_argument( + "--max-attempts", + type=int, + default=3, + metavar="N", + help=( + "Max deploy attempts. On container failure, the harness stops the run," + " asks the model to fix the configuration from the error logs, and redeploys." + " Set to 1 to disable self-healing" + ), + ) + ServiceConfigurator.register_args(create_parser) + create_parser.set_defaults(subfunc=self._create) + + def _command(self, args: argparse.Namespace): + super()._command(args) + args.subfunc(args) + + def _print_help(self, args: argparse.Namespace): + self._parser.print_help() + + def _create(self, args: argparse.Namespace): + configurator_parser = ServiceConfigurator.get_parser() + _, unknown_args = configurator_parser.parse_known_args(args.extra_args) + if unknown_args: + raise CLIError(f"Unrecognized arguments: {shlex.join(unknown_args)}") + + params = EndpointCreateParams.from_namespace(args, model=args.model) + + with console.status("Generating service configuration..."): + configuration = generate_service_configuration( + params=params, + skill_path=args.skill_path, + ) + + configurator = ServiceConfigurator(api_client=self.api) + configurator.apply_args(cast(TaskConfiguration, configuration), args) + configuration.model = args.model + + config_path = save_service_configuration(configuration) + console.print(f"Saved configuration to [code]{config_path}[/]") + + if args.dry_run: + console.print("Dry run complete. Skipping deployment.") + return + + apply_args = argparse.Namespace( + yes=args.yes, + detach=args.detach, + verbose=args.verbose, + force=args.force, + ) + + if args.detach: + deploy_service_configuration( + api=self.api, + configuration=configuration, + configuration_path=config_path, + command_args=apply_args, + configurator_args=args, + ) + return + + deploy_service_with_self_healing( + api=self.api, + configuration=configuration, + params=params, + configuration_path=config_path, + command_args=apply_args, + configurator_args=args, + skill_path=args.skill_path, + max_attempts=args.max_attempts, + ) diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 32f15a95f..335f7693f 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -8,6 +8,7 @@ from dstack._internal.cli.commands.attach import AttachCommand from dstack._internal.cli.commands.completion import CompletionCommand from dstack._internal.cli.commands.delete import DeleteCommand +from dstack._internal.cli.commands.endpoint import EndpointCommand from dstack._internal.cli.commands.event import EventCommand from dstack._internal.cli.commands.export import ExportCommand from dstack._internal.cli.commands.fleet import FleetCommand @@ -67,6 +68,7 @@ def main(): ApplyCommand.register(subparsers) AttachCommand.register(subparsers) DeleteCommand.register(subparsers) + EndpointCommand.register(subparsers) EventCommand.register(subparsers) ExportCommand.register(subparsers) FleetCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 16b0f0a87..adf2f1426 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -203,6 +203,10 @@ def apply_configuration( console.print(detach_message) return + pre_attach_hook = getattr(command_args, "pre_attach_hook", None) + if pre_attach_hook is not None: + pre_attach_hook(run.name) + abort_at_exit = False try: # We can attach to run multiple times if it goes from running to pending (retried). diff --git a/src/dstack/_internal/harness/__init__.py b/src/dstack/_internal/harness/__init__.py new file mode 100644 index 000000000..3f069be60 --- /dev/null +++ b/src/dstack/_internal/harness/__init__.py @@ -0,0 +1,20 @@ +from dstack._internal.harness.deployer import ( + deploy_service_configuration, + deploy_service_with_self_healing, +) +from dstack._internal.harness.generator import ( + generate_service_configuration, + regenerate_service_configuration, + save_service_configuration, +) +from dstack._internal.harness.models import EndpointCreateParams, default_endpoint_name + +__all__ = [ + "EndpointCreateParams", + "default_endpoint_name", + "deploy_service_configuration", + "deploy_service_with_self_healing", + "generate_service_configuration", + "regenerate_service_configuration", + "save_service_configuration", +] diff --git a/src/dstack/_internal/harness/deployer.py b/src/dstack/_internal/harness/deployer.py new file mode 100644 index 000000000..6d814fc7f --- /dev/null +++ b/src/dstack/_internal/harness/deployer.py @@ -0,0 +1,437 @@ +import argparse +import base64 +import sys +import time +from collections import deque +from datetime import timedelta +from enum import Enum +from pathlib import Path +from typing import Optional, cast + +from dstack._internal.cli.services.configurators.run import ( + ServiceConfigurator, + _print_service_urls, +) +from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.core.models.configurations import ServiceConfiguration, TaskConfiguration +from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.harness.generator import ( + regenerate_service_configuration, + save_service_configuration, +) +from dstack._internal.harness.models import EndpointCreateParams +from dstack._internal.server.schemas.logs import PollLogsRequest +from dstack.api._public import Client + +# Markers that indicate the model server finished starting and is serving. +READY_PATTERNS = ( + "Application startup complete", + "Uvicorn running on", + "The server is fired up and ready to roll", # SGLang + "Connected", # TGI +) +# Markers that indicate a fatal container error before the server became ready. +FATAL_PATTERNS = ( + "Traceback (most recent call last)", + "Engine core initialization failed", + "EngineCore failed to start", + "RuntimeError", + "CUDA out of memory", + "torch.cuda.OutOfMemoryError", +) + +DEFAULT_MONITOR_TIMEOUT_SECS = 1800 +MONITOR_POLL_INTERVAL_SECS = 3 +MAX_ERROR_LOG_CHARS = 6000 + + +class _Outcome(Enum): + SUCCESS = "success" + FAILED = "failed" + TIMEOUT = "timeout" + + +def deploy_service_configuration( + api: Client, + configuration: ServiceConfiguration, + configuration_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, +) -> None: + """Submit a service detached and report status (no self-healing).""" + _submit_detached(api, configuration, configuration_path, command_args, configurator_args) + if configuration.name: + _wait_for_service_and_report(api, configuration.name) + + +def deploy_service_with_self_healing( + api: Client, + configuration: ServiceConfiguration, + params: EndpointCreateParams, + configuration_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, + skill_path: Optional[str] = None, + max_attempts: int = 3, + monitor_timeout_secs: int = DEFAULT_MONITOR_TIMEOUT_SECS, +) -> None: + """Submit, monitor logs, and on failure stop, ask the LLM to fix, and redeploy.""" + config_path = configuration_path + try: + _run_self_healing_loop( + api=api, + configuration=configuration, + params=params, + config_path=config_path, + command_args=command_args, + configurator_args=configurator_args, + skill_path=skill_path, + max_attempts=max_attempts, + monitor_timeout_secs=monitor_timeout_secs, + ) + except KeyboardInterrupt: + _handle_detach_on_interrupt(api, configuration.name, command_args.yes) + + +def _print_monitoring_message(run_name: str) -> None: + console.print(f"[code]\\[harness][/] Monitoring logs for [code]{run_name}[/]...") + + +def _run_self_healing_loop( + api: Client, + configuration: ServiceConfiguration, + params: EndpointCreateParams, + config_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, + skill_path: Optional[str], + max_attempts: int, + monitor_timeout_secs: int, +) -> None: + detach = command_args.detach + for attempt in range(1, max_attempts + 1): + if attempt > 1: + console.print( + f"\n[code]\\[harness][/] Attempt {attempt}/{max_attempts}:" + " redeploying updated configuration..." + ) + + if not configuration.name: + return + + # Auto-confirm on retries; the user already approved the first plan. + submit_args = argparse.Namespace( + yes=command_args.yes or attempt > 1, + detach=detach, + verbose=command_args.verbose, + force=command_args.force, + pre_attach_hook=_print_monitoring_message if not detach else None, + ) + token_before = _submission_token(api, configuration.name) + configurator = ServiceConfigurator(api_client=api) + try: + configurator.apply_configuration( + conf=configuration, + configuration_path=str(config_path), + command_args=submit_args, + configurator_args=configurator_args, + ) + except SystemExit as e: + if detach or e.code in (0, None): + raise + if attempt == max_attempts: + console.print( + f"[code]\\[harness][/] Reached the maximum of {max_attempts} attempts." + " Giving up. See the logs above for the last error." + ) + raise + console.print( + f"[code]\\[harness][/] Detected a failure on attempt {attempt}." + f" Stopping run [code]{configuration.name}[/]..." + ) + error_logs = _fetch_recent_logs(api, configuration.name) + # Stop the failed run so the next attempt is a clean fresh deployment. + # Otherwise dstack treats the redeploy as an in-place rolling update of + # the still-active service, which breaks the attached log stream. + _stop_run(api, configuration.name) + console.print( + "[code]\\[harness][/] Asking the model to fix the configuration based on the error..." + ) + configuration, config_path = _regenerate_configuration( + api=api, + configuration=configuration, + params=params, + config_path=config_path, + configurator_args=configurator_args, + error_logs=error_logs, + skill_path=skill_path, + ) + continue + + # If no new submission was created, the user declined the plan (or there + # was nothing to apply). That is a clean exit, not a deployment failure. + token_after = _submission_token(api, configuration.name) + if token_after is None or token_after == token_before: + return + + if detach: + console.print( + f"[code]\\[harness][/] Monitoring logs for [code]{configuration.name}[/]..." + ) + outcome, error_logs = _monitor_run(api, configuration.name, monitor_timeout_secs) + else: + # Attached apply streams logs via dstack until the user detaches or the run ends. + return + + if outcome is _Outcome.SUCCESS: + run = api.runs.get(configuration.name) + if run is not None: + run.refresh() + _print_service_urls(run) + console.print( + f"[code]\\[harness][/] Endpoint [code]{configuration.name}[/] is up and serving." + ) + return + + if outcome is _Outcome.TIMEOUT: + console.print( + f"[code]\\[harness][/] Timed out waiting for [code]{configuration.name}[/]" + " to become ready. Stopping the run." + ) + _stop_run(api, configuration.name) + return + + console.print( + f"[code]\\[harness][/] Detected a failure on attempt {attempt}." + f" Stopping run [code]{configuration.name}[/]..." + ) + _stop_run(api, configuration.name) + + if attempt == max_attempts: + console.print( + f"[code]\\[harness][/] Reached the maximum of {max_attempts} attempts." + " Giving up. See the logs above for the last error." + ) + return + + console.print( + "[code]\\[harness][/] Asking the model to fix the configuration based on the error..." + ) + configuration, config_path = _regenerate_configuration( + api=api, + configuration=configuration, + params=params, + config_path=config_path, + configurator_args=configurator_args, + error_logs=error_logs, + skill_path=skill_path, + ) + + +def _regenerate_configuration( + api: Client, + configuration: ServiceConfiguration, + params: EndpointCreateParams, + config_path: Path, + configurator_args: argparse.Namespace, + error_logs: str, + skill_path: Optional[str], +) -> tuple[ServiceConfiguration, Path]: + previous_yaml = config_path.read_text(encoding="utf-8") + configuration = regenerate_service_configuration( + params=params, + previous_yaml=previous_yaml, + error_logs=error_logs, + skill_path=skill_path, + ) + ServiceConfigurator(api_client=api).apply_args( + cast(TaskConfiguration, configuration), configurator_args + ) + config_path = save_service_configuration(configuration) + console.print(f"[code]\\[harness][/] Saved updated configuration to [code]{config_path}[/]") + return configuration, config_path + + +def _fetch_recent_logs(api: Client, run_name: str) -> str: + run = api.runs.get(run_name) + if run is None: + return "" + try: + submission = run._run.jobs[0].job_submissions[-1] + except (AttributeError, IndexError): + return "" + events = _poll_new_logs(api, run_name, submission.id, None) + text = "".join(base64.b64decode(event.message).decode(errors="replace") for event in events) + return _truncate_logs(text) + + +def _truncate_logs(text: str) -> str: + if len(text) > MAX_ERROR_LOG_CHARS: + return text[-MAX_ERROR_LOG_CHARS:] + return text + + +def _handle_detach_on_interrupt(api: Client, run_name: Optional[str], yes: bool) -> None: + if not run_name: + return + run = api.runs.get(run_name) + if run is None or run.status.is_finished(): + return + try: + if yes or not confirm_ask(f"\nStop the run [code]{run_name}[/] before detaching?"): + console.print("Detached") + return + with console.status("Stopping..."): + api.client.runs.stop(api.project, [run_name], False) + while True: + run = api.runs.get(run_name) + if run is None or run.status.is_finished(): + break + time.sleep(2) + console.print("Stopped") + except KeyboardInterrupt: + with console.status("Aborting..."): + api.client.runs.stop(api.project, [run_name], True) + console.print("[error]Aborted[/]") + + +def _submit_detached( + api: Client, + configuration: ServiceConfiguration, + configuration_path: Path, + command_args: argparse.Namespace, + configurator_args: argparse.Namespace, +) -> None: + configurator = ServiceConfigurator(api_client=api) + submit_args = argparse.Namespace( + yes=command_args.yes, + detach=True, + verbose=command_args.verbose, + force=command_args.force, + ) + configurator.apply_configuration( + conf=configuration, + configuration_path=str(configuration_path), + command_args=submit_args, + configurator_args=configurator_args, + ) + + +def _submission_token(api: Client, run_name: str) -> Optional[tuple]: + """Identify the latest run submission so we can tell if a new run was submitted. + + Returns None when the run does not exist. When `apply` is declined or makes no + change, the token is unchanged; a fresh submission produces a different token. + """ + run = api.runs.get(run_name) + if run is None: + return None + run_id = getattr(run._run, "id", None) + try: + submission = run._run.jobs[0].job_submissions[-1] + except (AttributeError, IndexError): + return (str(run_id), None, None) + return (str(run_id), str(submission.id), submission.deployment_num) + + +def _monitor_run(api: Client, run_name: str, timeout_secs: int) -> tuple[_Outcome, str]: + deadline = time.monotonic() + timeout_secs + log_tail: deque[str] = deque(maxlen=300) + last_timestamp = None + ready_seen = False + fatal_seen = False + + while time.monotonic() < deadline: + run = api.runs.get(run_name) + if run is None: + return _Outcome.FAILED, _format_tail(log_tail) + + submission = run._run.jobs[0].job_submissions[-1] + events = _poll_new_logs(api, run_name, submission.id, last_timestamp) + for event in events: + text = base64.b64decode(event.message).decode(errors="replace") + sys.stdout.write(text) + sys.stdout.flush() + log_tail.append(text) + if any(pattern in text for pattern in READY_PATTERNS): + ready_seen = True + if any(pattern in text for pattern in FATAL_PATTERNS): + fatal_seen = True + last_timestamp = event.timestamp + + if run.status in (RunStatus.FAILED, RunStatus.TERMINATED) or ( + submission.status == JobStatus.FAILED + ): + return _Outcome.FAILED, _format_tail(log_tail) + + if ready_seen and submission.status == JobStatus.RUNNING: + return _Outcome.SUCCESS, _format_tail(log_tail) + + if fatal_seen and not ready_seen: + return _Outcome.FAILED, _format_tail(log_tail) + + time.sleep(MONITOR_POLL_INTERVAL_SECS) + + return _Outcome.TIMEOUT, _format_tail(log_tail) + + +def _poll_new_logs(api: Client, run_name: str, submission_id, start_time): + if start_time is not None: + start_time = start_time + timedelta(microseconds=1) + events = [] + next_token = None + while True: + resp = api.client.logs.poll( + project_name=api.project, + body=PollLogsRequest( + run_name=run_name, + job_submission_id=submission_id, + start_time=start_time, + end_time=None, + descending=False, + limit=1000, + diagnose=False, + next_token=next_token, + ), + ) + events.extend(resp.logs) + next_token = resp.next_token + if next_token is None: + break + return events + + +def _format_tail(log_tail: deque) -> str: + return _truncate_logs("".join(log_tail)) + + +def _stop_run(api: Client, run_name: str) -> None: + with console.status(f"Stopping {run_name}..."): + api.client.runs.stop(api.project, [run_name], False) + while True: + run = api.runs.get(run_name) + if run is None or run.status.is_finished(): + break + time.sleep(2) + console.print(f"[code]\\[harness][/] Stopped [code]{run_name}[/]") + + +def _wait_for_service_and_report(api: Client, run_name: str, attempts: int = 30) -> None: + for _ in range(attempts): + run = api.runs.get(run_name) + if run is None: + return + if run.status in (RunStatus.RUNNING, RunStatus.DONE, RunStatus.FAILED): + console.print() + run.refresh() + _print_service_urls(run) + if run.status == RunStatus.FAILED: + console.print( + f"[error]Run [code]{run_name}[/] failed. Check [code]dstack logs {run_name}[/]." + ) + return + time.sleep(2) + + console.print( + f"Run [code]{run_name}[/] is still provisioning. Check status with [code]dstack ps -v[/]." + ) diff --git a/src/dstack/_internal/harness/generator.py b/src/dstack/_internal/harness/generator.py new file mode 100644 index 000000000..aff5ef2b3 --- /dev/null +++ b/src/dstack/_internal/harness/generator.py @@ -0,0 +1,191 @@ +import json +import re +from pathlib import Path +from typing import Optional + +import yaml + +from dstack._internal.core.errors import CLIError, ConfigurationError +from dstack._internal.core.models.configurations import ( + ApplyConfigurationType, + ServiceConfiguration, + parse_apply_configuration, +) +from dstack._internal.core.models.envs import Env +from dstack._internal.harness.llm import HarnessLLMClient +from dstack._internal.harness.models import EndpointCreateParams +from dstack._internal.harness.skill import load_skill_content + +HARNESS_CONFIGS_DIR = Path(".dstack-harness-configs") + +SYSTEM_PROMPT_PREFIX = """\ +You generate dstack service configuration files for model inference endpoints. + +Rules: +- Output a single valid YAML document for `type: service` +- Do not wrap the YAML in markdown unless you also include the YAML body in a fenced block +- Use only documented dstack service fields +- Put secret values only as env var names in `env`, never inline values +- Include `model`, `port`, `commands`, and `resources.gpu` when possible +- Prefer `python: "3.12"` unless the user requests a custom image +- User-provided CLI options in the request are mandatory: use the exact GPU, backends, + regions, fleets, CPU, memory, disk, and other resource/profile values given +- Do not substitute different resource sizes or backends than those specified by the user +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + +""" + +FIX_SYSTEM_PROMPT_PREFIX = """\ +You fix dstack service configurations that failed to start on the GPU instance. + +You are given the previous configuration and the container error logs. Return a +corrected single YAML document for `type: service`. + +Rules: +- Change as little as possible to address the specific error in the logs +- Keep `model`, `name`, and `resources` unless the error requires changing them +- For vLLM KV-cache / out-of-memory errors, prefer adding serve flags such as + `--max-model-len` or `--gpu-memory-utilization` rather than changing the GPU +- Keep secret values as env var names only, never inline values +- Use only documented dstack fields and valid serving CLI flags +- Do not invent unsupported CLI flags or YAML properties + +Reference skill: + +""" + + +def _extract_yaml(text: str) -> str: + fenced = re.search(r"```(?:ya?ml)?\s*\n(.*?)```", text, flags=re.DOTALL | re.IGNORECASE) + if fenced: + return fenced.group(1).strip() + + stripped = text.strip() + if stripped.startswith("type:") or stripped.startswith("name:"): + return stripped + + raise CLIError("Harness LLM response did not contain YAML configuration") + + +def _normalize_env_names(configuration: ServiceConfiguration) -> None: + configuration.env = Env.parse_obj(list(configuration.env)) + + +def _validate_service_configuration(configuration: ServiceConfiguration) -> ServiceConfiguration: + if configuration.type != ApplyConfigurationType.SERVICE.value: + raise CLIError("Generated configuration must have [code]type: service[/]") + if configuration.model is None: + raise CLIError("Generated configuration must include a [code]model[/] field") + _normalize_env_names(configuration) + return configuration + + +def parse_service_yaml(yaml_text: str) -> ServiceConfiguration: + try: + data = yaml.safe_load(yaml_text) + except yaml.YAMLError as e: + raise CLIError(f"Generated YAML is invalid: {e}") from e + if not isinstance(data, dict): + raise CLIError("Generated YAML must be a mapping") + + try: + configuration = parse_apply_configuration(data) + except ConfigurationError as e: + raise CLIError(f"Generated configuration is invalid: {e}") from e + + if not isinstance(configuration, ServiceConfiguration): + raise CLIError("Generated configuration must be a service configuration") + return _validate_service_configuration(configuration) + + +def build_user_prompt(params: EndpointCreateParams) -> str: + return ( + "Generate a dstack service configuration for an inference endpoint.\n" + "The user passed these CLI options. You MUST use them exactly in the YAML." + " Do not substitute different GPU memory, backends, regions, fleets," + " or other resource/profile values.\n" + f"{json.dumps(params.cli_options(), indent=2, default=str)}\n\n" + "Return only the YAML configuration." + ) + + +def build_fix_prompt(params: EndpointCreateParams, previous_yaml: str, error_logs: str) -> str: + return ( + "The following dstack service configuration failed to start:\n" + f"```yaml\n{previous_yaml}\n```\n\n" + "Container error logs (tail):\n" + f"```\n{error_logs}\n```\n\n" + "Return only the corrected YAML configuration." + ) + + +def _apply_param_overrides( + configuration: ServiceConfiguration, params: EndpointCreateParams +) -> None: + if params.name: + configuration.name = params.name + if params.model: + configuration.model = params.model + + +def generate_service_configuration( + params: EndpointCreateParams, + skill_path: Optional[str] = None, + llm_client: Optional[HarnessLLMClient] = None, +) -> ServiceConfiguration: + skill_content = load_skill_content(skill_path) + client = llm_client or HarnessLLMClient() + response = client.chat( + system_prompt=SYSTEM_PROMPT_PREFIX + skill_content, + user_prompt=build_user_prompt(params), + ) + configuration = parse_service_yaml(_extract_yaml(response)) + _apply_param_overrides(configuration, params) + return configuration + + +def regenerate_service_configuration( + params: EndpointCreateParams, + previous_yaml: str, + error_logs: str, + skill_path: Optional[str] = None, + llm_client: Optional[HarnessLLMClient] = None, +) -> ServiceConfiguration: + skill_content = load_skill_content(skill_path) + client = llm_client or HarnessLLMClient() + response = client.chat( + system_prompt=FIX_SYSTEM_PROMPT_PREFIX + skill_content, + user_prompt=build_fix_prompt(params, previous_yaml, error_logs), + ) + configuration = parse_service_yaml(_extract_yaml(response)) + _apply_param_overrides(configuration, params) + return configuration + + +def get_endpoint_path(name: str) -> Path: + HARNESS_CONFIGS_DIR.mkdir(parents=True, exist_ok=True) + return HARNESS_CONFIGS_DIR / f"{name}.dstack.yml" + + +def save_service_configuration(configuration: ServiceConfiguration) -> Path: + if not configuration.name: + raise CLIError("Generated configuration must include a [code]name[/]") + + config_path = get_endpoint_path(configuration.name) + # Round-trip through JSON so enums and other rich types become plain + # primitives that yaml.safe_dump can represent. + config_dict = json.loads(configuration.json(exclude_none=True)) + # Never persist secret values to disk: env is always written as names only, + # even if values were resolved from the environment earlier in the flow. + env_names = list(configuration.env) + if env_names: + config_dict["env"] = env_names + else: + config_dict.pop("env", None) + config_path.write_text( + yaml.safe_dump(config_dict, sort_keys=False), + encoding="utf-8", + ) + return config_path diff --git a/src/dstack/_internal/harness/llm.py b/src/dstack/_internal/harness/llm.py new file mode 100644 index 000000000..9eafdd265 --- /dev/null +++ b/src/dstack/_internal/harness/llm.py @@ -0,0 +1,173 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import requests + +from dstack._internal.cli.utils.common import console +from dstack._internal.core.errors import CLIError + +REQUEST_TIMEOUT_SECS = 120 +DEFAULT_MAX_TOKENS = 4096 + +PROVIDER_ANTHROPIC = "anthropic" +PROVIDER_OPENAI = "openai" + +DEFAULTS = { + PROVIDER_ANTHROPIC: { + "base_url": "https://api.anthropic.com/v1", + "model": "claude-sonnet-4-6", + }, + PROVIDER_OPENAI: { + "base_url": "https://api.openai.com/v1", + "model": "gpt-4o-mini", + }, +} +ANTHROPIC_VERSION = "2023-06-01" + + +@dataclass(frozen=True) +class LLMUsage: + input_tokens: int + output_tokens: int + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +class HarnessLLMClient: + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: Optional[str] = None, + provider: Optional[str] = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + ): + self.api_key = api_key or os.getenv("DSTACK_HARNESS_API_KEY") + if not self.api_key: + raise CLIError( + "DSTACK_HARNESS_API_KEY is not set." + " Export it before running [code]dstack endpoint create[/]." + ) + self.provider = ( + provider or os.getenv("DSTACK_HARNESS_PROVIDER") or PROVIDER_ANTHROPIC + ).lower() + if self.provider not in DEFAULTS: + raise CLIError( + f"Unsupported harness provider [code]{self.provider}[/]." + f" Supported: {', '.join(DEFAULTS)}." + ) + defaults = DEFAULTS[self.provider] + self.base_url = ( + base_url or os.getenv("DSTACK_HARNESS_BASE_URL") or defaults["base_url"] + ).rstrip("/") + self.model = model or os.getenv("DSTACK_HARNESS_MODEL") or defaults["model"] + self.max_tokens = max_tokens + + def chat(self, system_prompt: str, user_prompt: str) -> str: + if self.provider == PROVIDER_ANTHROPIC: + return self._chat_anthropic(system_prompt, user_prompt) + return self._chat_openai(system_prompt, user_prompt) + + def _chat_anthropic(self, system_prompt: str, user_prompt: str) -> str: + url = f"{self.base_url}/messages" + payload = { + "model": self.model, + "max_tokens": self.max_tokens, + "system": system_prompt, + "messages": [{"role": "user", "content": user_prompt}], + } + headers = { + "x-api-key": self.api_key, + "anthropic-version": ANTHROPIC_VERSION, + "content-type": "application/json", + } + data = self._post(url, payload, headers) + self._print_usage(_parse_anthropic_usage(data)) + try: + return data["content"][0]["text"] + except (KeyError, IndexError, TypeError) as e: + raise CLIError(f"Unexpected harness LLM response: {data}") from e + + def _chat_openai(self, system_prompt: str, user_prompt: str) -> str: + url = f"{self.base_url}/chat/completions" + payload = { + "model": self.model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "temperature": 0, + } + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + data = self._post(url, payload, headers) + self._print_usage(_parse_openai_usage(data)) + try: + return data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as e: + raise CLIError(f"Unexpected harness LLM response: {data}") from e + + def _print_usage(self, usage: Optional[LLMUsage]) -> None: + provider_label = self.provider.capitalize() + console.print(f"[code]\\[harness][/] LLM Provider: {provider_label}") + console.print(f"[code]\\[harness][/] LLM Model: {self.model}") + if usage is None: + return + console.print( + "[code]\\[harness][/] LLM tokens:" + f" input={usage.input_tokens}," + f" output={usage.output_tokens}," + f" total={usage.total_tokens}" + ) + + def _post(self, url: str, payload: dict, headers: dict) -> dict: + try: + response = requests.post( + url, + json=payload, + headers=headers, + timeout=REQUEST_TIMEOUT_SECS, + ) + except requests.RequestException as e: + raise CLIError(f"Failed to call harness LLM: {e}") from e + + if response.status_code >= 400: + raise CLIError( + f"Harness LLM request failed with status {response.status_code}: {response.text}" + ) + + try: + return response.json() + except ValueError as e: + raise CLIError(f"Unexpected harness LLM response: {response.text}") from e + + +def _parse_anthropic_usage(data: dict) -> Optional[LLMUsage]: + usage = data.get("usage") + if not isinstance(usage, dict): + return None + try: + return LLMUsage( + input_tokens=int(usage["input_tokens"]), + output_tokens=int(usage["output_tokens"]), + ) + except (KeyError, TypeError, ValueError): + return None + + +def _parse_openai_usage(data: dict) -> Optional[LLMUsage]: + usage = data.get("usage") + if not isinstance(usage, dict): + return None + try: + return LLMUsage( + input_tokens=int(usage["prompt_tokens"]), + output_tokens=int(usage["completion_tokens"]), + ) + except (KeyError, TypeError, ValueError): + return None diff --git a/src/dstack/_internal/harness/models.py b/src/dstack/_internal/harness/models.py new file mode 100644 index 000000000..0bb70815d --- /dev/null +++ b/src/dstack/_internal/harness/models.py @@ -0,0 +1,59 @@ +import argparse +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional + +from dstack._internal.core.errors import CLIError + + +def default_endpoint_name(model: str) -> str: + """Derive a stable run name from a model id (e.g. org/model -> model, normalized).""" + base = model.rsplit("/", 1)[-1] + name = base.lower().replace(".", "-") + name = re.sub(r"[^a-z0-9-]+", "-", name) + name = re.sub(r"-+", "-", name).strip("-") + if not name: + raise CLIError(f"Cannot derive an endpoint name from model [code]{model}[/]") + return name + + +@dataclass +class EndpointCreateParams: + model: str + name: Optional[str] = None + gpu: Optional[Any] = None + cpu: Optional[Any] = None + memory: Optional[Any] = None + disk: Optional[Any] = None + backends: List[str] = field(default_factory=list) + regions: List[str] = field(default_factory=list) + instance_types: List[str] = field(default_factory=list) + fleets: List[str] = field(default_factory=list) + max_price: Optional[float] = None + max_duration: Optional[int] = None + spot_policy: Optional[str] = None + env_vars: List[str] = field(default_factory=list) + + @classmethod + def from_namespace(cls, args: argparse.Namespace, model: str) -> "EndpointCreateParams": + spot_policy = getattr(args, "spot_policy", None) + return cls( + model=model, + name=getattr(args, "run_name", None) or default_endpoint_name(model), + gpu=getattr(args, "gpu_spec", None), + cpu=getattr(args, "cpu_spec", None), + memory=getattr(args, "memory_spec", None), + disk=getattr(args, "disk_spec", None), + backends=getattr(args, "backends", None) or [], + regions=getattr(args, "regions", None) or [], + instance_types=getattr(args, "instance_types", None) or [], + fleets=getattr(args, "fleets", None) or [], + max_price=getattr(args, "max_price", None), + max_duration=getattr(args, "max_duration", None), + spot_policy=spot_policy.value if spot_policy is not None else None, + env_vars=[item.key for item in getattr(args, "env_vars", [])], + ) + + def cli_options(self) -> Dict[str, Any]: + options = asdict(self) + return {key: value for key, value in options.items() if value not in (None, [], {})} diff --git a/src/dstack/_internal/harness/skill.py b/src/dstack/_internal/harness/skill.py new file mode 100644 index 000000000..2e16533dd --- /dev/null +++ b/src/dstack/_internal/harness/skill.py @@ -0,0 +1,31 @@ +from pathlib import Path +from typing import Optional + +from dstack._internal.core.errors import CLIError + +DEFAULT_SKILL_RELATIVE_PATH = Path("skills/dstack/SKILL.md") + + +def find_skill_path(skill_path: Optional[str] = None) -> Path: + if skill_path is not None: + path = Path(skill_path) + if not path.is_file(): + raise CLIError(f"Skill file not found: {skill_path}") + return path + + candidates = [ + Path.cwd() / DEFAULT_SKILL_RELATIVE_PATH, + Path(__file__).resolve().parents[4] / DEFAULT_SKILL_RELATIVE_PATH, + ] + for path in candidates: + if path.is_file(): + return path + + raise CLIError( + "dstack skill not found. Expected " + f"[code]{DEFAULT_SKILL_RELATIVE_PATH}[/] in the current directory." + ) + + +def load_skill_content(skill_path: Optional[str] = None) -> str: + return find_skill_path(skill_path).read_text(encoding="utf-8")