Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sagemaker-serve/src/sagemaker/serve/deployment_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ def _live_logging_deploy_done_with_progress(sagemaker_client, endpoint_name, pag
progress_tracker.log(f"✅ Created endpoint with name {endpoint_name}")
elif endpoint_status != "InService":
time.sleep(poll)
# Return immediately when endpoint is no longer creating.
# Log fetching below is best-effort and must not block completion.
return desc

# Fetch and route CloudWatch logs to progress tracker
# Fetch and route CloudWatch logs to progress tracker (only while Creating)
pages = paginator.paginate(
logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
logStreamNamePrefix="AllTraffic/",
Expand All @@ -108,9 +111,6 @@ def _live_logging_deploy_done_with_progress(sagemaker_client, endpoint_name, pag
if progress_tracker:
progress_tracker.update_status(endpoint_status)

# Return desc if we should stop polling
if stop:
return desc
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
return None
Expand Down
182 changes: 121 additions & 61 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
ModelLifeCycle,
DriftCheckBaselines,
InferenceComponentComputeResourceRequirements,
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.core.resources import (
ModelPackage,
Expand Down Expand Up @@ -2700,6 +2702,52 @@ def _wait_for_endpoint(

return desc

@staticmethod
def _apply_optional_ic_params(inference_component_spec, **kwargs):
"""Apply optional IC-level parameters to an inference component spec dict.

Wires data_cache_config, base_inference_component_name, and container
into the given inference_component_spec dict. Shared by
_deploy_core_endpoint and _update_inference_component to avoid
code duplication.

Args:
inference_component_spec (dict): The spec dict to mutate in-place.
**kwargs: May contain data_cache_config, base_inference_component_name,
and container.
"""
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils

ic_data_cache_config = kwargs.get("data_cache_config")
if ic_data_cache_config is not None:
resolved = _ModelBuilderUtils._resolve_data_cache_config(
None, ic_data_cache_config
)
if resolved is not None:
inference_component_spec["DataCacheConfig"] = {
"EnableCaching": resolved.enable_caching
}

ic_base_component_name = kwargs.get("base_inference_component_name")
if ic_base_component_name is not None:
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name

ic_container = kwargs.get("container")
if ic_container is not None:
resolved_container = _ModelBuilderUtils._resolve_container_spec(
None, ic_container
)
if resolved_container is not None:
container_dict = {}
if resolved_container.image:
container_dict["Image"] = resolved_container.image
if resolved_container.artifact_url:
container_dict["ArtifactUrl"] = resolved_container.artifact_url
if resolved_container.environment:
container_dict["Environment"] = resolved_container.environment
if container_dict:
inference_component_spec["Container"] = container_dict

def _deploy_core_endpoint(self, **kwargs):
# Extract and update self parameters
initial_instance_count = kwargs.get(
Expand Down Expand Up @@ -2849,62 +2897,6 @@ def _deploy_core_endpoint(self, **kwargs):
if self.role_arn is None:
raise ValueError("Role can not be null for deploying a model")

routing_config = _resolve_routing_config(routing_config)

if (
inference_recommendation_id is not None
or self.inference_recommender_job_results is not None
):
instance_type, initial_instance_count = self._update_params(
instance_type=instance_type,
initial_instance_count=initial_instance_count,
accelerator_type=accelerator_type,
async_inference_config=async_inference_config,
serverless_inference_config=serverless_inference_config,
explainer_config=explainer_config,
inference_recommendation_id=inference_recommendation_id,
inference_recommender_job_results=self.inference_recommender_job_results,
)

is_async = async_inference_config is not None
if is_async and not isinstance(async_inference_config, AsyncInferenceConfig):
raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object")

is_explainer_enabled = explainer_config is not None
if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig):
raise ValueError("explainer_config needs to be a ExplainerConfig object")

is_serverless = serverless_inference_config is not None
if not is_serverless and not (instance_type and initial_instance_count):
raise ValueError(
"Must specify instance type and instance count unless using serverless inference"
)

if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig):
raise ValueError(
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
)

if self._is_sharded_model:
if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logger.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

if resources and resources.num_cpus and resources.num_cpus > 0:
logger.warning(
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
)

if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if update_endpoint:
raise ValueError(
Expand All @@ -2931,10 +2923,14 @@ def _deploy_core_endpoint(self, **kwargs):
else:
managed_instance_scaling_config["MinInstanceCount"] = initial_instance_count

# Use user-provided variant_name or default to "AllTraffic"
ic_variant_name = kwargs.get("variant_name", "AllTraffic")

if not self.sagemaker_session.endpoint_in_service_or_not(self.endpoint_name):
production_variant = session_helper.production_variant(
instance_type=instance_type,
initial_instance_count=initial_instance_count,
variant_name=ic_variant_name,
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
Expand Down Expand Up @@ -2978,6 +2974,10 @@ def _deploy_core_endpoint(self, **kwargs):
"StartupParameters": startup_parameters,
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
}

# Wire optional IC-level parameters into the specification
self._apply_optional_ic_params(inference_component_spec, **kwargs)

runtime_config = {"CopyCount": resources.copy_count}
self.inference_component_name = (
inference_component_name
Expand All @@ -2989,7 +2989,7 @@ def _deploy_core_endpoint(self, **kwargs):
self.sagemaker_session.create_inference_component(
inference_component_name=self.inference_component_name,
endpoint_name=self.endpoint_name,
variant_name="AllTraffic", # default variant name
variant_name=ic_variant_name,
specification=inference_component_spec,
runtime_config=runtime_config,
tags=tags,
Expand Down Expand Up @@ -3168,6 +3168,10 @@ def _update_inference_component(
"StartupParameters": startup_parameters,
"ComputeResourceRequirements": compute_rr,
}

# Wire optional IC-level parameters into the update specification
self._apply_optional_ic_params(inference_component_spec, **kwargs)

runtime_config = {"CopyCount": resource_requirements.copy_count}

return self.sagemaker_session.update_inference_component(
Expand Down Expand Up @@ -4127,6 +4131,11 @@ def deploy(
] = None,
custom_orchestrator_instance_type: str = None,
custom_orchestrator_initial_instance_count: int = None,
inference_component_name: Optional[str] = None,
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
base_inference_component_name: Optional[str] = None,
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
variant_name: Optional[str] = None,
**kwargs,
) -> Union[Endpoint, LocalEndpoint, Transformer]:
"""Deploy the built model to an ``Endpoint``.
Expand Down Expand Up @@ -4160,6 +4169,26 @@ def deploy(
orchestrator deployment. (Default: None).
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
for custom orchestrator deployment. (Default: None).
inference_component_name (str, optional): The name of the inference component
to create. Only used for inference-component-based endpoints. If not specified,
a unique name is generated from the model name. (Default: None).
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
Data cache configuration for the inference component. Enables caching of model
artifacts and container images on instances for faster auto-scaling cold starts.
Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an
InferenceComponentDataCacheConfig instance. (Default: None).
base_inference_component_name (str, optional): Name of the base inference component
for adapter deployments (e.g., LoRA adapters attached to a base model).
(Default: None).
container (Union[InferenceComponentContainerSpecification, dict], optional):
Custom container specification for the inference component, including image URI,
artifact URL, and environment variables. Can be a dict with keys 'image',
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
instance. (Default: None).
variant_name (str, optional): The name of the production variant to deploy to.
If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``.
(Default: None).

Returns:
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
Expand All @@ -4182,6 +4211,21 @@ def deploy(
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
raise ValueError("Model needs to be built before deploying")

# Only forward variant_name when explicitly provided by the caller.
# Each downstream path has its own default:
# - _deploy_core_endpoint defaults to "AllTraffic"
# - _deploy_model_customization defaults to endpoint_name
if variant_name is not None:
kwargs["variant_name"] = variant_name
if inference_component_name is not None:
kwargs["inference_component_name"] = inference_component_name
if data_cache_config is not None:
kwargs["data_cache_config"] = data_cache_config
if base_inference_component_name is not None:
kwargs["base_inference_component_name"] = base_inference_component_name
if container is not None:
kwargs["container"] = container

# Handle model customization deployment
if self._is_model_customization():
logger.info("Deploying Model Customization model")
Expand Down Expand Up @@ -4338,8 +4382,13 @@ def _deploy_model_customization(
initial_instance_count: int = 1,
inference_component_name: Optional[str] = None,
inference_config: Optional[ResourceRequirements] = None,
variant_name: Optional[str] = None,
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
**kwargs,
) -> Endpoint:
# NOTE: For backward compatibility, model customization deployments
# default variant_name to endpoint_name (not "AllTraffic") when the
# caller does not provide an explicit value.
"""Deploy a model customization (fine-tuned) model to an endpoint with inference components.

This method handles the special deployment flow for fine-tuned models, creating:
Expand Down Expand Up @@ -4379,6 +4428,15 @@ def _deploy_model_customization(
# Fetch model package
model_package = self._fetch_model_package()

# Resolve variant_name: preserve backward-compatible default of
# endpoint_name for model customization deployments.
effective_variant_name = variant_name or endpoint_name or "AllTraffic"

# Resolve data_cache_config if provided
resolved_data_cache_config = None
if data_cache_config is not None:
resolved_data_cache_config = self._resolve_data_cache_config(data_cache_config)

# Check if endpoint exists
is_existing_endpoint = self._does_endpoint_exist(endpoint_name)

Expand All @@ -4387,7 +4445,7 @@ def _deploy_model_customization(
endpoint_config_name=endpoint_name,
production_variants=[
ProductionVariant(
variant_name=endpoint_name,
variant_name=effective_variant_name,
instance_type=self.instance_type,
initial_instance_count=initial_instance_count or 1,
)
Expand Down Expand Up @@ -4428,6 +4486,7 @@ def _deploy_model_customization(

base_ic_spec = InferenceComponentSpecification(
model_name=self.built_model.model_name,
data_cache_config=resolved_data_cache_config,
)
if inference_config is not None:
base_ic_spec.compute_resource_requirements = (
Expand All @@ -4444,7 +4503,7 @@ def _deploy_model_customization(
InferenceComponent.create(
inference_component_name=base_ic_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
variant_name=effective_variant_name,
specification=base_ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
tags=[{"key": "Base", "value": base_model_recipe_name}],
Expand Down Expand Up @@ -4486,7 +4545,8 @@ def _deploy_model_customization(
ic_spec = InferenceComponentSpecification(
container=InferenceComponentContainerSpecification(
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
)
),
data_cache_config=resolved_data_cache_config,
)

if inference_config is not None:
Expand All @@ -4504,7 +4564,7 @@ def _deploy_model_customization(
InferenceComponent.create(
inference_component_name=inference_component_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
variant_name=effective_variant_name,
specification=ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
)
Expand Down
Loading
Loading