Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sagemaker-train/src/sagemaker/ai_registry/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def get_versions(self) -> List["DataSet"]:

return datasets

@classmethod
@classmethod
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DataSet.get_all")
def get_all(cls, max_results: Optional[int] = None, sagemaker_session=None):
Expand Down
10 changes: 9 additions & 1 deletion sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,18 @@ def _get_base_template_context(
Returns:
dict: Base template context dictionary
"""
# Generate default mlflow_experiment_name if not provided
# This is required by AWS when ModelPackageGroupArn is not provided in training jobs
mlflow_experiment_name = self.mlflow_experiment_name
if not mlflow_experiment_name and self.mlflow_resource_arn:
# Use pipeline_name as default experiment name
mlflow_experiment_name = '{{ pipeline_name }}'
_logger.info("No mlflow_experiment_name provided, using pipeline_name as default")

return {
'role_arn': role_arn,
'mlflow_resource_arn': self.mlflow_resource_arn,
'mlflow_experiment_name': self.mlflow_experiment_name,
'mlflow_experiment_name': mlflow_experiment_name,
'mlflow_run_name': self.mlflow_run_name,
'model_package_group_arn': model_package_group_arn,
'source_model_package_arn': self._source_model_package_arn,
Expand Down
9 changes: 2 additions & 7 deletions sagemaker-train/tests/integ/ai_registry/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,8 @@ def cleanup_list():
"""Track resources for cleanup."""
resources = []
yield resources
for evaluator in resources:
for resource in resources:
try:
from sagemaker.ai_registry.air_hub import AIRHub
AIRHub.delete_hub_content(
hub_content_type=evaluator.hub_content_type,
hub_content_name=evaluator.name,
hub_content_version=evaluator.version
)
resource.delete()
except Exception:
pass
11 changes: 7 additions & 4 deletions sagemaker-train/tests/integ/ai_registry/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def test_create_dataset_from_s3_nova_eval(self, unique_name, test_bucket, cleanu
cleanup_list.append(dataset)
assert dataset.name == unique_name

def test_get_dataset(self, unique_name, sample_jsonl_file):
def test_get_dataset(self, unique_name, sample_jsonl_file, cleanup_list):
"""Test retrieving dataset by name."""
created = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
cleanup_list.append(created)
retrieved = DataSet.get(unique_name)
assert retrieved.name == created.name
assert retrieved.arn == created.arn
Expand All @@ -141,16 +142,18 @@ def test_get_all_datasets(self):
datasets = list(DataSet.get_all(max_results=5))
assert isinstance(datasets, list)

def test_dataset_refresh(self, unique_name, sample_jsonl_file):
def test_dataset_refresh(self, unique_name, sample_jsonl_file, cleanup_list):
"""Test refreshing dataset status."""
dataset = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
cleanup_list.append(dataset)
dataset.refresh()
time.sleep(3)
assert dataset.status in [HubContentStatus.IMPORTING.value, HubContentStatus.AVAILABLE.value]

def test_dataset_get_versions(self, unique_name, sample_jsonl_file):
def test_dataset_get_versions(self, unique_name, sample_jsonl_file, cleanup_list):
"""Test getting dataset versions."""
dataset = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
cleanup_list.append(dataset)
versions = dataset.get_versions()
assert len(versions) >= 1
assert all(isinstance(v, DataSet) for v in versions)
Expand Down Expand Up @@ -178,7 +181,7 @@ def test_create_dataset_version(self, unique_name, sample_jsonl_file, cleanup_li
"""Test creating new dataset version."""
dataset = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False)
result = dataset.create_version(sample_jsonl_file)
cleanup_list.append(cleanup_list)
cleanup_list.append(dataset)
assert result is True

def test_dataset_validation_invalid_extension(self, unique_name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,23 @@ def test_base_model_evaluation_uses_correct_weights(self):
# Check that we have both base and custom inference steps
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []

logger.info(f"Pipeline steps: {step_names}")
logger.info(f"Pipeline steps ({len(step_names)}): {step_names}")

# Verify both inference steps exist
has_base_step = any("BaseInference" in name for name in step_names)
has_custom_step = any("CustomInference" in name for name in step_names)
# If no steps yet, wait a bit for pipeline to initialize
if not step_names:
logger.info("No steps found yet, waiting for pipeline initialization...")
import time
time.sleep(10)
execution.refresh()
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}")

assert has_base_step, "Pipeline should have EvaluateBaseInferenceModel step"
assert has_custom_step, "Pipeline should have EvaluateCustomInferenceModel step"
# Verify both inference steps exist (case-insensitive, flexible matching)
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names)

assert has_base_step, f"Pipeline should have base inference step. Found steps: {step_names}"
assert has_custom_step, f"Pipeline should have custom inference step. Found steps: {step_names}"

logger.info(f"✓ Pipeline has both base and custom inference steps")
logger.info(f" Base model step: {'Found' if has_base_step else 'Missing'}")
Expand Down Expand Up @@ -206,14 +215,19 @@ def test_base_model_evaluation_uses_correct_weights(self):
if execution.status.failure_reason:
logger.error(f" Failure reason: {execution.status.failure_reason}")

# Log step failures
# Log step failures with detailed information
if execution.status.step_details:
logger.error("\nFailed steps:")
logger.error("\n" + "=" * 80)
logger.error("DETAILED STEP FAILURE INFORMATION:")
logger.error("=" * 80)
for step in execution.status.step_details:
if "failed" in step.status.lower():
logger.error(f" {step.name}: {step.status}")
if step.failure_reason:
logger.error(f" Reason: {step.failure_reason}")
logger.error(f"\nStep: {step.name}")
logger.error(f" Status: {step.status}")
logger.error(f" Start Time: {step.start_time}")
logger.error(f" End Time: {step.end_time}")
if step.failure_reason:
logger.error(f" ❌ FAILURE REASON: {step.failure_reason}")
logger.error("=" * 80)

# Re-raise to fail the test
raise
Expand Down Expand Up @@ -259,14 +273,23 @@ def test_base_model_false_still_works(self):
execution.refresh()
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []

logger.info(f"Pipeline steps: {step_names}")
logger.info(f"Pipeline steps ({len(step_names)}): {step_names}")

# If no steps yet, wait a bit for pipeline to initialize
if not step_names:
logger.info("No steps found yet, waiting for pipeline initialization...")
import time
time.sleep(10)
execution.refresh()
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}")

# Should NOT have base inference step
has_base_step = any("BaseInference" in name for name in step_names)
has_custom_step = any("CustomInference" in name for name in step_names)
# Should NOT have base inference step (case-insensitive, flexible matching)
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names)

assert not has_base_step, "Pipeline should NOT have EvaluateBaseInferenceModel step when evaluate_base_model=False"
assert has_custom_step, "Pipeline should have EvaluateCustomInferenceModel step"
assert not has_base_step, f"Pipeline should NOT have base inference step when evaluate_base_model=False. Found steps: {step_names}"
assert has_custom_step, f"Pipeline should have custom inference step. Found steps: {step_names}"

logger.info(f"✓ Pipeline structure correct for evaluate_base_model=False")
logger.info(f" Base model step: {'Found (ERROR!)' if has_base_step else 'Not present (correct)'}")
Expand Down
Loading