From 4797a726bcb290155ec44bbb933884d446f965dd Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 23 Jun 2026 01:45:16 +0200 Subject: [PATCH] Add gateway replica statuses and pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduce gateway replica statuses. - Provision and terminate gateway replicas independently from each other, in a separate pipeline. In this version, the pipelines have the following responsibilities. Gateway pipeline: - `SUBMITTED` — create replica DB records, transition to `PROVISIONING`; in a future version — create the load balancer (e.g., AWS ALB). - `PROVISIONING` — once all replicas reach `RUNNING`, transition to `RUNNING`; if any replica enters `TERMINATING` or `TERMINATED`, transition to `FAILED`. - `RUNNING`, `FAILED` — delete the gateway if deletion requested and all replicas are `TERMINATED`. Gateway replica pipeline: - `SUBMITTED` — call backend to create the cloud instance, transition to `PROVISIONING` on success, or `TERMINATED` on failure. - `PROVISIONING` — SSH-connect to the instance and configure the gateway, transition to `RUNNING` on success, or `TERMINATING` on failure. - `RUNNING` — nothing to do. - `TERMINATING` — call backend to destroy the cloud instance, transition to `TERMINATED`. - `TERMINATED` — nothing to do. - also `SUBMITTED`, `PROVISIONING`, or `RUNNING` — transition to `TERMINATING` or `TERMINATED` if the gateway is `FAILED` or gateway deletion is requested. --- mkdocs/docs/concepts/gateways.md | 8 +- src/dstack/_internal/cli/utils/gateway.py | 6 + src/dstack/_internal/core/models/gateways.py | 17 +- .../background/pipeline_tasks/__init__.py | 4 + .../pipeline_tasks/gateway_replicas.py | 662 ++++++++++++ .../background/pipeline_tasks/gateways.py | 250 +---- .../background/scheduled_tasks/gateways.py | 2 +- .../server/compatibility/gateways.py | 10 +- ...7d8fa7fcc5_add_gateway_replica_pipeline.py | 120 +++ ...dd_ix_gateway_computes_pipeline_fetch_q.py | 45 + src/dstack/_internal/server/models.py | 34 +- .../server/services/backends/__init__.py | 21 + .../server/services/gateways/__init__.py | 107 +- src/dstack/_internal/server/testing/common.py | 18 +- .../pipeline_tasks/test_gateway_replicas.py | 951 ++++++++++++++++++ .../pipeline_tasks/test_gateways.py | 514 ++++------ .../server/compatibility/test_gateways.py | 2 + .../_internal/server/routers/test_gateways.py | 8 + 18 files changed, 2212 insertions(+), 567 deletions(-) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/gateway_replicas.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/06_19_0709_857d8fa7fcc5_add_gateway_replica_pipeline.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/06_24_1626_e9c5e7e26c78_add_ix_gateway_computes_pipeline_fetch_q.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_gateway_replicas.py diff --git a/mkdocs/docs/concepts/gateways.md b/mkdocs/docs/concepts/gateways.md index b71a23d7b6..a959ab104c 100644 --- a/mkdocs/docs/concepts/gateways.md +++ b/mkdocs/docs/concepts/gateways.md @@ -43,8 +43,8 @@ The example-gateway doesn't exist. Create it? [y/n]: y Provisioning... ---> 100% - BACKEND REGION NAME HOSTNAME DOMAIN DEFAULT STATUS - aws eu-west-1 example-gateway example.com ✓ submitted + NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS + example-gateway aws (eu-west-1) 34.244.128.46 example.com ✓ running ``` @@ -211,8 +211,8 @@ To balance requests between gateway replicas, add DNS records for each replica o $ dstack gateway list NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS example-gateway example.com ✓ running - replica=0 aws (eu-west-1) 34.244.128.46 - replica=1 aws (eu-west-1) 18.201.201.174 + replica=0 aws (eu-west-1) 34.244.128.46 running + replica=1 aws (eu-west-1) 18.201.201.174 running ``` diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py index 0d873a9a5d..4c80aaaa8c 100644 --- a/src/dstack/_internal/cli/utils/gateway.py +++ b/src/dstack/_internal/cli/utils/gateway.py @@ -118,6 +118,10 @@ def get_gateways_table( gateway.replicas[0].backend, gateway.replicas[0].region ) gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.replicas[0].hostname) + gateway_row["STATUS"] = gateway.replicas[0].status or gateway.status + gateway_row["ERROR"] = ". ".join( + m for m in [gateway.status_message, gateway.replicas[0].status_message] if m + ) add_row_from_dict(table, gateway_row) if len(gateway.replicas) > 1: @@ -126,7 +130,9 @@ def get_gateways_table( "NAME": f" replica={replica.replica_num}", "BACKEND": format_backend(replica.backend, replica.region), "HOSTNAME": replica.hostname, + "STATUS": replica.status, "CREATED": format_date(replica.created_at), + "ERROR": replica.status_message, } add_row_from_dict(table, replica_row, style="secondary") diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 74b3f4e835..1bd97eea45 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -21,6 +21,14 @@ class GatewayStatus(str, Enum): FAILED = "failed" +class GatewayReplicaStatus(str, Enum): + SUBMITTED = "submitted" + PROVISIONING = "provisioning" + RUNNING = "running" + TERMINATING = "terminating" + TERMINATED = "terminated" + + class LetsEncryptGatewayCertificate(CoreModel): type: Annotated[ Literal["lets-encrypt"], Field(description="Automatic certificates by Let's Encrypt") @@ -119,11 +127,14 @@ class GatewaySpec(CoreModel): class GatewayReplica(CoreModel): - hostname: str + hostname: Optional[str] = None replica_num: int - backend: BackendType - region: str + backend: Optional[BackendType] = None + region: Optional[str] = None created_at: datetime.datetime + status: Optional[GatewayReplicaStatus] = None + """`status` is only optional on the client side for compatibility with 0.20.25 servers""" + status_message: Optional[str] = None class Gateway(CoreModel): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index a5e5164792..2e83e780c0 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -3,6 +3,9 @@ from dstack._internal.server.background.pipeline_tasks.base import Pipeline from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline +from dstack._internal.server.background.pipeline_tasks.gateway_replicas import ( + GatewayReplicaPipeline, +) from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline from dstack._internal.server.background.pipeline_tasks.jobs_running import JobRunningPipeline @@ -33,6 +36,7 @@ def __init__(self) -> None: ComputeGroupPipeline(pipeline_hinter=self._hinter), FleetPipeline(pipeline_hinter=self._hinter), GatewayPipeline(pipeline_hinter=self._hinter), + GatewayReplicaPipeline(pipeline_hinter=self._hinter), JobSubmittedPipeline(pipeline_hinter=self._hinter), JobRunningPipeline(pipeline_hinter=self._hinter), JobTerminatingPipeline(pipeline_hinter=self._hinter), diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateway_replicas.py b/src/dstack/_internal/server/background/pipeline_tasks/gateway_replicas.py new file mode 100644 index 0000000000..0e1215c7a1 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateway_replicas.py @@ -0,0 +1,662 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Optional, Sequence + +from sqlalchemy import and_, or_, select, update +from sqlalchemy.orm import InstrumentedAttribute, joinedload, load_only +from sqlalchemy.sql.base import ExecutableOption + +from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport +from dstack._internal.core.errors import BackendError, BackendNotAvailable +from dstack._internal.core.models.gateways import GatewayReplicaStatus, GatewayStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + ItemUpdateMap, + Pipeline, + PipelineItem, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + BackendModel, + GatewayComputeModel, + GatewayModel, + ProjectModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import gateways as gateways_services +from dstack._internal.server.services.gateways import get_gateway_compute_configuration +from dstack._internal.server.services.gateways.pool import gateway_connections_pool +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.pipelines import PipelineHinterProtocol +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class GatewayReplicaPipelineItem(PipelineItem): + status: GatewayReplicaStatus + + +class GatewayReplicaPipeline(Pipeline[GatewayReplicaPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + *, + pipeline_hinter: PipelineHinterProtocol, + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[GatewayReplicaPipelineItem]( + model_type=GatewayComputeModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = GatewayReplicaFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + GatewayReplicaWorker( + queue=self._queue, + heartbeater=self._heartbeater, + pipeline_hinter=pipeline_hinter, + ) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return GatewayComputeModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[GatewayReplicaPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[GatewayReplicaPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["GatewayReplicaWorker"]: + return self.__workers + + +class GatewayReplicaFetcher(Fetcher[GatewayReplicaPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[GatewayReplicaPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[GatewayReplicaPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_pipeline_task("GatewayReplicaFetcher.fetch") + async def fetch(self, limit: int) -> list[GatewayReplicaPipelineItem]: + replica_lock, _ = get_locker(get_db().dialect_name).get_lockset( + GatewayComputeModel.__tablename__ + ) + async with replica_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(GatewayComputeModel) + .outerjoin( + GatewayModel, + or_( + GatewayModel.id == GatewayComputeModel.gateway_id, + GatewayModel.gateway_compute_id == GatewayComputeModel.id, + ), + ) + .where( + or_( + GatewayComputeModel.status.in_( + [ + GatewayReplicaStatus.SUBMITTED, + GatewayReplicaStatus.PROVISIONING, + GatewayReplicaStatus.TERMINATING, + ] + ), + and_( + GatewayComputeModel.status == GatewayReplicaStatus.RUNNING, + or_( + GatewayModel.to_be_deleted == True, + GatewayModel.status == GatewayStatus.FAILED, + # Gateway was hard-deleted (unexpected, fetch to log an error) + GatewayModel.id.is_(None), + ), + ), + ), + or_( + GatewayComputeModel.last_processed_at + <= now - self._min_processing_interval, + GatewayComputeModel.last_processed_at + == GatewayComputeModel.created_at, + ), + or_( + GatewayComputeModel.lock_expires_at.is_(None), + GatewayComputeModel.lock_expires_at < now, + ), + or_( + GatewayComputeModel.lock_owner.is_(None), + GatewayComputeModel.lock_owner == GatewayReplicaPipeline.__name__, + ), + ) + .order_by(GatewayComputeModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=GatewayComputeModel) + .options( + load_only( + GatewayComputeModel.id, + GatewayComputeModel.lock_token, + GatewayComputeModel.lock_expires_at, + GatewayComputeModel.status, + ) + ) + ) + replica_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for replica_model in replica_models: + prev_lock_expired = replica_model.lock_expires_at is not None + replica_model.lock_expires_at = lock_expires_at + replica_model.lock_token = lock_token + replica_model.lock_owner = GatewayReplicaPipeline.__name__ + items.append( + GatewayReplicaPipelineItem( + __tablename__=GatewayComputeModel.__tablename__, + id=replica_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=replica_model.status, + ) + ) + await session.commit() + return items + + +class GatewayReplicaWorker(Worker[GatewayReplicaPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[GatewayReplicaPipelineItem], + heartbeater: Heartbeater[GatewayReplicaPipelineItem], + pipeline_hinter: PipelineHinterProtocol, + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + pipeline_hinter=pipeline_hinter, + ) + + @sentry_utils.instrument_pipeline_task("GatewayReplicaWorker.process") + async def process(self, item: GatewayReplicaPipelineItem): + if item.status == GatewayReplicaStatus.SUBMITTED: + await _process_submitted_item(item) + elif item.status == GatewayReplicaStatus.PROVISIONING: + await _process_provisioning_item(item) + elif item.status == GatewayReplicaStatus.RUNNING: + await _process_running_item(item) + elif item.status == GatewayReplicaStatus.TERMINATING: + await _process_terminating_item(item) + + +class _GatewayReplicaUpdateMap(ItemUpdateMap, total=False): + status: GatewayReplicaStatus + status_message: Optional[str] + active: bool + deleted: bool + instance_id: Optional[str] + ip_address: Optional[str] + region: Optional[str] + hostname: Optional[str] + backend_data: Optional[str] + + +_REPLICA_FIELDS_MIN: list[InstrumentedAttribute[Any]] = [ + GatewayComputeModel.id, + GatewayComputeModel.lock_token, + GatewayComputeModel.status, + GatewayComputeModel.replica_num, +] + +_GATEWAY_FIELDS_MIN: list[InstrumentedAttribute[Any]] = [ + GatewayModel.id, + GatewayModel.name, + GatewayModel.to_be_deleted, + GatewayModel.status, +] + + +async def _load_gateway_replica( + item: GatewayReplicaPipelineItem, + replica_fields: list[InstrumentedAttribute[Any]], + gateway_fields: list[InstrumentedAttribute[Any]], + load_backends: bool = False, + load_gateway_backend_type: bool = False, +) -> Optional[GatewayComputeModel]: + def build_gateway_options( + gateway_attr: InstrumentedAttribute[GatewayModel | None], + ) -> list[ExecutableOption]: + gateway_load = joinedload(gateway_attr).load_only(*gateway_fields) + options: list[ExecutableOption] = [gateway_load] + if load_backends: + options.append( + gateway_load.joinedload(GatewayModel.project).selectinload(ProjectModel.backends) + ) + if load_gateway_backend_type: + options.append( + gateway_load.joinedload(GatewayModel.backend).load_only(BackendModel.type) + ) + return options + + async with get_session_ctx() as session: + stmt = ( + select(GatewayComputeModel) + .where( + GatewayComputeModel.id == item.id, + GatewayComputeModel.lock_token == item.lock_token, + ) + .options( + load_only(*replica_fields), + *build_gateway_options(GatewayComputeModel.gateway), + *build_gateway_options(GatewayComputeModel.legacy_gateway), + ) + ) + res = await session.execute(stmt) + replica_model = res.unique().scalar_one_or_none() + + if replica_model is None: + log_lock_token_mismatch(logger, item) + return None + return replica_model + + +def _get_loaded_gateway_model(replica_model: GatewayComputeModel) -> Optional[GatewayModel]: + gateway_model = replica_model.gateway or replica_model.legacy_gateway + if gateway_model is None: + logger.error("Gateway replica %s is not attached to a gateway", replica_model.id) + return gateway_model + + +def _mark_terminating_if_gateway_terminating( + gateway_model: GatewayModel, replica_model: GatewayComputeModel +) -> Optional[_GatewayReplicaUpdateMap]: + if gateway_model.to_be_deleted or gateway_model.status == GatewayStatus.FAILED: + if replica_model.status == GatewayReplicaStatus.SUBMITTED: + new_status = GatewayReplicaStatus.TERMINATED + deleted = True + else: + new_status = GatewayReplicaStatus.TERMINATING + deleted = False + logger.info( + "%s replica %d: marked %s, gateway is being deleted or failed", + fmt(gateway_model), + replica_model.replica_num, + new_status.value, + ) + return _GatewayReplicaUpdateMap(status=new_status, active=False, deleted=deleted) + return None + + +async def _commit_update( + item: GatewayReplicaPipelineItem, + replica_model: GatewayComputeModel, + update_map: _GatewayReplicaUpdateMap, +) -> None: + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(update_map, now=now) + res = await session.execute( + update(GatewayComputeModel) + .where( + GatewayComputeModel.id == replica_model.id, + GatewayComputeModel.lock_token == replica_model.lock_token, + ) + .values(**update_map) + .returning(GatewayComputeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + + +async def _process_submitted_item(item: GatewayReplicaPipelineItem): + replica_model = await _load_gateway_replica( + item, + replica_fields=_REPLICA_FIELDS_MIN + + [ + GatewayComputeModel.backend_id, + GatewayComputeModel.configuration, + GatewayComputeModel.ssh_public_key, + ], + gateway_fields=_GATEWAY_FIELDS_MIN + + [ + GatewayModel.configuration, + GatewayModel.region, + GatewayModel.wildcard_domain, + ], + load_backends=True, + load_gateway_backend_type=True, + ) + if replica_model is None: + return + gateway_model = _get_loaded_gateway_model(replica_model) + if gateway_model is None: + await _commit_update(item, replica_model, update_map={}) + return + if update_map := _mark_terminating_if_gateway_terminating(gateway_model, replica_model): + await _commit_update(item, replica_model, update_map=update_map) + return + update_map = await _provision_gateway_replica(gateway_model, replica_model) + await _commit_update(item, replica_model, update_map) + + +async def _provision_gateway_replica( + gateway_model: GatewayModel, + replica_model: GatewayComputeModel, +) -> _GatewayReplicaUpdateMap: + try: + if replica_model.backend_id is None: # unexpected + raise BackendNotAvailable() + (_, backend) = await backends_services.get_project_backend_with_model_by_id_or_error( + project=gateway_model.project, backend_id=replica_model.backend_id + ) + except BackendNotAvailable: + logger.warning( + "%s replica %d: backend not available", + fmt(gateway_model), + replica_model.replica_num, + ) + return _GatewayReplicaUpdateMap( + status=GatewayReplicaStatus.TERMINATED, + active=False, + deleted=True, + ) + + compute = backend.compute() + assert isinstance(compute, ComputeWithGatewaySupport) + compute_configuration = get_gateway_compute_configuration(replica_model, gateway_model) + + logger.debug( + "%s replica %d: creating gateway compute", + fmt(gateway_model), + replica_model.replica_num, + ) + try: + gpd = await run_async(compute.create_gateway, compute_configuration) + except BackendError as e: + status_message = f"Backend error: {repr(e)}" + if len(e.args) > 0: + status_message = str(e.args[0]) + logger.warning( + "%s replica %d: failed to create gateway compute: %s", + fmt(gateway_model), + replica_model.replica_num, + status_message, + ) + return _GatewayReplicaUpdateMap( + status=GatewayReplicaStatus.TERMINATED, + status_message=status_message, + active=False, + deleted=True, + ) + except Exception: + logger.exception( + "%s replica %d: unexpected error when creating gateway compute", + fmt(gateway_model), + replica_model.replica_num, + ) + return _GatewayReplicaUpdateMap( + status=GatewayReplicaStatus.TERMINATED, + status_message="Unexpected error", + active=False, + deleted=True, + ) + + logger.info( + "%s replica %d: gateway compute created", + fmt(gateway_model), + replica_model.replica_num, + ) + return _GatewayReplicaUpdateMap( + status=GatewayReplicaStatus.PROVISIONING, + active=True, + instance_id=gpd.instance_id, + ip_address=gpd.ip_address, + region=gpd.region, + hostname=gpd.hostname, + backend_data=gpd.backend_data, + ) + + +async def _process_provisioning_item(item: GatewayReplicaPipelineItem): + replica_model = await _load_gateway_replica( + item, + replica_fields=_REPLICA_FIELDS_MIN + + [ + GatewayComputeModel.ip_address, + GatewayComputeModel.ssh_private_key, + ], + gateway_fields=_GATEWAY_FIELDS_MIN, + ) + if replica_model is None: + return + gateway_model = _get_loaded_gateway_model(replica_model) + if gateway_model is None: + await _commit_update(item, replica_model, update_map={}) + return + if update_map := _mark_terminating_if_gateway_terminating(gateway_model, replica_model): + await _commit_update(item, replica_model, update_map=update_map) + return + error = await _connect_and_configure_gateway_replica(gateway_model, replica_model) + if error is None: + logger.info( + "%s replica %d: running", + fmt(gateway_model), + replica_model.replica_num, + ) + update_map = _GatewayReplicaUpdateMap(status=GatewayReplicaStatus.RUNNING, active=True) + else: + logger.warning( + "%s replica %d: provisioning failed: %s", + fmt(gateway_model), + replica_model.replica_num, + error, + ) + update_map = _GatewayReplicaUpdateMap( + status=GatewayReplicaStatus.TERMINATING, status_message=error, active=False + ) + await _commit_update(item, replica_model, update_map) + + +async def _connect_and_configure_gateway_replica( + gateway_model: GatewayModel, + gateway_compute: GatewayComputeModel, +) -> Optional[str]: + """Returns an error message on failure, None on success.""" + logger.debug( + "%s replica %d: connecting to gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, + ) + # TODO: do only one connection/configuration attempt per pipeline tick. + # Blocking on connect_to_gateway_with_retry and configure_gateway now has these cons: + # - cannot terminate the gateway replica before it is provisioned because the DB model is locked + # - connection retry counter is reset on server restart + # - only one server replica is processing the gateway replica + connection = await gateways_services.connect_to_gateway_with_retry(gateway_compute) + if connection is None: + logger.warning( + "%s replica %d: failed to connect to gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, + ) + return "Failed to connect to gateway" + try: + await gateways_services.configure_gateway(connection) + except Exception: + logger.exception( + "%s replica %d: failed to configure gateway", + fmt(gateway_model), + gateway_compute.replica_num, + ) + return "Failed to configure gateway" + logger.info( + "%s replica %d: gateway compute connected and configured", + fmt(gateway_model), + gateway_compute.replica_num, + ) + return None + + +async def _process_running_item(item: GatewayReplicaPipelineItem): + replica_model = await _load_gateway_replica( + item, + replica_fields=_REPLICA_FIELDS_MIN, + gateway_fields=_GATEWAY_FIELDS_MIN, + ) + if replica_model is None: + return + gateway_model = _get_loaded_gateway_model(replica_model) + if gateway_model is None: + await _commit_update(item, replica_model, update_map={}) + return + if update_map := _mark_terminating_if_gateway_terminating(gateway_model, replica_model): + await _commit_update(item, replica_model, update_map=update_map) + return + logger.warning( + "%s replica %d: nothing to do in this pipeline tick", + fmt(gateway_model), + replica_model.replica_num, + ) + await _commit_update(item, replica_model, update_map={}) + + +async def _process_terminating_item(item: GatewayReplicaPipelineItem): + replica_model = await _load_gateway_replica( + item, + replica_fields=_REPLICA_FIELDS_MIN + + [ + GatewayComputeModel.instance_id, + GatewayComputeModel.ip_address, + GatewayComputeModel.backend_id, + GatewayComputeModel.configuration, + GatewayComputeModel.backend_data, + GatewayComputeModel.ssh_public_key, + ], + gateway_fields=_GATEWAY_FIELDS_MIN + + [ + GatewayModel.configuration, + GatewayModel.region, + GatewayModel.wildcard_domain, + ], + load_backends=True, + load_gateway_backend_type=True, + ) + if replica_model is None: + return + gateway_model = _get_loaded_gateway_model(replica_model) + if gateway_model is None: + await _commit_update(item, replica_model, update_map={}) + return + mark_terminated_update_map = _GatewayReplicaUpdateMap( + status=GatewayReplicaStatus.TERMINATED, active=False, deleted=True + ) + try: + if replica_model.backend_id is None: # unexpected + raise BackendNotAvailable() + (_, backend) = await backends_services.get_project_backend_with_model_by_id_or_error( + project=gateway_model.project, + backend_id=replica_model.backend_id, + ) + except BackendNotAvailable: + logger.error( + "%s replica %d: backend not available, cannot terminate. Marking TERMINATED without termination", + fmt(gateway_model), + replica_model.replica_num, + ) + await _commit_update(item, replica_model, mark_terminated_update_map) + return + compute = backend.compute() + assert isinstance(compute, ComputeWithGatewaySupport) + compute_configuration = get_gateway_compute_configuration(replica_model, gateway_model) + if replica_model.instance_id is None: + logger.warning( + "%s replica %d: instance_id is None, skipping gateway replica termination", + fmt(gateway_model), + replica_model.replica_num, + ) + await _commit_update(item, replica_model, mark_terminated_update_map) + return + + logger.debug( + "%s replica %d: terminating gateway compute", + fmt(gateway_model), + replica_model.replica_num, + ) + try: + await run_async( + compute.terminate_gateway, + replica_model.instance_id, + compute_configuration, + replica_model.backend_data, + ) + except Exception: + logger.exception( + "%s replica %d: error when terminating gateway compute", + fmt(gateway_model), + replica_model.replica_num, + ) + await _commit_update(item, replica_model, update_map={}) + return + + logger.info( + "%s replica %d: gateway compute terminated", + fmt(gateway_model), + replica_model.replica_num, + ) + + if replica_model.ip_address is not None: + await gateway_connections_pool.remove(replica_model.ip_address) + + await _commit_update(item, replica_model, mark_terminated_update_map) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 5c834c852a..8e71568ccc 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -2,14 +2,17 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Optional, Sequence, TypedDict +from typing import Sequence from sqlalchemy import delete, or_, select, update from sqlalchemy.orm import joinedload, load_only, selectinload -from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport -from dstack._internal.core.errors import BackendError, BackendNotAvailable -from dstack._internal.core.models.gateways import GATEWAY_REPLICAS_DEFAULT, GatewayStatus +from dstack._internal.core.errors import BackendNotAvailable +from dstack._internal.core.models.gateways import ( + GATEWAY_REPLICAS_DEFAULT, + GatewayReplicaStatus, + GatewayStatus, +) from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, @@ -37,12 +40,11 @@ emit_gateway_status_change_event, get_gateway_compute_models, ) -from dstack._internal.server.services.gateways.pool import gateway_connections_pool from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.utils import sentry_utils -from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -257,7 +259,6 @@ async def _process_submitted_item(item: GatewayPipelineItem): updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) - # TODO: Clean up gateway_compute_models. return emit_gateway_status_change_event( session=session, @@ -273,11 +274,6 @@ class _GatewayUpdateMap(ItemUpdateMap, total=False): status_message: str -class _GatewayComputeUpdateMap(TypedDict, total=False): - active: bool - deleted: bool - - @dataclass class _SubmittedResult: update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) @@ -285,12 +281,11 @@ class _SubmittedResult: async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedResult: - logger.info("%s: started gateway provisioning", fmt(gateway_model)) configuration = gateways_services.get_gateway_configuration(gateway_model) try: ( backend_model, - backend, + _, ) = await backends_services.get_project_backend_with_model_by_type_or_error( project=gateway_model.project, backend_type=configuration.backend ) @@ -301,49 +296,30 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR "status_message": "Backend not available", } ) + # NOTE: On a later stage of #3959, the SUBMITTED status may also be responsible for + # setting up the load balancer (e.g., AWS ALB) before replicas are created. replicas = ( configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT ) gateway_compute_models = [] - try: - for replica_num in range(replicas): - logger.debug( - "%s replica %d: creating gateway compute", fmt(gateway_model), replica_num - ) - gateway_compute_model = await gateways_services.create_gateway_compute( - backend_compute=backend.compute(), - project_name=gateway_model.project.name, - configuration=configuration, - replica_num=replica_num, - gateway_id=gateway_model.id, - backend_id=backend_model.id, - ) - logger.info("%s replica %d: gateway compute created", fmt(gateway_model), replica_num) - gateway_compute_models.append(gateway_compute_model) - return _SubmittedResult( - update_map={"status": GatewayStatus.PROVISIONING}, - gateway_compute_models=gateway_compute_models, - ) - except BackendError as e: - status_message = f"Backend error: {repr(e)}" - if len(e.args) > 0: - status_message = str(e.args[0]) - return _SubmittedResult( - update_map={ - "status": GatewayStatus.FAILED, - "status_message": status_message, - }, - gateway_compute_models=gateway_compute_models, - ) - except Exception as e: - logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) - return _SubmittedResult( - update_map={ - "status": GatewayStatus.FAILED, - "status_message": f"Unexpected error: {repr(e)}", - }, - gateway_compute_models=gateway_compute_models, + for replica_num in range(replicas): + gateway_compute_model = gateways_services.create_gateway_compute_model( + project_name=gateway_model.project.name, + configuration=configuration, + replica_num=replica_num, + gateway_id=gateway_model.id, + backend_id=backend_model.id, ) + gateway_compute_models.append(gateway_compute_model) + logger.info( + "%s: created %d replica record(s) in submitted state", + fmt(gateway_model), + len(gateway_compute_models), + ) + return _SubmittedResult( + update_map={"status": GatewayStatus.PROVISIONING}, + gateway_compute_models=gateway_compute_models, + ) async def _process_provisioning_item(item: GatewayPipelineItem): @@ -355,14 +331,18 @@ async def _process_provisioning_item(item: GatewayPipelineItem): GatewayModel.lock_token == item.lock_token, ) .options(joinedload(GatewayModel.gateway_compute)) - .options(selectinload(GatewayModel.gateway_computes)) + .options( + selectinload(GatewayModel.gateway_computes).load_only( + GatewayComputeModel.id, GatewayComputeModel.status + ) + ) ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: log_lock_token_mismatch(logger, item) return - result = await _process_provisioning_gateway(gateway_model) + result = _process_provisioning_gateway(gateway_model) gateway_update_map = result.gateway_update_map set_processed_update_map_fields(gateway_update_map) set_unlock_update_map_fields(gateway_update_map) @@ -390,97 +370,35 @@ async def _process_provisioning_item(item: GatewayPipelineItem): new_status=gateway_update_map.get("status", gateway_model.status), status_message=gateway_update_map.get("status_message", gateway_model.status_message), ) - if result.all_computes_update_map: - res = await session.execute( - update(GatewayComputeModel) - .where( - or_( - GatewayComputeModel.gateway_id == gateway_model.id, - GatewayComputeModel.id == gateway_model.gateway_compute_id, - ) - ) - .values(**result.all_computes_update_map) - .returning(GatewayComputeModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) < len(get_gateway_compute_models(gateway_model)): - logger.error( - "Failed to update compute models for gateway %s." - " This is unexpected and may happen only if the compute model was manually deleted.", - gateway_model.id, - ) @dataclass class _ProvisioningResult: gateway_update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) - all_computes_update_map: _GatewayComputeUpdateMap = field( - default_factory=_GatewayComputeUpdateMap - ) -async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: +def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: gateway_computes = get_gateway_compute_models(gateway_model) # Provisioning gateways must have compute. assert len(gateway_computes) > 0 - # TODO: do only one connection/configuration attempt per pipeline tick. - # Blocking on connect_to_gateway_with_retry and configure_gateway now has these cons: - # - cannot delete the gateway before it is provisioned because the DB model is locked - # - connection retry counter is reset on server restart - # - only one server replica is processing the gateway + statuses = {gc.status for gc in gateway_computes} - errors = await asyncio.gather( - *(_connect_and_configure_gateway_replica(gateway_model, gc) for gc in gateway_computes) - ) - if any(errors): + if statuses & {GatewayReplicaStatus.TERMINATING, GatewayReplicaStatus.TERMINATED}: return _ProvisioningResult( gateway_update_map={ "status": GatewayStatus.FAILED, - "status_message": next(e for e in errors if e), + "status_message": "Failed to provision gateway replica", }, - all_computes_update_map={"active": False}, ) - return _ProvisioningResult( - gateway_update_map={"status": GatewayStatus.RUNNING}, - ) - - -async def _connect_and_configure_gateway_replica( - gateway_model: GatewayModel, - gateway_compute: GatewayComputeModel, -) -> Optional[str]: - """Returns an error message on failure, None on success.""" - logger.debug( - "%s replica %d: connecting to gateway compute", - fmt(gateway_model), - gateway_compute.replica_num, - ) - connection = await gateways_services.connect_to_gateway_with_retry(gateway_compute) - if connection is None: - logger.warning( - "%s replica %d: failed to connect to gateway compute", - fmt(gateway_model), - gateway_compute.replica_num, - ) - return "Failed to connect to gateway" - try: - await gateways_services.configure_gateway(connection) - except Exception: - logger.exception( - "%s replica %d: failed to configure gateway", - fmt(gateway_model), - gateway_compute.replica_num, + if statuses == {GatewayReplicaStatus.RUNNING}: + return _ProvisioningResult( + gateway_update_map={"status": GatewayStatus.RUNNING}, ) - await gateway_connections_pool.remove(gateway_compute.ip_address) - return "Failed to configure gateway" - logger.info( - "%s replica %d: gateway compute connected and configured", - fmt(gateway_model), - gateway_compute.replica_num, - ) - return None + + # Replicas are still being provisioned + return _ProvisioningResult() async def _process_to_be_deleted_item(item: GatewayPipelineItem): @@ -491,39 +409,20 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): GatewayModel.id == item.id, GatewayModel.lock_token == item.lock_token, ) - .options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends)) .options(joinedload(GatewayModel.gateway_compute)) - .options(selectinload(GatewayModel.gateway_computes)) - .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + .options( + selectinload(GatewayModel.gateway_computes).load_only( + GatewayComputeModel.id, GatewayComputeModel.status + ) + ) ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: log_lock_token_mismatch(logger, item) return - result = await _process_to_be_deleted_gateway(gateway_model) + result = _process_to_be_deleted_gateway(gateway_model) async with get_session_ctx() as session: - if result.all_computes_update_map: - res = await session.execute( - update(GatewayComputeModel) - .where( - or_( - GatewayComputeModel.gateway_id == gateway_model.id, - GatewayComputeModel.id == gateway_model.gateway_compute_id, - ) - ) - .values(**result.all_computes_update_map) - .returning(GatewayComputeModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) < len(get_gateway_compute_models(gateway_model)): - logger.error( - "Failed to update compute models for gateway %s." - " This is unexpected and may happen only if the compute model was manually deleted.", - gateway_model.id, - ) - return - if result.delete_gateway: res = await session.execute( delete(GatewayModel) @@ -571,50 +470,9 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): @dataclass class _ProcessToBeDeletedResult: delete_gateway: bool - all_computes_update_map: _GatewayComputeUpdateMap = field( - default_factory=_GatewayComputeUpdateMap - ) - -async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _ProcessToBeDeletedResult: - backend = await backends_services.get_project_backend_by_type_or_error( - project=gateway_model.project, backend_type=gateway_model.backend.type - ) - compute = backend.compute() - assert isinstance(compute, ComputeWithGatewaySupport) - for gateway_compute in get_gateway_compute_models(gateway_model): - gateway_compute_configuration = gateways_services.get_gateway_compute_configuration( - gateway_compute=gateway_compute, - gateway_model=gateway_model, - ) - logger.debug( - "%s replica %d: terminating gateway compute", - fmt(gateway_model), - gateway_compute.replica_num, - ) - try: - await run_async( - compute.terminate_gateway, - gateway_compute.instance_id, - gateway_compute_configuration, - gateway_compute.backend_data, - ) - except Exception: - logger.exception( - "%s replica %d: error when terminating gateway compute", - fmt(gateway_model), - gateway_compute.replica_num, - ) - return _ProcessToBeDeletedResult(delete_gateway=False) - logger.info( - "%s replica %d: gateway compute terminated", - fmt(gateway_model), - gateway_compute.replica_num, - ) - await gateway_connections_pool.remove(gateway_compute.ip_address) - - return _ProcessToBeDeletedResult( - delete_gateway=True, - all_computes_update_map={"active": False, "deleted": True}, - ) +def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _ProcessToBeDeletedResult: + gateway_computes = get_gateway_compute_models(gateway_model) + all_terminated = all(gc.status == GatewayReplicaStatus.TERMINATED for gc in gateway_computes) + return _ProcessToBeDeletedResult(delete_gateway=all_terminated) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py index f71ecacc0d..97dcb81646 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py @@ -33,7 +33,7 @@ async def _remove_inactive_connections(): res = await session.execute( select(GatewayComputeModel.ip_address).where(GatewayComputeModel.active == True) ) - active_connection_ips = set(res.scalars().all()) + active_connection_ips = {ip for ip in res.scalars().all() if ip is not None} for conn in await gateway_connections_pool.all(): if conn.ip_address not in active_connection_ips: await gateway_connections_pool.remove(conn.ip_address) diff --git a/src/dstack/_internal/server/compatibility/gateways.py b/src/dstack/_internal/server/compatibility/gateways.py index 653f3d6e22..f9c008eca6 100644 --- a/src/dstack/_internal/server/compatibility/gateways.py +++ b/src/dstack/_internal/server/compatibility/gateways.py @@ -10,6 +10,14 @@ def patch_gateway(gateway: Gateway, client_version: Optional[Version]) -> None: return if client_version < Version("0.20.25"): gateway.instance_id = "" - gateway.ip_address = "\n".join(r.hostname for r in gateway.replicas) + gateway.ip_address = "\n".join(r.hostname for r in gateway.replicas if r.hostname) if gateway.hostname is None: gateway.hostname = gateway.ip_address + if client_version == Version("0.20.25"): + for replica in gateway.replicas: + if replica.hostname is None: + replica.hostname = "" + if replica.region is None: + replica.region = "" + if replica.backend is None: + replica.backend = gateway.configuration.backend diff --git a/src/dstack/_internal/server/migrations/versions/2026/06_19_0709_857d8fa7fcc5_add_gateway_replica_pipeline.py b/src/dstack/_internal/server/migrations/versions/2026/06_19_0709_857d8fa7fcc5_add_gateway_replica_pipeline.py new file mode 100644 index 0000000000..4ebe0d961f --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/06_19_0709_857d8fa7fcc5_add_gateway_replica_pipeline.py @@ -0,0 +1,120 @@ +"""Add gateway replica pipeline + +Revision ID: 857d8fa7fcc5 +Revises: b7609b94ea4d +Create Date: 2026-06-19 07:09:26.989255+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "857d8fa7fcc5" +down_revision = "b7609b94ea4d" +branch_labels = None +depends_on = None + +# partial definition for queries +gateway_computes = sa.table( + "gateway_computes", + sa.column("id"), + sa.column("gateway_id"), + sa.column("last_processed_at", sa.DateTime()), + sa.column("created_at", sa.DateTime()), + sa.column("status", sa.String(100)), + sa.column("deleted", sa.Boolean()), + sa.column("active", sa.Boolean()), + sa.column("region", sa.String(100)), + sa.column("ip_address", sa.String(100)), + sa.column("instance_id", sa.String(100)), +) +gateways = sa.table( + "gateways", + sa.column("id"), + sa.column("gateway_compute_id"), + sa.column("status", sa.String(100)), +) + + +def upgrade() -> None: + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "last_processed_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column(sa.Column("status", sa.String(length=100), nullable=True)) + batch_op.add_column(sa.Column("status_message", sa.Text(), nullable=True)) + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + batch_op.alter_column("instance_id", existing_type=sa.VARCHAR(length=100), nullable=True) + batch_op.alter_column("ip_address", existing_type=sa.VARCHAR(length=100), nullable=True) + batch_op.alter_column("region", existing_type=sa.VARCHAR(length=100), nullable=True) + + op.execute(sa.update(gateway_computes).values(last_processed_at=gateway_computes.c.created_at)) + + gateway_is_provisioning = sa.exists( + sa.select(sa.literal(1)) + .select_from(gateways) + .where( + sa.or_( + gateways.c.id == gateway_computes.c.gateway_id, + gateways.c.gateway_compute_id == gateway_computes.c.id, + ), + gateways.c.status == "PROVISIONING", + ) + ) + op.execute( + sa.update(gateway_computes).values( + status=sa.case( + (gateway_computes.c.deleted == True, "TERMINATED"), + (gateway_computes.c.active == False, "TERMINATING"), + (gateway_is_provisioning, "PROVISIONING"), + else_="RUNNING", + ) + ) + ) + + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.alter_column( + "last_processed_at", + existing_type=dstack._internal.server.models.NaiveDateTime(), + nullable=False, + ) + batch_op.alter_column("status", existing_type=sa.String(100), nullable=False) + + +def downgrade() -> None: + op.execute( + sa.delete(gateway_computes).where( + sa.or_( + gateway_computes.c.status == "SUBMITTED", + gateway_computes.c.region.is_(None), + gateway_computes.c.ip_address.is_(None), + gateway_computes.c.instance_id.is_(None), + ) + ) + ) + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.alter_column("region", existing_type=sa.VARCHAR(length=100), nullable=False) + batch_op.alter_column("ip_address", existing_type=sa.VARCHAR(length=100), nullable=False) + batch_op.alter_column("instance_id", existing_type=sa.VARCHAR(length=100), nullable=False) + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + batch_op.drop_column("status_message") + batch_op.drop_column("status") + batch_op.drop_column("last_processed_at") diff --git a/src/dstack/_internal/server/migrations/versions/2026/06_24_1626_e9c5e7e26c78_add_ix_gateway_computes_pipeline_fetch_q.py b/src/dstack/_internal/server/migrations/versions/2026/06_24_1626_e9c5e7e26c78_add_ix_gateway_computes_pipeline_fetch_q.py new file mode 100644 index 0000000000..4a44b11f2e --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/06_24_1626_e9c5e7e26c78_add_ix_gateway_computes_pipeline_fetch_q.py @@ -0,0 +1,45 @@ +"""Add ix_gateway_computes_pipeline_fetch_q + +Revision ID: e9c5e7e26c78 +Revises: 857d8fa7fcc5 +Create Date: 2026-06-24 16:26:22.834262+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "e9c5e7e26c78" +down_revision = "857d8fa7fcc5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.get_context().autocommit_block(): + op.drop_index( + "ix_gateway_computes_pipeline_fetch_q", + table_name="gateway_computes", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_gateway_computes_pipeline_fetch_q", + "gateway_computes", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("deleted = 0"), + postgresql_where=sa.text("deleted IS FALSE"), + postgresql_concurrently=True, + ) + + +def downgrade() -> None: + with op.get_context().autocommit_block(): + op.drop_index( + "ix_gateway_computes_pipeline_fetch_q", + table_name="gateway_computes", + if_exists=True, + postgresql_concurrently=True, + ) diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 8d6f3c512c..fc73010263 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -28,7 +28,7 @@ from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.core.models.events import EventTargetType from dstack._internal.core.models.fleets import FleetStatus -from dstack._internal.core.models.gateways import GatewayStatus +from dstack._internal.core.models.gateways import GatewayReplicaStatus, GatewayStatus from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.profiles import ( @@ -630,7 +630,8 @@ class GatewayModel(PipelineModelMixin, BaseModel): ForeignKey("gateway_computes.id", ondelete="CASCADE") ) gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship( - foreign_keys=[gateway_compute_id] + foreign_keys=[gateway_compute_id], + back_populates="legacy_gateway", ) """ Relationship with gateway computes for pre-0.20.25 gateways. @@ -652,7 +653,7 @@ class GatewayModel(PipelineModelMixin, BaseModel): # TODO: Add pipeline index ("ix_gateways_pipeline_fetch_q") if gateways become soft-deleted. -class GatewayComputeModel(BaseModel): +class GatewayComputeModel(PipelineModelMixin, BaseModel): """A single gateway replica. **TODO**: consider renaming to `GatewayReplicaModel`. """ @@ -663,9 +664,12 @@ class GatewayComputeModel(BaseModel): UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + last_processed_at: Mapped[datetime] = mapped_column(NaiveDateTime) + status: Mapped[GatewayReplicaStatus] = mapped_column(EnumAsString(GatewayReplicaStatus, 100)) + status_message: Mapped[Optional[str]] = mapped_column(Text) replica_num: Mapped[int] = mapped_column(Integer, server_default="0") - instance_id: Mapped[str] = mapped_column(String(100)) - ip_address: Mapped[str] = mapped_column(String(100)) + instance_id: Mapped[Optional[str]] = mapped_column(String(100)) + ip_address: Mapped[Optional[str]] = mapped_column(String(100)) """Gateway replica IP address or domain name (e.g., k8s can use domain names). **TODO**: rename. """ @@ -678,7 +682,7 @@ class GatewayComputeModel(BaseModel): Use `get_gateway_compute_configuration` to construct `configuration` for old gateways. """ backend_data: Mapped[Optional[str]] = mapped_column(Text) - region: Mapped[str] = mapped_column(String(100)) + region: Mapped[Optional[str]] = mapped_column(String(100)) gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey( @@ -695,6 +699,15 @@ class GatewayComputeModel(BaseModel): Gateway. Can be None for pre-0.20.25 gateways, which use GatewayModel.gateway_compute_id to establish the relationship. """ + legacy_gateway: Mapped[Optional["GatewayModel"]] = relationship( + back_populates="gateway_compute", + foreign_keys="GatewayModel.gateway_compute_id", + viewonly=True, + ) + """ + Gateway for pre-0.20.25 gateways, where GatewayModel.gateway_compute_id points to this replica. + Use `gateway or legacy_gateway` to get the gateway regardless of version. + """ backend_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("backends.id", ondelete="CASCADE") @@ -710,6 +723,15 @@ class GatewayComputeModel(BaseModel): deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) app_updated_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + __table_args__ = ( + Index( + "ix_gateway_computes_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + # TODO: Drop after the release without pools class PoolModel(BaseModel): diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index f118261796..5dd3d91ae1 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -451,6 +451,27 @@ async def get_project_backend_model_by_type_or_error( return backend_model +async def get_project_backend_with_model_by_id( + project: ProjectModel, backend_id: UUID +) -> Optional[BackendTuple]: + backends_with_models = await get_project_backends_with_models(project=project) + for backend_model, backend in backends_with_models: + if backend_model.id == backend_id: + return backend_model, backend + return None + + +async def get_project_backend_with_model_by_id_or_error( + project: ProjectModel, backend_id: UUID +) -> BackendTuple: + backend_with_model = await get_project_backend_with_model_by_id( + project=project, backend_id=backend_id + ) + if backend_with_model is None: + raise BackendNotAvailable() + return backend_with_model + + async def get_backend_offers( backends: List[Backend], requirements: Requirements, diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index e81dbf7044..ede03b836f 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -14,8 +14,6 @@ import dstack._internal.utils.random_names as random_names from dstack._internal.core.backends.base.compute import ( - Compute, - ComputeWithGatewaySupport, get_dstack_gateway_wheel, get_dstack_runner_version, ) @@ -38,6 +36,7 @@ GatewayComputeConfiguration, GatewayConfiguration, GatewayReplica, + GatewayReplicaStatus, GatewaySpec, GatewayStatus, LetsEncryptGatewayCertificate, @@ -73,8 +72,8 @@ from dstack._internal.server.utils.common import gather_map_async from dstack._internal.utils.common import ( get_current_datetime, + get_or_error, interpolate_gateway_domain, - run_async, ) from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger @@ -173,15 +172,13 @@ async def get_gateway_by_name( return gateway_model_to_gateway(gateway, default_gateway_id=project.default_gateway_id) -async def create_gateway_compute( +def create_gateway_compute_model( project_name: str, - backend_compute: Compute, configuration: GatewayConfiguration, replica_num: int, gateway_id: Optional[uuid.UUID] = None, backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: - assert isinstance(backend_compute, ComputeWithGatewaySupport) assert configuration.name is not None private_bytes, public_bytes = generate_rsa_key_pair_bytes() @@ -201,23 +198,18 @@ async def create_gateway_compute( router=configuration.router, ) - gpd = await run_async( - backend_compute.create_gateway, - compute_configuration, - ) - + now = get_current_datetime() return GatewayComputeModel( gateway_id=gateway_id, backend_id=backend_id, replica_num=replica_num, - region=gpd.region, - ip_address=gpd.ip_address, - instance_id=gpd.instance_id, - hostname=gpd.hostname, configuration=compute_configuration.json(), - backend_data=gpd.backend_data, ssh_private_key=gateway_ssh_private_key, ssh_public_key=gateway_ssh_public_key, + status=GatewayReplicaStatus.SUBMITTED, + active=False, + created_at=now, + last_processed_at=now, ) @@ -307,6 +299,10 @@ async def connect_to_gateway_with_retry( the domain can be resolved. """ + if gateway_compute.ip_address is None: + logger.warning("Gateway replica %s has no ip_address, cannot connect", gateway_compute.id) + return None + connection = None for attempt in range(GATEWAY_CONNECT_ATTEMPTS): @@ -472,8 +468,16 @@ async def list_project_gateway_models( else: stmt = stmt.where(GatewayModel.project_id == project.id) if load_gateway_compute: - stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) - stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) + stmt = stmt.options( + joinedload(GatewayModel.gateway_compute) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) + stmt = stmt.options( + selectinload(GatewayModel.gateway_computes) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -501,8 +505,16 @@ async def get_project_gateway_model_by_reference( ) ) if load_gateway_compute: - stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) - stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) + stmt = stmt.options( + joinedload(GatewayModel.gateway_compute) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) + stmt = stmt.options( + selectinload(GatewayModel.gateway_computes) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -536,8 +548,16 @@ async def get_project_gateway_model_by_name_for_update( res = await session.execute( select(GatewayModel) .where(GatewayModel.id.in_([gateway_id]), *filters) - .options(joinedload(GatewayModel.gateway_compute)) - .options(selectinload(GatewayModel.gateway_computes)) + .options( + joinedload(GatewayModel.gateway_compute) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) + .options( + selectinload(GatewayModel.gateway_computes) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) .with_for_update(key_share=True, of=GatewayModel) ) @@ -563,8 +583,16 @@ async def get_project_default_gateway_model( ), ) if load_gateway_compute: - stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) - stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) + stmt = stmt.options( + joinedload(GatewayModel.gateway_compute) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) + stmt = stmt.options( + selectinload(GatewayModel.gateway_computes) + .joinedload(GatewayComputeModel.backend) + .load_only(BackendModel.type) + ) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -598,6 +626,9 @@ async def get_or_add_gateway_connections( raise GatewayError("Gateway compute not found") connections: List[GatewayConnection] = [] for compute in computes: + if compute.ip_address is None: + logger.warning("Gateway replica %s has no ip_address", compute.id) + raise GatewayError("Failed to connect to gateway") try: conn = await gateway_connections_pool.get_or_add( hostname=compute.ip_address, @@ -653,8 +684,7 @@ def _merge_per_window_stats(stats_per_gateway_replica: list[PerWindowStats]) -> async def init_gateways(session: AsyncSession): res = await session.execute( select(GatewayComputeModel).where( - # FIXME: should not include computes related to gateways in the `provisioning` status. - # Causes warnings and delays when restarting the server during gateway provisioning. + GatewayComputeModel.status == GatewayReplicaStatus.RUNNING, GatewayComputeModel.active == True, GatewayComputeModel.deleted == False, ) @@ -670,8 +700,10 @@ async def init_gateways(session: AsyncSession): resource="gateway_tunnels", ): for gateway, error in await gather_map_async( - gateway_computes, - lambda g: gateway_connections_pool.get_or_add(g.ip_address, g.ssh_private_key, True), + [g for g in gateway_computes if g.ip_address], + lambda g: gateway_connections_pool.get_or_add( + get_or_error(g.ip_address), g.ssh_private_key, True + ), return_exceptions=True, ): if isinstance(error, Exception): @@ -707,6 +739,11 @@ async def init_gateways(session: AsyncSession): async def _update_gateway(gateway_compute_model: GatewayComputeModel, build: str) -> bool: + if gateway_compute_model.ip_address is None: + logger.warning( + "Gateway replica %s has no ip_address, cannot update", gateway_compute_model.id + ) + return False if _recently_updated(gateway_compute_model): logger.debug( "Skipping gateway %s update. Gateway was recently updated.", @@ -812,11 +849,12 @@ def get_gateway_compute_configuration( if gateway_compute.configuration is not None: return GatewayComputeConfiguration.__response__.parse_raw(gateway_compute.configuration) # Handle gateways created before GatewayComputeConfiguration was introduced + gateway_configuration = get_gateway_configuration(gateway_model) return GatewayComputeConfiguration( project_name=gateway_model.project.name, - instance_name=gateway_compute.instance_id, - backend=gateway_model.backend.type, - region=gateway_compute.region, + instance_name=f"{gateway_model.name}-{gateway_compute.replica_num}", + backend=gateway_configuration.backend, + region=gateway_configuration.region, public_ip=True, ssh_key_pub=gateway_compute.ssh_public_key, certificate=LetsEncryptGatewayCertificate(), @@ -840,14 +878,15 @@ def gateway_model_to_gateway( gateway_hostname = None replicas = [] for compute in compute_models: - compute_configuration = get_gateway_compute_configuration(compute, gateway_model) replicas.append( GatewayReplica( hostname=compute.ip_address, replica_num=compute.replica_num, - backend=compute_configuration.backend, - region=compute_configuration.region, + backend=compute.backend.type if compute.backend else None, + region=compute.region, created_at=compute.created_at, + status=compute.status, + status_message=compute.status_message, ) ) gateway_hostname = compute.hostname diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 2c0a66be5a..293311e806 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -44,7 +44,11 @@ SSHHostParams, SSHParams, ) -from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus +from dstack._internal.core.models.gateways import ( + GatewayComputeConfiguration, + GatewayReplicaStatus, + GatewayStatus, +) from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( Disk, @@ -662,10 +666,15 @@ async def create_gateway_compute( gateway_id: Optional[UUID] = None, backend_id: Optional[UUID] = None, ip_address: Optional[str] = "1.1.1.1", - region: str = "us", + region: Optional[str] = "us", instance_id: Optional[str] = "i-1234567890", ssh_private_key: str = "", ssh_public_key: str = "", + status: GatewayReplicaStatus = GatewayReplicaStatus.RUNNING, + last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + replica_num: int = 0, + active: bool = True, + configuration: Optional[str] = None, ) -> GatewayComputeModel: gateway_compute = GatewayComputeModel( gateway_id=gateway_id, @@ -675,6 +684,11 @@ async def create_gateway_compute( instance_id=instance_id, ssh_private_key=ssh_private_key, ssh_public_key=ssh_public_key, + status=status, + last_processed_at=last_processed_at, + replica_num=replica_num, + active=active, + configuration=configuration, ) session.add(gateway_compute) await session.commit() diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateway_replicas.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateway_replicas.py new file mode 100644 index 0000000000..a26b3f104c --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateway_replicas.py @@ -0,0 +1,951 @@ +import asyncio +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.gateways import ( + GatewayProvisioningData, + GatewayReplicaStatus, + GatewayStatus, +) +from dstack._internal.server.background.pipeline_tasks.gateway_replicas import ( + GatewayReplicaFetcher, + GatewayReplicaPipeline, + GatewayReplicaPipelineItem, + GatewayReplicaWorker, +) +from dstack._internal.server.models import GatewayComputeModel +from dstack._internal.server.testing.common import ( + AsyncContextManager, + ComputeMockSpec, + create_backend, + create_gateway, + create_gateway_compute, + create_project, + get_gateway_compute_configuration, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.fixture +def worker() -> GatewayReplicaWorker: + return GatewayReplicaWorker(queue=Mock(), heartbeater=Mock(), pipeline_hinter=Mock()) + + +@pytest.fixture +def fetcher() -> GatewayReplicaFetcher: + return GatewayReplicaFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +def _compute_to_pipeline_item( + compute: GatewayComputeModel, +) -> GatewayReplicaPipelineItem: + assert compute.lock_token is not None + assert compute.lock_expires_at is not None + return GatewayReplicaPipelineItem( + __tablename__=compute.__tablename__, + id=compute.id, + lock_token=compute.lock_token, + lock_expires_at=compute.lock_expires_at, + prev_lock_expired=False, + status=compute.status, + ) + + +def _lock_compute(compute: GatewayComputeModel) -> None: + compute.lock_token = uuid.uuid4() + compute.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayReplicaFetcher: + async def test_fetch_selects_eligible_replicas_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: GatewayReplicaFetcher + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + submitted = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + last_processed_at=stale - timedelta(seconds=3), + ) + provisioning = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.PROVISIONING, + last_processed_at=stale - timedelta(seconds=2), + ) + terminating = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + last_processed_at=stale - timedelta(seconds=1), + ) + running = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.RUNNING, + last_processed_at=stale, + ) + terminated = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.TERMINATED, + active=False, + last_processed_at=stale, + ) + recent = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.SUBMITTED, + ip_address=None, + instance_id=None, + region=None, + last_processed_at=now, + ) + recent.created_at = now - timedelta(minutes=2) + recent.last_processed_at = now + locked = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.SUBMITTED, + ip_address=None, + instance_id=None, + region=None, + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == {submitted.id, provisioning.id, terminating.id} + assert {(item.id, item.status) for item in items} == { + (submitted.id, GatewayReplicaStatus.SUBMITTED), + (provisioning.id, GatewayReplicaStatus.PROVISIONING), + (terminating.id, GatewayReplicaStatus.TERMINATING), + } + + for compute in [submitted, provisioning, terminating, running, terminated, recent, locked]: + await session.refresh(compute) + + fetched = [submitted, provisioning, terminating] + assert all(c.lock_owner == GatewayReplicaPipeline.__name__ for c in fetched) + assert all(c.lock_expires_at is not None for c in fetched) + assert all(c.lock_token is not None for c in fetched) + assert len({c.lock_token for c in fetched}) == 1 + + assert running.lock_owner is None + assert terminated.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + @pytest.mark.parametrize( + "gateway_status,to_be_deleted", + [ + (GatewayStatus.FAILED, False), + (GatewayStatus.RUNNING, True), + ], + ) + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_fetch_includes_running_replica_needing_cleanup( + self, + test_db, + session: AsyncSession, + fetcher: GatewayReplicaFetcher, + gateway_status: GatewayStatus, + to_be_deleted: bool, + legacy_compute: bool, + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=gateway_status, + ) + gateway.to_be_deleted = to_be_deleted + stale = get_current_datetime() - timedelta(minutes=1) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.RUNNING, + last_processed_at=stale, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.RUNNING, + last_processed_at=stale, + ) + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert len(items) == 1 + assert items[0].id == compute.id + assert items[0].status == GatewayReplicaStatus.RUNNING + + async def test_fetch_includes_running_replica_with_hard_deleted_gateway( + self, + test_db, + session: AsyncSession, + fetcher: GatewayReplicaFetcher, + ): + # A compute whose gateway was hard-deleted (orphaned). The fetcher should + # pick it up so the worker can log the error. + stale = get_current_datetime() - timedelta(minutes=1) + compute = await create_gateway_compute( + session=session, + gateway_id=None, + status=GatewayReplicaStatus.RUNNING, + last_processed_at=stale, + ) + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert len(items) == 1 + assert items[0].id == compute.id + assert items[0].status == GatewayReplicaStatus.RUNNING + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_fetch_excludes_running_replica_with_healthy_gateway( + self, + test_db, + session: AsyncSession, + fetcher: GatewayReplicaFetcher, + legacy_compute: bool, + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + ) + stale = get_current_datetime() - timedelta(minutes=1) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.RUNNING, + last_processed_at=stale, + ) + gateway.gateway_compute_id = compute.id + else: + await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.RUNNING, + last_processed_at=stale, + ) + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert len(items) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayReplicaWorkerSubmitted: + async def test_submitted_to_provisioning( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + configuration=get_gateway_compute_configuration().json(), + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as m: + aws = Mock() + m.return_value = [(backend, aws)] + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( + instance_id="i-1234567890", + ip_address="2.2.2.2", + region="us", + ) + await worker.process(_compute_to_pipeline_item(compute)) + aws.compute.return_value.create_gateway.assert_called_once() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.PROVISIONING + assert compute.ip_address == "2.2.2.2" + assert compute.instance_id == "i-1234567890" + assert compute.region == "us" + + async def test_submitted_backend_error_marks_terminated( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + configuration=get_gateway_compute_configuration().json(), + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as m: + aws = Mock() + m.return_value = [(backend, aws)] + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = BackendError("Some error") + await worker.process(_compute_to_pipeline_item(compute)) + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + async def test_submitted_backend_not_available_marks_terminated( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as m: + m.return_value = [] + await worker.process(_compute_to_pipeline_item(compute)) + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + async def test_submitted_skips_provisioning_if_gateway_to_be_deleted( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + ) + gateway.to_be_deleted = True + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as m: + await worker.process(_compute_to_pipeline_item(compute)) + m.assert_not_called() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + async def test_submitted_skips_provisioning_if_gateway_failed( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.FAILED, + ) + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as m: + await worker.process(_compute_to_pipeline_item(compute)) + m.assert_not_called() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + async def test_submitted_unexpected_error_marks_terminated( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, + configuration=get_gateway_compute_configuration().json(), + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as m: + aws = Mock() + m.return_value = [(backend, aws)] + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = RuntimeError("Unexpected!") + await worker.process(_compute_to_pipeline_item(compute)) + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.status_message == "Unexpected error" + assert compute.active is False + assert compute.deleted is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayReplicaWorkerRunning: + @pytest.mark.parametrize( + "gateway_status,to_be_deleted", + [ + (GatewayStatus.FAILED, False), + (GatewayStatus.RUNNING, True), + ], + ) + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_running_to_terminating( + self, + test_db, + session: AsyncSession, + worker: GatewayReplicaWorker, + gateway_status: GatewayStatus, + to_be_deleted: bool, + legacy_compute: bool, + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=gateway_status, + ) + gateway.to_be_deleted = to_be_deleted + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.RUNNING, + active=True, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.RUNNING, + active=True, + ) + _lock_compute(compute) + await session.commit() + + await worker.process(_compute_to_pipeline_item(compute)) + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATING + assert compute.active is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayReplicaWorkerProvisioning: + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_provisioning_to_running( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.PROVISIONING, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.PROVISIONING, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as pool_add: + pool_add.return_value = MagicMock() + pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + await worker.process(_compute_to_pipeline_item(compute)) + pool_add.assert_called_once() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.RUNNING + assert compute.active is True + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_provisioning_to_terminating_if_connect_fails( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.PROVISIONING, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.PROVISIONING, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_mock: + connect_mock.return_value = None + await worker.process(_compute_to_pipeline_item(compute)) + connect_mock.assert_called_once() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATING + assert compute.active is False + assert compute.status_message == "Failed to connect to gateway" + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_provisioning_to_terminating_if_configure_fails( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.PROVISIONING, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.PROVISIONING, + ) + _lock_compute(compute) + await session.commit() + + with ( + patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_mock, + patch("dstack._internal.server.services.gateways.configure_gateway") as configure_mock, + ): + connect_mock.return_value = MagicMock() + configure_mock.side_effect = Exception("Configure failed") + await worker.process(_compute_to_pipeline_item(compute)) + connect_mock.assert_called_once() + configure_mock.assert_called_once() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATING + assert compute.active is False + assert compute.status_message == "Failed to configure gateway" + + @pytest.mark.parametrize( + "gateway_status,to_be_deleted", + [ + (GatewayStatus.FAILED, False), + (GatewayStatus.RUNNING, True), + ], + ) + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_provisioning_to_terminating_if_gateway_needs_cleanup( + self, + test_db, + session: AsyncSession, + worker: GatewayReplicaWorker, + gateway_status: GatewayStatus, + to_be_deleted: bool, + legacy_compute: bool, + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=gateway_status, + ) + gateway.to_be_deleted = to_be_deleted + if legacy_compute: + compute = await create_gateway_compute( + session=session, + status=GatewayReplicaStatus.PROVISIONING, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.PROVISIONING, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.gateway_replicas._connect_and_configure_gateway_replica" + ) as connect_mock: + await worker.process(_compute_to_pipeline_item(compute)) + connect_mock.assert_not_called() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATING + assert compute.active is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayReplicaWorkerTerminating: + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_terminating_to_terminated( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.FAILED, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + _lock_compute(compute) + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as get_backends_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateway_replicas.gateway_connections_pool.remove" + ) as remove_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + get_backends_mock.return_value = [(backend, backend_mock)] + + await worker.process(_compute_to_pipeline_item(compute)) + + get_backends_mock.assert_called_once() + backend_mock.compute.return_value.terminate_gateway.assert_called_once() + remove_mock.assert_called_once_with(compute.ip_address) + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_terminating_to_terminated_if_backend_not_available( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.FAILED, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + _lock_compute(compute) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as get_backends_mock: + get_backends_mock.return_value = [] + await worker.process(_compute_to_pipeline_item(compute)) + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_terminating_to_terminated_with_no_instance_id( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.FAILED, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + instance_id=None, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + instance_id=None, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + _lock_compute(compute) + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as get_backends_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateway_replicas.gateway_connections_pool.remove" + ) as remove_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + get_backends_mock.return_value = [(backend, backend_mock)] + + await worker.process(_compute_to_pipeline_item(compute)) + + backend_mock.compute.return_value.terminate_gateway.assert_not_called() + remove_mock.assert_not_called() + + await session.refresh(compute) + assert compute.status == GatewayReplicaStatus.TERMINATED + assert compute.active is False + assert compute.deleted is True + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_terminating_retries_if_terminate_fails( + self, test_db, session: AsyncSession, worker: GatewayReplicaWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.FAILED, + ) + if legacy_compute: + compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + gateway.gateway_compute_id = compute.id + else: + compute = await create_gateway_compute( + session=session, + gateway_id=gateway.id, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATING, + active=False, + ) + _lock_compute(compute) + original_last_processed_at = compute.last_processed_at + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backends_with_models" + ) as get_backends_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateway_replicas.gateway_connections_pool.remove" + ) as remove_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.terminate_gateway.side_effect = Exception( + "Terminate failed" + ) + get_backends_mock.return_value = [(backend, backend_mock)] + + await worker.process(_compute_to_pipeline_item(compute)) + + get_backends_mock.assert_called_once() + backend_mock.compute.return_value.terminate_gateway.assert_called_once() + remove_mock.assert_not_called() + + await session.refresh(compute) + # Not TERMINATED, should retry termination + assert compute.status == GatewayReplicaStatus.TERMINATING + assert compute.last_processed_at > original_last_processed_at + assert compute.lock_token is None + assert compute.lock_expires_at is None + assert compute.lock_owner is None diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index 2759d8a236..4b33a08409 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -1,18 +1,18 @@ import asyncio import uuid from datetime import datetime, timedelta, timezone -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from dstack._internal.core.errors import BackendError +from dstack._internal.core.errors import BackendNotAvailable from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.gateways import ( GatewayConfiguration, - GatewayProvisioningData, + GatewayReplicaStatus, GatewayStatus, ) from dstack._internal.server.background.pipeline_tasks.gateways import ( @@ -23,8 +23,6 @@ ) from dstack._internal.server.models import GatewayModel from dstack._internal.server.testing.common import ( - AsyncContextManager, - ComputeMockSpec, create_backend, create_gateway, create_gateway_compute, @@ -230,83 +228,6 @@ async def test_fetch_returns_oldest_gateways_first_up_to_limit( class TestGatewayWorkerSubmitted: async def test_submitted_to_provisioning( self, test_db, session: AsyncSession, worker: GatewayWorker - ): - project = await create_project(session=session) - backend = await create_backend(session=session, project_id=project.id) - gateway = await create_gateway( - session=session, - project_id=project.id, - backend_id=backend.id, - status=GatewayStatus.SUBMITTED, - ) - gateway.lock_token = uuid.uuid4() - gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) - await session.commit() - - with patch( - "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" - ) as m: - aws = Mock() - m.return_value = (backend, aws) - aws.compute.return_value = Mock(spec=ComputeMockSpec) - aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( - instance_id="i-1234567890", - ip_address="2.2.2.2", - region="us", - ) - await worker.process(_gateway_to_pipeline_item(gateway)) - m.assert_called_once() - aws.compute.return_value.create_gateway.assert_called_once() - - await session.refresh(gateway) - res = await session.execute( - select(GatewayModel) - .where(GatewayModel.id == gateway.id) - .options(selectinload(GatewayModel.gateway_computes)) - ) - gateway = res.unique().scalar_one() - assert gateway.status == GatewayStatus.PROVISIONING - assert len(gateway.gateway_computes) > 0 - assert gateway.gateway_computes[0].ip_address == "2.2.2.2" - events = await list_events(session) - assert len(events) == 1 - assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING" - - async def test_marks_gateway_as_failed_if_gateway_creation_errors( - self, test_db, session: AsyncSession, worker: GatewayWorker - ): - project = await create_project(session=session) - backend = await create_backend(session=session, project_id=project.id) - gateway = await create_gateway( - session=session, - project_id=project.id, - backend_id=backend.id, - status=GatewayStatus.SUBMITTED, - ) - gateway.lock_token = uuid.uuid4() - gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) - await session.commit() - - with patch( - "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" - ) as m: - aws = Mock() - m.return_value = (backend, aws) - aws.compute.return_value = Mock(spec=ComputeMockSpec) - aws.compute.return_value.create_gateway.side_effect = BackendError("Some error") - await worker.process(_gateway_to_pipeline_item(gateway)) - m.assert_called_once() - aws.compute.return_value.create_gateway.assert_called_once() - - await session.refresh(gateway) - assert gateway.status == GatewayStatus.FAILED - assert gateway.status_message == "Some error" - events = await list_events(session) - assert len(events) == 1 - assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" - - async def test_submitted_creates_multiple_computes_for_multi_replica( - self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) @@ -330,15 +251,8 @@ async def test_submitted_creates_multiple_computes_for_multi_replica( with patch( "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" ) as m: - aws = Mock() - m.return_value = (backend, aws) - aws.compute.return_value = Mock(spec=ComputeMockSpec) - aws.compute.return_value.create_gateway.side_effect = [ - GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"), - GatewayProvisioningData(instance_id="i-bbb", ip_address="3.3.3.3", region="us"), - ] + m.return_value = (backend, Mock()) await worker.process(_gateway_to_pipeline_item(gateway)) - assert aws.compute.return_value.create_gateway.call_count == 2 await session.refresh(gateway) res = await session.execute( @@ -350,12 +264,13 @@ async def test_submitted_creates_multiple_computes_for_multi_replica( assert gateway.status == GatewayStatus.PROVISIONING computes = sorted(gateway.gateway_computes, key=lambda c: c.replica_num) assert len(computes) == 2 - assert computes[0].ip_address == "2.2.2.2" + assert computes[0].status == GatewayReplicaStatus.SUBMITTED assert computes[0].replica_num == 0 - assert computes[1].ip_address == "3.3.3.3" + assert computes[1].status == GatewayReplicaStatus.SUBMITTED assert computes[1].replica_num == 1 + assert all(c.ip_address is None for c in computes) - async def test_marks_gateway_as_failed_if_second_replica_creation_errors( + async def test_marks_gateway_as_failed_if_backend_not_available( self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) @@ -366,13 +281,6 @@ async def test_marks_gateway_as_failed_if_second_replica_creation_errors( backend_id=backend.id, status=GatewayStatus.SUBMITTED, ) - config = GatewayConfiguration( - name=gateway.name, - backend=BackendType.AWS, - region=gateway.region, - replicas=2, - ) - gateway.configuration = config.json() gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() @@ -380,32 +288,18 @@ async def test_marks_gateway_as_failed_if_second_replica_creation_errors( with patch( "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" ) as m: - aws = Mock() - m.return_value = (backend, aws) - aws.compute.return_value = Mock(spec=ComputeMockSpec) - aws.compute.return_value.create_gateway.side_effect = [ - GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"), - BackendError("Some error"), - ] + m.side_effect = BackendNotAvailable() await worker.process(_gateway_to_pipeline_item(gateway)) - assert aws.compute.return_value.create_gateway.call_count == 2 await session.refresh(gateway) - res = await session.execute( - select(GatewayModel) - .where(GatewayModel.id == gateway.id) - .options(selectinload(GatewayModel.gateway_computes)) - ) - gateway = res.unique().scalar_one() assert gateway.status == GatewayStatus.FAILED - assert gateway.status_message == "Some error" - # The first replica's compute is saved even though the second failed - assert len(gateway.gateway_computes) == 1 - assert gateway.gateway_computes[0].ip_address == "2.2.2.2" - assert gateway.gateway_computes[0].replica_num == 0 + assert gateway.status_message == "Backend not available" events = await list_events(session) assert len(events) == 1 - assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" + assert ( + events[0].message + == "Gateway status changed SUBMITTED -> FAILED (Backend not available)" + ) @pytest.mark.asyncio @@ -424,21 +318,23 @@ async def test_provisioning_to_running( status=GatewayStatus.PROVISIONING, ) if legacy_compute: - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) - gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + status=GatewayReplicaStatus.RUNNING, + ) + gateway.gateway_compute_id = gateway_compute.id else: - await create_gateway_compute(session, gateway_id=gateway.id) + await create_gateway_compute( + session, + gateway_id=gateway.id, + status=GatewayReplicaStatus.RUNNING, + ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() - with patch( - "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" - ) as pool_add: - pool_add.return_value = MagicMock() - pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) - await worker.process(_gateway_to_pipeline_item(gateway)) - pool_add.assert_called_once() + await worker.process(_gateway_to_pipeline_item(gateway)) await session.refresh(gateway) assert gateway.status == GatewayStatus.RUNNING @@ -457,22 +353,25 @@ async def test_provisioning_to_running_with_multiple_replicas( backend_id=backend.id, status=GatewayStatus.PROVISIONING, ) - await create_gateway_compute(session, gateway_id=gateway.id, ip_address="1.1.1.1") - compute1 = await create_gateway_compute( - session, gateway_id=gateway.id, ip_address="2.2.2.2" + await create_gateway_compute( + session, + gateway_id=gateway.id, + ip_address="1.1.1.1", + status=GatewayReplicaStatus.RUNNING, + replica_num=0, + ) + await create_gateway_compute( + session, + gateway_id=gateway.id, + ip_address="2.2.2.2", + status=GatewayReplicaStatus.RUNNING, + replica_num=1, ) - compute1.replica_num = 1 gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() - with patch( - "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" - ) as pool_add: - pool_add.return_value = MagicMock() - pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) - await worker.process(_gateway_to_pipeline_item(gateway)) - assert pool_add.call_count == 2 + await worker.process(_gateway_to_pipeline_item(gateway)) await session.refresh(gateway) assert gateway.status == GatewayStatus.RUNNING @@ -480,9 +379,55 @@ async def test_provisioning_to_running_with_multiple_replicas( assert len(events) == 1 assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" + async def test_still_provisioning_if_not_all_replicas_running( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + await create_gateway_compute( + session, + gateway_id=gateway.id, + ip_address="1.1.1.1", + status=GatewayReplicaStatus.RUNNING, + replica_num=0, + ) + await create_gateway_compute( + session, + gateway_id=gateway.id, + ip_address="2.2.2.2", + status=GatewayReplicaStatus.PROVISIONING, + replica_num=1, + ) + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + original_last_processed_at = gateway.last_processed_at + await session.commit() + + await worker.process(_gateway_to_pipeline_item(gateway)) + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.PROVISIONING + assert gateway.last_processed_at > original_last_processed_at + events = await list_events(session) + assert len(events) == 0 + @pytest.mark.parametrize("legacy_compute", [False, True]) - async def test_marks_gateway_as_failed_if_fails_to_connect( - self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool + @pytest.mark.parametrize( + "replica_status", [GatewayReplicaStatus.TERMINATING, GatewayReplicaStatus.TERMINATED] + ) + async def test_marks_gateway_as_failed_if_replica_failed( + self, + test_db, + session: AsyncSession, + worker: GatewayWorker, + legacy_compute: bool, + replica_status: GatewayReplicaStatus, ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) @@ -493,34 +438,37 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( status=GatewayStatus.PROVISIONING, ) if legacy_compute: - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) - gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + status=replica_status, + active=False, + ) + gateway.gateway_compute_id = gateway_compute.id else: - gateway_compute = await create_gateway_compute(session, gateway_id=gateway.id) + await create_gateway_compute( + session, + gateway_id=gateway.id, + status=replica_status, + active=False, + ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() - with patch( - "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" - ) as connect_to_gateway_with_retry_mock: - connect_to_gateway_with_retry_mock.return_value = None - await worker.process(_gateway_to_pipeline_item(gateway)) - connect_to_gateway_with_retry_mock.assert_called_once() + await worker.process(_gateway_to_pipeline_item(gateway)) await session.refresh(gateway) - await session.refresh(gateway_compute) assert gateway.status == GatewayStatus.FAILED - assert gateway.status_message == "Failed to connect to gateway" - assert gateway_compute.active is False + assert gateway.status_message == "Failed to provision gateway replica" events = await list_events(session) assert len(events) == 1 assert ( events[0].message - == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" + == "Gateway status changed PROVISIONING -> FAILED (Failed to provision gateway replica)" ) - async def test_marks_gateway_as_failed_if_any_replica_fails_to_connect( + async def test_still_provisioning_with_submitted_replica( self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) @@ -531,47 +479,33 @@ async def test_marks_gateway_as_failed_if_any_replica_fails_to_connect( backend_id=backend.id, status=GatewayStatus.PROVISIONING, ) - compute0 = await create_gateway_compute( - session, gateway_id=gateway.id, ip_address="1.1.1.1" + await create_gateway_compute( + session, + gateway_id=gateway.id, + ip_address=None, + instance_id=None, + region=None, + status=GatewayReplicaStatus.SUBMITTED, ) - compute1 = await create_gateway_compute( - session, gateway_id=gateway.id, ip_address="2.2.2.2" - ) - compute1.replica_num = 1 gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + original_last_processed_at = gateway.last_processed_at await session.commit() - with patch( - "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" - ) as connect_mock: - connect_mock.return_value = None - await worker.process(_gateway_to_pipeline_item(gateway)) - assert connect_mock.call_count == 2 + await worker.process(_gateway_to_pipeline_item(gateway)) await session.refresh(gateway) - assert gateway.status == GatewayStatus.FAILED - assert gateway.status_message == "Failed to connect to gateway" - - await session.refresh(compute0) - await session.refresh(compute1) - assert compute0.active is False - assert compute1.active is False - + assert gateway.status == GatewayStatus.PROVISIONING + assert gateway.last_processed_at > original_last_processed_at events = await list_events(session) - assert len(events) == 1 - assert ( - events[0].message - == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" - ) + assert len(events) == 0 @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerDeleted: - @pytest.mark.parametrize("legacy_compute", [False, True]) - async def test_deletes_gateway_and_marks_compute_deleted( - self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool + async def test_deletes_gateway_with_no_computes( + self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) @@ -581,47 +515,21 @@ async def test_deletes_gateway_and_marks_compute_deleted( backend_id=backend.id, status=GatewayStatus.RUNNING, ) - if legacy_compute: - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) - gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style - else: - gateway_compute = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id - ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) gateway.to_be_deleted = True await session.commit() - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" - ) as get_backend_mock, - patch( - "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" - ) as remove_connection_mock, - ): - backend_mock = Mock() - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - get_backend_mock.return_value = backend_mock - - await worker.process(_gateway_to_pipeline_item(gateway)) - - get_backend_mock.assert_called_once() - backend_mock.compute.return_value.terminate_gateway.assert_called_once() - remove_connection_mock.assert_called_once_with(gateway_compute.ip_address) + await worker.process(_gateway_to_pipeline_item(gateway)) - await session.refresh(gateway_compute) res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) assert res.scalar_one_or_none() is None - assert gateway_compute.active is False - assert gateway_compute.deleted is True events = await list_events(session) assert len(events) == 1 assert events[0].message == "Gateway deleted" @pytest.mark.parametrize("legacy_compute", [False, True]) - async def test_keeps_gateway_if_terminate_fails( + async def test_deletes_gateway_when_all_replicas_terminated( self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) @@ -633,54 +541,49 @@ async def test_keeps_gateway_if_terminate_fails( status=GatewayStatus.RUNNING, ) if legacy_compute: - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) - gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style - else: gateway_compute = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id + session=session, + backend_id=backend.id, + status=GatewayReplicaStatus.TERMINATED, + active=False, + ) + gateway.gateway_compute_id = gateway_compute.id + else: + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + status=GatewayReplicaStatus.TERMINATED, + active=False, ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) - gateway.lock_owner = "GatewayPipeline" gateway.to_be_deleted = True - original_last_processed_at = gateway.last_processed_at await session.commit() - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" - ) as get_backend_mock, - patch( - "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" - ) as remove_connection_mock, - ): - backend_mock = Mock() - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.terminate_gateway.side_effect = BackendError( - "Terminate failed" - ) - get_backend_mock.return_value = backend_mock - - await worker.process(_gateway_to_pipeline_item(gateway)) - - get_backend_mock.assert_called_once() - backend_mock.compute.return_value.terminate_gateway.assert_called_once() - remove_connection_mock.assert_not_called() + await worker.process(_gateway_to_pipeline_item(gateway)) - await session.refresh(gateway) - await session.refresh(gateway_compute) - assert gateway.to_be_deleted is True - assert gateway.last_processed_at > original_last_processed_at - assert gateway.lock_token is None - assert gateway.lock_expires_at is None - assert gateway.lock_owner is None - assert gateway_compute.active is True - assert gateway_compute.deleted is False + res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) + assert res.scalar_one_or_none() is None events = await list_events(session) - assert len(events) == 0 + assert len(events) == 1 + assert events[0].message == "Gateway deleted" - async def test_deletes_gateway_with_multiple_replicas( - self, test_db, session: AsyncSession, worker: GatewayWorker + @pytest.mark.parametrize( + "replica_status", + [ + GatewayReplicaStatus.SUBMITTED, + GatewayReplicaStatus.PROVISIONING, + GatewayReplicaStatus.RUNNING, + GatewayReplicaStatus.TERMINATING, + ], + ) + async def test_waits_when_replicas_not_yet_terminated( + self, + test_db, + session: AsyncSession, + worker: GatewayWorker, + replica_status: GatewayReplicaStatus, ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) @@ -690,48 +593,34 @@ async def test_deletes_gateway_with_multiple_replicas( backend_id=backend.id, status=GatewayStatus.RUNNING, ) - compute0 = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1" - ) - compute1 = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2" + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + status=replica_status, + active=False, ) - compute1.replica_num = 1 gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.lock_owner = "GatewayPipeline" gateway.to_be_deleted = True + original_last_processed_at = gateway.last_processed_at await session.commit() - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" - ) as get_backend_mock, - patch( - "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" - ) as remove_connection_mock, - ): - backend_mock = Mock() - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - get_backend_mock.return_value = backend_mock - - await worker.process(_gateway_to_pipeline_item(gateway)) - - assert backend_mock.compute.return_value.terminate_gateway.call_count == 2 - assert remove_connection_mock.call_count == 2 + await worker.process(_gateway_to_pipeline_item(gateway)) - await session.refresh(compute0) - await session.refresh(compute1) res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) - assert res.scalar_one_or_none() is None - assert compute0.active is False - assert compute0.deleted is True - assert compute1.active is False - assert compute1.deleted is True + assert res.scalar_one_or_none() is not None + await session.refresh(gateway) + assert gateway.to_be_deleted is True + assert gateway.last_processed_at > original_last_processed_at + assert gateway.lock_token is None + assert gateway.lock_expires_at is None + assert gateway.lock_owner is None events = await list_events(session) - assert len(events) == 1 - assert events[0].message == "Gateway deleted" + assert len(events) == 0 - async def test_keeps_gateway_if_second_replica_terminate_fails( + async def test_deletes_gateway_with_multiple_replicas_all_terminated( self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) @@ -742,48 +631,33 @@ async def test_keeps_gateway_if_second_replica_terminate_fails( backend_id=backend.id, status=GatewayStatus.RUNNING, ) - compute0 = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1" + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ip_address="1.1.1.1", + status=GatewayReplicaStatus.TERMINATED, + active=False, + replica_num=0, ) - compute1 = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2" + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ip_address="2.2.2.2", + status=GatewayReplicaStatus.TERMINATED, + active=False, + replica_num=1, ) - compute1.replica_num = 1 gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) - gateway.lock_owner = "GatewayPipeline" gateway.to_be_deleted = True - original_last_processed_at = gateway.last_processed_at await session.commit() - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" - ) as get_backend_mock, - patch( - "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" - ) as remove_connection_mock, - ): - backend_mock = Mock() - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.terminate_gateway.side_effect = [ - None, - BackendError("Terminate failed"), - ] - get_backend_mock.return_value = backend_mock - - await worker.process(_gateway_to_pipeline_item(gateway)) - - assert backend_mock.compute.return_value.terminate_gateway.call_count == 2 - remove_connection_mock.assert_called_once_with(compute0.ip_address) + await worker.process(_gateway_to_pipeline_item(gateway)) - await session.refresh(gateway) - await session.refresh(compute0) - await session.refresh(compute1) - assert gateway.to_be_deleted is True - assert gateway.last_processed_at > original_last_processed_at - assert gateway.lock_token is None - assert gateway.lock_expires_at is None - assert gateway.lock_owner is None - assert compute0.deleted is False - assert compute1.deleted is False + res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) + assert res.scalar_one_or_none() is None + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway deleted" diff --git a/src/tests/_internal/server/compatibility/test_gateways.py b/src/tests/_internal/server/compatibility/test_gateways.py index 4313be0955..ed5e929c13 100644 --- a/src/tests/_internal/server/compatibility/test_gateways.py +++ b/src/tests/_internal/server/compatibility/test_gateways.py @@ -8,6 +8,7 @@ Gateway, GatewayConfiguration, GatewayReplica, + GatewayReplicaStatus, GatewayStatus, ) from dstack._internal.server.compatibility.gateways import patch_gateway @@ -24,6 +25,7 @@ def _make_gateway_replica(hostname: str = "1.2.3.4") -> GatewayReplica: backend=BackendType.AWS, region="us", created_at=get_current_datetime(), + status=GatewayReplicaStatus.RUNNING, ) diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 075d1f6d4a..583541327a 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -74,6 +74,8 @@ async def test_list( "backend": backend.type.value, "region": "us", "created_at": response.json()[0]["replicas"][0]["created_at"], + "status": "running", + "status_message": None, } ], "instance_id": None, @@ -145,6 +147,8 @@ async def test_get( "backend": backend.type.value, "region": "us", "created_at": response.json()["replicas"][0]["created_at"], + "status": "running", + "status_message": None, } ], "instance_id": None, @@ -843,6 +847,8 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "backend": backend.type.value, "region": "us", "created_at": response.json()["replicas"][0]["created_at"], + "status": "running", + "status_message": None, } ], "instance_id": None, @@ -1222,6 +1228,8 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "backend": backend.type.value, "region": "us", "created_at": response.json()["replicas"][0]["created_at"], + "status": "running", + "status_message": None, } ], "instance_id": None,