diff --git a/sagemaker-serve/src/sagemaker/serve/deployment_progress.py b/sagemaker-serve/src/sagemaker/serve/deployment_progress.py index d816d3f3fe..bf418a23d8 100644 --- a/sagemaker-serve/src/sagemaker/serve/deployment_progress.py +++ b/sagemaker-serve/src/sagemaker/serve/deployment_progress.py @@ -89,8 +89,11 @@ def _live_logging_deploy_done_with_progress(sagemaker_client, endpoint_name, pag progress_tracker.log(f"✅ Created endpoint with name {endpoint_name}") elif endpoint_status != "InService": time.sleep(poll) + # Return immediately when endpoint is no longer creating. + # Log fetching below is best-effort and must not block completion. + return desc - # Fetch and route CloudWatch logs to progress tracker + # Fetch and route CloudWatch logs to progress tracker (only while Creating) pages = paginator.paginate( logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}", logStreamNamePrefix="AllTraffic/", @@ -108,9 +111,6 @@ def _live_logging_deploy_done_with_progress(sagemaker_client, endpoint_name, pag if progress_tracker: progress_tracker.update_status(endpoint_status) - # Return desc if we should stop polling - if stop: - return desc except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": return None diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..20062ae62e 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -45,6 +45,8 @@ ModelLifeCycle, DriftCheckBaselines, InferenceComponentComputeResourceRequirements, + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, ) from sagemaker.core.resources import ( ModelPackage, @@ -2700,6 +2702,52 @@ def _wait_for_endpoint( return desc + @staticmethod + def _apply_optional_ic_params(inference_component_spec, **kwargs): + """Apply optional IC-level parameters to an inference component spec dict. + + Wires data_cache_config, base_inference_component_name, and container + into the given inference_component_spec dict. Shared by + _deploy_core_endpoint and _update_inference_component to avoid + code duplication. + + Args: + inference_component_spec (dict): The spec dict to mutate in-place. + **kwargs: May contain data_cache_config, base_inference_component_name, + and container. + """ + from sagemaker.serve.model_builder_utils import _ModelBuilderUtils + + ic_data_cache_config = kwargs.get("data_cache_config") + if ic_data_cache_config is not None: + resolved = _ModelBuilderUtils._resolve_data_cache_config( + None, ic_data_cache_config + ) + if resolved is not None: + inference_component_spec["DataCacheConfig"] = { + "EnableCaching": resolved.enable_caching + } + + ic_base_component_name = kwargs.get("base_inference_component_name") + if ic_base_component_name is not None: + inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name + + ic_container = kwargs.get("container") + if ic_container is not None: + resolved_container = _ModelBuilderUtils._resolve_container_spec( + None, ic_container + ) + if resolved_container is not None: + container_dict = {} + if resolved_container.image: + container_dict["Image"] = resolved_container.image + if resolved_container.artifact_url: + container_dict["ArtifactUrl"] = resolved_container.artifact_url + if resolved_container.environment: + container_dict["Environment"] = resolved_container.environment + if container_dict: + inference_component_spec["Container"] = container_dict + def _deploy_core_endpoint(self, **kwargs): # Extract and update self parameters initial_instance_count = kwargs.get( @@ -2849,62 +2897,6 @@ def _deploy_core_endpoint(self, **kwargs): if self.role_arn is None: raise ValueError("Role can not be null for deploying a model") - routing_config = _resolve_routing_config(routing_config) - - if ( - inference_recommendation_id is not None - or self.inference_recommender_job_results is not None - ): - instance_type, initial_instance_count = self._update_params( - instance_type=instance_type, - initial_instance_count=initial_instance_count, - accelerator_type=accelerator_type, - async_inference_config=async_inference_config, - serverless_inference_config=serverless_inference_config, - explainer_config=explainer_config, - inference_recommendation_id=inference_recommendation_id, - inference_recommender_job_results=self.inference_recommender_job_results, - ) - - is_async = async_inference_config is not None - if is_async and not isinstance(async_inference_config, AsyncInferenceConfig): - raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object") - - is_explainer_enabled = explainer_config is not None - if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig): - raise ValueError("explainer_config needs to be a ExplainerConfig object") - - is_serverless = serverless_inference_config is not None - if not is_serverless and not (instance_type and initial_instance_count): - raise ValueError( - "Must specify instance type and instance count unless using serverless inference" - ) - - if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig): - raise ValueError( - "serverless_inference_config needs to be a ServerlessInferenceConfig object" - ) - - if self._is_sharded_model: - if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED: - logger.warning( - "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - " - "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints." - ) - endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED - - if self._enable_network_isolation: - raise ValueError( - "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " - "Loading of model requires network access." - ) - - if resources and resources.num_cpus and resources.num_cpus > 0: - logger.warning( - "NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker " - "Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`." - ) - if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: if update_endpoint: raise ValueError( @@ -2931,10 +2923,14 @@ def _deploy_core_endpoint(self, **kwargs): else: managed_instance_scaling_config["MinInstanceCount"] = initial_instance_count + # Use user-provided variant_name or default to "AllTraffic" + ic_variant_name = kwargs.get("variant_name", "AllTraffic") + if not self.sagemaker_session.endpoint_in_service_or_not(self.endpoint_name): production_variant = session_helper.production_variant( instance_type=instance_type, initial_instance_count=initial_instance_count, + variant_name=ic_variant_name, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, @@ -2978,6 +2974,10 @@ def _deploy_core_endpoint(self, **kwargs): "StartupParameters": startup_parameters, "ComputeResourceRequirements": resources.get_compute_resource_requirements(), } + + # Wire optional IC-level parameters into the specification + self._apply_optional_ic_params(inference_component_spec, **kwargs) + runtime_config = {"CopyCount": resources.copy_count} self.inference_component_name = ( inference_component_name @@ -2989,7 +2989,7 @@ def _deploy_core_endpoint(self, **kwargs): self.sagemaker_session.create_inference_component( inference_component_name=self.inference_component_name, endpoint_name=self.endpoint_name, - variant_name="AllTraffic", # default variant name + variant_name=ic_variant_name, specification=inference_component_spec, runtime_config=runtime_config, tags=tags, @@ -3168,6 +3168,10 @@ def _update_inference_component( "StartupParameters": startup_parameters, "ComputeResourceRequirements": compute_rr, } + + # Wire optional IC-level parameters into the update specification + self._apply_optional_ic_params(inference_component_spec, **kwargs) + runtime_config = {"CopyCount": resource_requirements.copy_count} return self.sagemaker_session.update_inference_component( @@ -4127,6 +4131,11 @@ def deploy( ] = None, custom_orchestrator_instance_type: str = None, custom_orchestrator_initial_instance_count: int = None, + inference_component_name: Optional[str] = None, + data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, + base_inference_component_name: Optional[str] = None, + container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None, + variant_name: Optional[str] = None, **kwargs, ) -> Union[Endpoint, LocalEndpoint, Transformer]: """Deploy the built model to an ``Endpoint``. @@ -4160,6 +4169,26 @@ def deploy( orchestrator deployment. (Default: None). custom_orchestrator_initial_instance_count (int, optional): Initial instance count for custom orchestrator deployment. (Default: None). + inference_component_name (str, optional): The name of the inference component + to create. Only used for inference-component-based endpoints. If not specified, + a unique name is generated from the model name. (Default: None). + data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional): + Data cache configuration for the inference component. Enables caching of model + artifacts and container images on instances for faster auto-scaling cold starts. + Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an + InferenceComponentDataCacheConfig instance. (Default: None). + base_inference_component_name (str, optional): Name of the base inference component + for adapter deployments (e.g., LoRA adapters attached to a base model). + (Default: None). + container (Union[InferenceComponentContainerSpecification, dict], optional): + Custom container specification for the inference component, including image URI, + artifact URL, and environment variables. Can be a dict with keys 'image', + 'artifact_url', 'environment' or an InferenceComponentContainerSpecification + instance. (Default: None). + variant_name (str, optional): The name of the production variant to deploy to. + If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``. + (Default: None). + Returns: Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, @@ -4182,6 +4211,21 @@ def deploy( if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): raise ValueError("Model needs to be built before deploying") + # Only forward variant_name when explicitly provided by the caller. + # Each downstream path has its own default: + # - _deploy_core_endpoint defaults to "AllTraffic" + # - _deploy_model_customization defaults to endpoint_name + if variant_name is not None: + kwargs["variant_name"] = variant_name + if inference_component_name is not None: + kwargs["inference_component_name"] = inference_component_name + if data_cache_config is not None: + kwargs["data_cache_config"] = data_cache_config + if base_inference_component_name is not None: + kwargs["base_inference_component_name"] = base_inference_component_name + if container is not None: + kwargs["container"] = container + # Handle model customization deployment if self._is_model_customization(): logger.info("Deploying Model Customization model") @@ -4338,8 +4382,13 @@ def _deploy_model_customization( initial_instance_count: int = 1, inference_component_name: Optional[str] = None, inference_config: Optional[ResourceRequirements] = None, + variant_name: Optional[str] = None, + data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, **kwargs, ) -> Endpoint: + # NOTE: For backward compatibility, model customization deployments + # default variant_name to endpoint_name (not "AllTraffic") when the + # caller does not provide an explicit value. """Deploy a model customization (fine-tuned) model to an endpoint with inference components. This method handles the special deployment flow for fine-tuned models, creating: @@ -4379,6 +4428,15 @@ def _deploy_model_customization( # Fetch model package model_package = self._fetch_model_package() + # Resolve variant_name: preserve backward-compatible default of + # endpoint_name for model customization deployments. + effective_variant_name = variant_name or endpoint_name or "AllTraffic" + + # Resolve data_cache_config if provided + resolved_data_cache_config = None + if data_cache_config is not None: + resolved_data_cache_config = self._resolve_data_cache_config(data_cache_config) + # Check if endpoint exists is_existing_endpoint = self._does_endpoint_exist(endpoint_name) @@ -4387,7 +4445,7 @@ def _deploy_model_customization( endpoint_config_name=endpoint_name, production_variants=[ ProductionVariant( - variant_name=endpoint_name, + variant_name=effective_variant_name, instance_type=self.instance_type, initial_instance_count=initial_instance_count or 1, ) @@ -4428,6 +4486,7 @@ def _deploy_model_customization( base_ic_spec = InferenceComponentSpecification( model_name=self.built_model.model_name, + data_cache_config=resolved_data_cache_config, ) if inference_config is not None: base_ic_spec.compute_resource_requirements = ( @@ -4444,7 +4503,7 @@ def _deploy_model_customization( InferenceComponent.create( inference_component_name=base_ic_name, endpoint_name=endpoint_name, - variant_name=endpoint_name, + variant_name=effective_variant_name, specification=base_ic_spec, runtime_config=InferenceComponentRuntimeConfig(copy_count=1), tags=[{"key": "Base", "value": base_model_recipe_name}], @@ -4486,7 +4545,8 @@ def _deploy_model_customization( ic_spec = InferenceComponentSpecification( container=InferenceComponentContainerSpecification( image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars - ) + ), + data_cache_config=resolved_data_cache_config, ) if inference_config is not None: @@ -4504,7 +4564,7 @@ def _deploy_model_customization( InferenceComponent.create( inference_component_name=inference_component_name, endpoint_name=endpoint_name, - variant_name=endpoint_name, + variant_name=effective_variant_name, specification=ic_spec, runtime_config=InferenceComponentRuntimeConfig(copy_count=1), ) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 56f3070346..4cf7d56095 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -78,6 +78,10 @@ def build(self): from sagemaker.serve.utils.hardware_detector import _total_inference_model_size_mib from sagemaker.serve.utils.types import ModelServer from sagemaker.core.resources import Model +from sagemaker.core.shapes import ( + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, +) # MLflow imports from sagemaker.serve.model_format.mlflow.constants import ( @@ -3369,6 +3373,80 @@ def _extract_speculative_draft_model_provider( return "auto" + def _resolve_data_cache_config( + self, + data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None], + ) -> Optional[InferenceComponentDataCacheConfig]: + """Resolve data_cache_config to InferenceComponentDataCacheConfig. + + Args: + data_cache_config: Either a dict with 'enable_caching' key (and any future + fields supported by InferenceComponentDataCacheConfig), + an InferenceComponentDataCacheConfig instance, or None. + + Returns: + InferenceComponentDataCacheConfig or None. + + Raises: + ValueError: If data_cache_config is an unsupported type or dict + is missing the required 'enable_caching' key. + """ + if data_cache_config is None: + return None + + if isinstance(data_cache_config, InferenceComponentDataCacheConfig): + return data_cache_config + elif isinstance(data_cache_config, dict): + if "enable_caching" not in data_cache_config: + raise ValueError( + "data_cache_config dict must contain the required 'enable_caching' key. " + "Example: {'enable_caching': True}" + ) + # Pass only 'enable_caching' to avoid Pydantic validation errors + # if the model has extra='forbid'. As new fields are added to + # InferenceComponentDataCacheConfig, add them here. + return InferenceComponentDataCacheConfig( + enable_caching=data_cache_config["enable_caching"] + ) + else: + raise ValueError( + f"data_cache_config must be a dict with 'enable_caching' key or an " + f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}" + ) + + def _resolve_container_spec( + self, + container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None], + ) -> Optional[InferenceComponentContainerSpecification]: + """Resolve container to InferenceComponentContainerSpecification. + + Args: + container: Either a dict with container config keys (image, artifact_url, + environment), an InferenceComponentContainerSpecification instance, or None. + + Returns: + InferenceComponentContainerSpecification or None. + + Raises: + ValueError: If container is an unsupported type. + """ + if container is None: + return None + + if isinstance(container, InferenceComponentContainerSpecification): + return container + elif isinstance(container, dict): + # Only pass known keys to avoid Pydantic validation errors + # if the model has extra='forbid' configured + known_keys = {"image", "artifact_url", "environment"} + filtered = {k: v for k, v in container.items() if k in known_keys} + return InferenceComponentContainerSpecification(**filtered) + else: + raise ValueError( + f"container must be a dict or an InferenceComponentContainerSpecification " + f"instance, got {type(container)}" + ) + def get_huggingface_model_metadata( self, model_id: str, hf_hub_token: Optional[str] = None ) -> dict: diff --git a/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py new file mode 100644 index 0000000000..a6a7abc5a6 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py @@ -0,0 +1,199 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration test for IC-level deploy parameters (data_cache_config, variant_name). + +Uses ModelBuilder with a simple PyTorch model and ResourceRequirements to deploy +via the IC-based endpoint path, then verifies DataCacheConfig and custom VariantName +on the created InferenceComponent. +""" +from __future__ import absolute_import + +import json +import os +import tempfile +import uuid +import logging + +import boto3 +import pytest +import torch +import torch.nn as nn + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.utils.types import ModelServer +from sagemaker.core.inference_config import ResourceRequirements +from sagemaker.core.resources import EndpointConfig + +logger = logging.getLogger(__name__) + + +class SimpleModel(nn.Module): + """Tiny PyTorch model for testing.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 2) + + def forward(self, x): + return torch.softmax(self.linear(x), dim=1) + + +class SimpleInferenceSpec(InferenceSpec): + """InferenceSpec for the simple model.""" + + def load(self, model_dir: str): + model = SimpleModel() + model_path = os.path.join(model_dir, "model.pth") + if os.path.exists(model_path): + model = torch.jit.load(model_path, map_location="cpu") + model.eval() + return model + + def invoke(self, input_object, model): + input_tensor = torch.tensor(input_object, dtype=torch.float32) + with torch.no_grad(): + return model(input_tensor).tolist() + + +def _save_model(path): + """Save a traced PyTorch model to disk.""" + os.makedirs(path, exist_ok=True) + m = SimpleModel() + traced = torch.jit.trace(m, torch.randn(1, 4)) + torch.jit.save(traced, os.path.join(path, "model.pth")) + + +def _cleanup(endpoint_name, sagemaker_client): + """Best-effort cleanup.""" + try: + paginator = sagemaker_client.get_paginator("list_inference_components") + for page in paginator.paginate(EndpointNameEquals=endpoint_name): + for ic in page.get("InferenceComponents", []): + try: + sagemaker_client.delete_inference_component( + InferenceComponentName=ic["InferenceComponentName"] + ) + except Exception: + pass + except Exception: + pass + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint_name) + except Exception: + pass + try: + sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name) + except Exception: + pass + + +@pytest.mark.slow_test +def test_deploy_ic_with_data_cache_config_and_variant_name(): + """Deploy a simple model via ModelBuilder IC path with data_cache_config and variant_name. + + Uses a tiny PyTorch model on ml.m5.xlarge (CPU) to keep costs low and avoid + GPU capacity issues. Verifies the IC was created with the correct DataCacheConfig + and VariantName via boto3 describe. + """ + uid = uuid.uuid4().hex[:8] + endpoint_name = f"ic-params-ep-{uid}" + ic_name = f"ic-params-component-{uid}" + custom_variant = f"Variant-{uid}" + model_name = f"ic-params-model-{uid}" + + sagemaker_client = boto3.client("sagemaker", region_name="us-west-2") + model_path = tempfile.mkdtemp() + _save_model(model_path) + + try: + schema = SchemaBuilder( + sample_input=[[0.1, 0.2, 0.3, 0.4]], + sample_output=[[0.6, 0.4]], + ) + + # Use a PyTorch inference image that works across CI (py310) and local (py312). + # The container has its own Python, so py310 works regardless of host Python. + from sagemaker.core import image_uris + inference_image = image_uris.retrieve( + framework="pytorch", + region="us-west-2", + version="2.2.0", + py_version="py310", + instance_type="ml.m5.xlarge", + image_scope="inference", + ) + + model_builder = ModelBuilder( + inference_spec=SimpleInferenceSpec(), + model_path=model_path, + model_server=ModelServer.TORCHSERVE, + schema_builder=schema, + instance_type="ml.m5.xlarge", + image_uri=inference_image, + dependencies={"auto": False}, + ) + + model_builder.build(model_name=model_name) + logger.info("Model built: %s", model_name) + + resources = ResourceRequirements( + requests={"memory": 1024, "num_cpus": 1, "copies": 1} + ) + + endpoint = model_builder.deploy( + endpoint_name=endpoint_name, + initial_instance_count=1, + inference_config=resources, + inference_component_name=ic_name, + data_cache_config={"enable_caching": True}, + variant_name=custom_variant, + ) + logger.info("Endpoint deployed: %s", endpoint.endpoint_name) + + # Wait for the IC to be fully ready before describing it. + # deploy() creates the IC with wait=False, so it may still be Creating. + import time + for _ in range(40): + ic_status = sagemaker_client.describe_inference_component( + InferenceComponentName=ic_name + ).get("InferenceComponentStatus") + if ic_status == "InService": + break + logger.info("IC status: %s, waiting...", ic_status) + time.sleep(15) + logger.info("IC InService: %s", ic_name) + + # Verify the IC was created with correct params + ic_desc = sagemaker_client.describe_inference_component( + InferenceComponentName=ic_name + ) + + # Check DataCacheConfig + spec = ic_desc.get("Specification", {}) + data_cache = spec.get("DataCacheConfig", {}) + assert data_cache.get("EnableCaching") is True, ( + f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" + ) + + # Check VariantName + actual_variant = ic_desc.get("VariantName") + assert actual_variant == custom_variant, ( + f"Expected VariantName='{custom_variant}', got '{actual_variant}'" + ) + + logger.info("Test passed: IC has correct DataCacheConfig and VariantName") + + finally: + _cleanup(endpoint_name, sagemaker_client) diff --git a/sagemaker-serve/tests/unit/test_deployment_progress.py b/sagemaker-serve/tests/unit/test_deployment_progress.py index 16ec1acdd0..709341c5a0 100644 --- a/sagemaker-serve/tests/unit/test_deployment_progress.py +++ b/sagemaker-serve/tests/unit/test_deployment_progress.py @@ -157,29 +157,29 @@ def test_endpoint_inservice_status(self): self.assertEqual(result, expected_desc) def test_with_progress_tracker_and_logs(self): - """Test with progress tracker and CloudWatch logs.""" + """Test with progress tracker when endpoint is InService. + + When the endpoint is InService, the function returns immediately + after logging the success message — it does not fetch CloudWatch + logs or call update_status (those only happen while Creating). + """ mock_client = Mock() - mock_client.describe_endpoint.return_value = { - "EndpointStatus": "InService" - } + expected_desc = {"EndpointStatus": "InService"} + mock_client.describe_endpoint.return_value = expected_desc mock_paginator = Mock() - mock_paginator.paginate.return_value = [ - { - "events": [ - {"message": "Log line 1"}, - {"message": "Log line 2"} - ] - } - ] mock_tracker = Mock() - + result = _live_logging_deploy_done_with_progress( mock_client, "test-endpoint", mock_paginator, {}, 5, mock_tracker ) - - # Should log success message when InService - self.assertGreaterEqual(mock_tracker.log.call_count, 1) - mock_tracker.update_status.assert_called_once_with("InService") + + # Should return desc immediately when InService + self.assertEqual(result, expected_desc) + # Should log success message + mock_tracker.log.assert_called_once() + self.assertIn("Created endpoint", mock_tracker.log.call_args[0][0]) + # Paginator should NOT be called (logs are skipped when done) + mock_paginator.paginate.assert_not_called() def test_resource_not_found_exception(self): """Test ResourceNotFoundException during log fetching.""" @@ -200,10 +200,14 @@ def test_resource_not_found_exception(self): self.assertIsNone(result) def test_pagination_with_next_token(self): - """Test pagination with nextToken.""" + """Test pagination with nextToken during Creating status. + + Pagination and log fetching only happen while the endpoint is + Creating. When InService, the function returns immediately. + """ mock_client = Mock() mock_client.describe_endpoint.return_value = { - "EndpointStatus": "InService" + "EndpointStatus": "Creating" } mock_paginator = Mock() paginator_config = {} @@ -213,11 +217,14 @@ def test_pagination_with_next_token(self): "events": [{"message": "Log 1"}] } ] - + result = _live_logging_deploy_done_with_progress( mock_client, "test-endpoint", mock_paginator, paginator_config, 5 ) - + + # Should return None (still Creating) + self.assertIsNone(result) + # Paginator should have been called and token stored self.assertEqual(paginator_config.get("StartingToken"), "token123") diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py new file mode 100644 index 0000000000..ceb697bac0 --- /dev/null +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -0,0 +1,699 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for IC parameter resolvers and wiring logic.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import MagicMock, patch, ANY + +from sagemaker.core.shapes import ( + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, +) +from sagemaker.serve.model_builder_utils import _ModelBuilderUtils + + +class ConcreteUtils(_ModelBuilderUtils): + """Concrete class to test mixin methods. + + _ModelBuilderUtils is a mixin that does not define __init__, + so this can be instantiated without arguments. + """ + pass + + +@pytest.fixture +def utils(): + return ConcreteUtils() + + +# ============================================================ +# Tests for _resolve_data_cache_config +# ============================================================ + +class TestResolveDataCacheConfig: + def test_none_returns_none(self, utils): + assert utils._resolve_data_cache_config(None) is None + + def test_already_typed_passthrough(self, utils): + config = InferenceComponentDataCacheConfig(enable_caching=True) + result = utils._resolve_data_cache_config(config) + assert result is config + assert result.enable_caching is True + + def test_dict_with_enable_caching_true(self, utils): + result = utils._resolve_data_cache_config({"enable_caching": True}) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is True + + def test_dict_with_enable_caching_false(self, utils): + result = utils._resolve_data_cache_config({"enable_caching": False}) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is False + + def test_dict_missing_enable_caching_raises(self, utils): + with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"): + utils._resolve_data_cache_config({}) + + def test_dict_with_extra_keys_still_works(self, utils): + """Extra keys in the input dict are ignored. + + The resolver only extracts 'enable_caching' from the dict, so extra keys + do not cause Pydantic validation errors even if the model forbids extras. + We verify the result has enable_caching=True and does not expose extra_key. + """ + result = utils._resolve_data_cache_config( + {"enable_caching": True, "extra_key": "ignored"} + ) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is True + # Verify extra_key is not present on the result object + assert not hasattr(result, "extra_key") or getattr(result, "extra_key", None) is None + + def test_invalid_type_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config("invalid") + + def test_invalid_type_int_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config(42) + + def test_invalid_type_list_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config([True]) + + +# ============================================================ +# Tests for _resolve_container_spec +# ============================================================ + +class TestResolveContainerSpec: + def test_none_returns_none(self, utils): + assert utils._resolve_container_spec(None) is None + + def test_already_typed_passthrough(self, utils): + spec = InferenceComponentContainerSpecification( + image="my-image:latest", + artifact_url="s3://bucket/artifact", + environment={"KEY": "VALUE"}, + ) + result = utils._resolve_container_spec(spec) + assert result is spec + + def test_dict_full(self, utils): + result = utils._resolve_container_spec({ + "image": "my-image:latest", + "artifact_url": "s3://bucket/artifact", + "environment": {"KEY": "VALUE"}, + }) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "my-image:latest" + assert result.artifact_url == "s3://bucket/artifact" + assert result.environment == {"KEY": "VALUE"} + + def test_dict_image_only(self, utils): + result = utils._resolve_container_spec({"image": "my-image:latest"}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "my-image:latest" + + def test_dict_artifact_url_only(self, utils): + result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.artifact_url == "s3://bucket/model.tar.gz" + + def test_dict_environment_only(self, utils): + result = utils._resolve_container_spec({"environment": {"A": "B"}}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.environment == {"A": "B"} + + def test_dict_empty(self, utils): + """Empty dict creates a spec with no fields set.""" + result = utils._resolve_container_spec({}) + assert isinstance(result, InferenceComponentContainerSpecification) + + def test_dict_with_extra_keys(self, utils): + """Extra keys are filtered out before passing to the Pydantic constructor. + + This ensures compatibility even if InferenceComponentContainerSpecification + has extra='forbid' in its Pydantic model config. + """ + result = utils._resolve_container_spec({ + "image": "img", + "unknown_key": "ignored", + }) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "img" + + def test_invalid_type_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec("invalid") + + def test_invalid_type_int_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec(123) + + def test_invalid_type_list_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec([{"image": "img"}]) + + +# ============================================================ +# Tests for _apply_optional_ic_params helper +# ============================================================ + +class TestApplyOptionalIcParams: + """Tests for the static helper that wires optional IC params into a spec dict.""" + + def test_no_params_no_mutation(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params(spec) + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + def test_data_cache_config_dict(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, data_cache_config={"enable_caching": True} + ) + assert spec["DataCacheConfig"] == {"EnableCaching": True} + + def test_data_cache_config_typed(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + cfg = InferenceComponentDataCacheConfig(enable_caching=False) + ModelBuilder._apply_optional_ic_params(spec, data_cache_config=cfg) + assert spec["DataCacheConfig"] == {"EnableCaching": False} + + def test_base_inference_component_name(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, base_inference_component_name="base-ic" + ) + assert spec["BaseInferenceComponentName"] == "base-ic" + + def test_container_dict(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, + container={ + "image": "img:latest", + "artifact_url": "s3://b/a", + "environment": {"K": "V"}, + }, + ) + assert spec["Container"] == { + "Image": "img:latest", + "ArtifactUrl": "s3://b/a", + "Environment": {"K": "V"}, + } + + def test_container_typed(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + c = InferenceComponentContainerSpecification(image="img") + ModelBuilder._apply_optional_ic_params(spec, container=c) + assert spec["Container"] == {"Image": "img"} + + def test_all_params_together(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, + data_cache_config={"enable_caching": True}, + base_inference_component_name="base", + container={"image": "img"}, + ) + assert spec["DataCacheConfig"] == {"EnableCaching": True} + assert spec["BaseInferenceComponentName"] == "base" + assert spec["Container"] == {"Image": "img"} + + +# ============================================================ +# Tests for core wiring logic in _deploy_core_endpoint +# ============================================================ + +class TestDeployCoreEndpointWiring: + """Tests that new IC parameters are correctly wired through _deploy_core_endpoint.""" + + def _make_model_builder(self): + """Create a minimally-configured ModelBuilder for testing _deploy_core_endpoint.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + # Set minimum required attributes + mb.model_name = "test-model" + mb.endpoint_name = None + mb.inference_component_name = None + mb.instance_type = "ml.g5.2xlarge" + mb.instance_count = 1 + mb.accelerator_type = None + mb._tags = None + mb.kms_key = None + mb.async_inference_config = None + mb.serverless_inference_config = None + mb.model_data_download_timeout = None + mb.resource_requirements = None + mb.container_startup_health_check_timeout = None + mb.inference_ami_version = None + mb._is_sharded_model = False + mb._enable_network_isolation = False + mb.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + mb.vpc_config = None + mb.inference_recommender_job_results = None + mb.model_server = None + mb.mode = None + mb.region = "us-east-1" + + # Mock built_model + mb.built_model = MagicMock() + mb.built_model.model_name = "test-model" + + # Mock sagemaker_session + mb.sagemaker_session = MagicMock() + mb.sagemaker_session.endpoint_in_service_or_not.return_value = True + mb.sagemaker_session.boto_session = MagicMock() + mb.sagemaker_session.boto_region_name = "us-east-1" + + return mb + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls): + """When variant_name is not provided, it defaults to 'AllTraffic'.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + wait=False, + ) + + # Verify create_inference_component was called with variant_name="AllTraffic" + mb.sagemaker_session.create_inference_component.assert_called_once() + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + assert call_kwargs.kwargs["variant_name"] == "AllTraffic" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_custom(self, mock_endpoint_cls): + """When variant_name is provided, it is used instead of 'AllTraffic'.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + variant_name="MyVariant", + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + assert call_kwargs.kwargs["variant_name"] == "MyVariant" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls): + """data_cache_config dict is resolved and added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + data_cache_config={"enable_caching": True}, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" in spec + assert spec["DataCacheConfig"]["EnableCaching"] is True + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls): + """base_inference_component_name is added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + base_inference_component_name="base-ic-name", + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["BaseInferenceComponentName"] == "base-ic-name" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_container_wired_into_spec(self, mock_endpoint_cls): + """container dict is resolved and added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + container={ + "image": "my-image:latest", + "artifact_url": "s3://bucket/artifact", + "environment": {"KEY": "VALUE"}, + }, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "Container" in spec + assert spec["Container"]["Image"] == "my-image:latest" + assert spec["Container"]["ArtifactUrl"] == "s3://bucket/artifact" + assert spec["Container"]["Environment"] == {"KEY": "VALUE"} + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls): + """When no optional IC params are provided, spec has no extra keys.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls): + """InferenceComponentDataCacheConfig object is correctly wired.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + config = InferenceComponentDataCacheConfig(enable_caching=True) + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + data_cache_config=config, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["DataCacheConfig"]["EnableCaching"] is True + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_endpoint_cls): + """When creating a new endpoint, variant_name is passed to production_variant.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + # Simulate endpoint does NOT exist yet + mb.sagemaker_session.endpoint_in_service_or_not.return_value = False + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + with patch("sagemaker.serve.model_builder.session_helper.production_variant") as mock_pv: + mock_pv.return_value = {"VariantName": "CustomVariant"} + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + variant_name="CustomVariant", + wait=False, + ) + + # Verify production_variant was called with variant_name="CustomVariant" + mock_pv.assert_called_once() + pv_kwargs = mock_pv.call_args + assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant" + + +# ============================================================ +# Tests for _update_inference_component wiring +# ============================================================ + +class TestUpdateInferenceComponentWiring: + """Tests that _update_inference_component correctly wires optional IC params.""" + + def _make_model_builder(self): + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.model_name = "test-model" + mb.sagemaker_session = MagicMock() + return mb + + def test_update_ic_with_data_cache_config(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, data_cache_config={"enable_caching": True} + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["DataCacheConfig"] == {"EnableCaching": True} + + def test_update_ic_with_container(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, container={"image": "img:v1"} + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["Container"] == {"Image": "img:v1"} + + def test_update_ic_with_base_inference_component_name(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, base_inference_component_name="base-ic" + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["BaseInferenceComponentName"] == "base-ic" + + def test_update_ic_no_optional_params(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component("my-ic", resources) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + +# ============================================================ +# Tests for deploy() parameter forwarding +# ============================================================ + +class TestDeployParameterForwarding: + """Tests that deploy() correctly forwards new IC params into kwargs.""" + + def test_deploy_forwards_variant_name_to_kwargs(self): + """deploy() should set kwargs['variant_name'] to the provided value.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + + # Mock _is_model_customization to return False + mb._is_model_customization = MagicMock(return_value=False) + # Mock _deploy to capture kwargs + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + variant_name="MyVariant", + data_cache_config={"enable_caching": True}, + base_inference_component_name="base-ic", + container={"image": "img"}, + ) + + assert captured["variant_name"] == "MyVariant" + assert captured["data_cache_config"] == {"enable_caching": True} + assert captured["base_inference_component_name"] == "base-ic" + assert captured["container"] == {"image": "img"} + + def test_deploy_does_not_set_variant_name_when_not_provided(self): + """deploy() should NOT set variant_name in kwargs when not provided. + + This allows downstream methods to use their own defaults: + - _deploy_core_endpoint defaults to 'AllTraffic' + - _deploy_model_customization defaults to endpoint_name + """ + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + mb._is_model_customization = MagicMock(return_value=False) + + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + ) + + # variant_name should NOT be in kwargs when not explicitly provided + assert "variant_name" not in captured + # Optional params should not be in kwargs when not provided + assert "data_cache_config" not in captured + assert "base_inference_component_name" not in captured + assert "container" not in captured + + def test_deploy_forwards_variant_name_none_is_not_forwarded(self): + """deploy(variant_name=None) should NOT forward variant_name. + + None is the default, so it should behave the same as not providing it. + """ + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + mb._is_model_customization = MagicMock(return_value=False) + + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + variant_name=None, + ) + + # variant_name=None should not be forwarded + assert "variant_name" not in captured