diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 490f872861..e1d4018b76 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -1235,6 +1235,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) + # Ensure sagemaker dependency is included to prevent version mismatch issues + # Resolves issue where computing hash for integrity check changed in 2.256.0 + local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path) + job_settings.dependencies = local_dependencies_path + if step_compilation_context: with _tmpdir() as tmp_dir: script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( @@ -1291,6 +1296,225 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): return input_data_config +def _decrement_version(version_str: str) -> str: + """Decrement a version string by one minor or patch version. + + Rules: + - If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0 + - If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1 + + Args: + version_str: Version string (e.g., "2.256.0") + + Returns: + Decremented version string + """ + from packaging import version as pkg_version + + try: + parsed = pkg_version.parse(version_str) + major = parsed.major + minor = parsed.minor + patch = parsed.micro + + if patch == 0: + # Decrement minor version + minor = max(0, minor - 1) + else: + # Decrement patch version + patch = max(0, patch - 1) + + return f"{major}.{minor}.{patch}" + except Exception: + return version_str + + +def _resolve_version_from_specifier(specifier_str: str) -> str: + """Resolve the version to check based on upper bounds. + + Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only). + If no upper bound exists, it's safe (unbounded). + If the decremented upper bound is less than a lower bound, use the lower bound. + + Args: + specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0") + + Returns: + The resolved version string to check, or None if safe + """ + import re + from packaging import version as pkg_version + + # Handle exact version pinning (==) + match = re.search(r"==\s*([\d.]+)", specifier_str) + if match: + return match.group(1) + + # Extract lower bounds for comparison + lower_bounds = [] + for match in re.finditer(r">=\s*([\d.]+)", specifier_str): + lower_bounds.append(match.group(1)) + + # Handle upper bounds - find the most restrictive one + upper_bounds = [] + + # Find all <= bounds + for match in re.finditer(r"<=\s*([\d.]+)", specifier_str): + upper_bounds.append(("<=", match.group(1))) + + # Find all < bounds + for match in re.finditer(r"<\s*([\d.]+)", specifier_str): + upper_bounds.append(("<", match.group(1))) + + if upper_bounds: + # Sort by version to find the most restrictive (lowest) upper bound + upper_bounds.sort(key=lambda x: pkg_version.parse(x[1])) + operator, version = upper_bounds[0] + + # Special case: if upper bound is <3.0.0, it's safe (V2 only) + try: + parsed_upper = pkg_version.parse(version) + if ( + operator == "<" + and parsed_upper.major == 3 + and parsed_upper.minor == 0 + and parsed_upper.micro == 0 + ): + # <3.0.0 means V2 only, which is safe + return None + except Exception: + pass + + resolved_version = version + if operator == "<": + resolved_version = _decrement_version(version) + + # If we have a lower bound and the resolved version is less than it, use the lower bound + if lower_bounds: + try: + resolved_parsed = pkg_version.parse(resolved_version) + for lower_bound_str in lower_bounds: + lower_parsed = pkg_version.parse(lower_bound_str) + if resolved_parsed < lower_parsed: + resolved_version = lower_bound_str + except Exception: + pass + + return resolved_version + + # For lower bounds only (>=, >), we don't check + return None + + +def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: + """Check if the sagemaker version requirement uses incompatible hashing. + + Raises ValueError if the requirement would install a version that uses HMAC hashing + (which is incompatible with the current SHA256-based integrity checks). + + Args: + sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=2.200.0") + + Raises: + ValueError: If the requirement would install a version using HMAC hashing + """ + import re + from packaging import version as pkg_version + + match = re.search(r"sagemaker\s*(.+)$", sagemaker_requirement.strip(), re.IGNORECASE) + if not match: + return + + specifier_str = match.group(1).strip() + + # Resolve the version that would be installed + resolved_version_str = _resolve_version_from_specifier(specifier_str) + if not resolved_version_str: + # No upper bound or exact version, so we can't determine if it's bad + return + + try: + resolved_version = pkg_version.parse(resolved_version_str) + except Exception: + return + + # Define HMAC thresholds for each major version + v2_hmac_threshold = pkg_version.parse("2.256.0") + v3_hmac_threshold = pkg_version.parse("3.2.0") + + # Check if the resolved version uses HMAC hashing + uses_hmac = False + if resolved_version.major == 2 and resolved_version < v2_hmac_threshold: + uses_hmac = True + elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold: + uses_hmac = True + + if uses_hmac: + raise ValueError( + f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " + f"could install a version using HMAC-based integrity checks which are incompatible " + f"with the current SHA256-based integrity checks. Please update to " + f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." + ) + + +def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: + """Ensure sagemaker>=2.256.0 is in the dependencies. + + This function ensures that the remote environment has a compatible version of sagemaker + that includes the fix for the HMAC key security issue. Versions < 2.256.0 use HMAC-based + integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable. + Versions >= 2.256.0 use SHA256-based integrity checks which are secure and don't require + the secret key. + + If no dependencies are provided, creates a temporary requirements.txt with sagemaker. + If dependencies are provided, appends sagemaker if not already present. + + Args: + local_dependencies_path: Path to user's dependencies file or None + + Returns: + Path to the dependencies file (created or modified) + + Raises: + ValueError: If user has pinned sagemaker to a version using HMAC hashing + """ + import tempfile + + SAGEMAKER_MIN_VERSION = "sagemaker>=2.256.0,<3.0.0" + + if local_dependencies_path is None: + # Create a temporary requirements.txt in the system temp directory + fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_") + os.close(fd) + + with open(req_file, "w") as f: + f.write(f"{SAGEMAKER_MIN_VERSION}\n") + logger.info( + "Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION + ) + return req_file + + # If dependencies provided, ensure sagemaker is included + if local_dependencies_path.endswith(".txt"): + with open(local_dependencies_path, "r") as f: + content = f.read() + + # Check if sagemaker is already specified + if "sagemaker" in content.lower(): + # Extract the sagemaker requirement line for compatibility check + for line in content.split("\n"): + if "sagemaker" in line.lower(): + _check_sagemaker_version_compatibility(line.strip()) + break + else: + with open(local_dependencies_path, "a") as f: + f.write(f"\n{SAGEMAKER_MIN_VERSION}\n") + logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION) + + return local_dependencies_path + + def _prepare_dependencies_and_pre_execution_scripts( local_dependencies_path: str, pre_execution_commands: List[str], diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py new file mode 100644 index 0000000000..bc5e7870a9 --- /dev/null +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -0,0 +1,144 @@ +"""Integration tests for sagemaker dependency injection in remote functions. + +These tests verify that the sagemaker>=2.256.0 dependency is properly injected +into remote function jobs, preventing version mismatch issues. +""" + +from __future__ import absolute_import + +import os +import sys +import tempfile + +import pytest + +# Add src to path before importing sagemaker +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) + +from sagemaker.remote_function import remote # noqa: E402 + +# Skip decorator for AWS configuration +skip_if_no_aws_region = pytest.mark.skipif( + not os.environ.get("AWS_DEFAULT_REGION"), reason="AWS credentials not configured" +) + + +class TestRemoteFunctionDependencyInjection: + """Integration tests for dependency injection in remote functions.""" + + @pytest.mark.integ + @skip_if_no_aws_region + def test_remote_function_without_dependencies(self): + """Test remote function execution without explicit dependencies. + + This test verifies that when no dependencies are provided, the remote + function still executes successfully because sagemaker>=2.256.0 is + automatically injected. + """ + + @remote( + instance_type="ml.m5.large", + # No dependencies specified - sagemaker should be injected automatically + ) + def simple_add(x, y): + """Simple function that adds two numbers.""" + return x + y + + # Execute the function + result = simple_add(5, 3) + + # Verify result + assert result == 8, f"Expected 8, got {result}" + print("✓ Remote function without dependencies executed successfully") + + @pytest.mark.integ + @skip_if_no_aws_region + def test_remote_function_with_user_dependencies_no_sagemaker(self): + """Test remote function with user dependencies but no sagemaker. + + This test verifies that when user provides dependencies without sagemaker, + sagemaker>=2.256.0 is automatically appended. + """ + # Create a temporary requirements.txt without sagemaker + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("numpy>=1.20.0\npandas>=1.3.0\n") + req_file = f.name + + try: + + @remote( + instance_type="ml.m5.large", + dependencies=req_file, + ) + def compute_with_numpy(x): + """Function that uses numpy.""" + import numpy as np + + return np.array([x, x * 2, x * 3]).sum() + + # Execute the function + result = compute_with_numpy(5) + + # Verify result (5 + 10 + 15 = 30) + assert result == 30, f"Expected 30, got {result}" + print("✓ Remote function with user dependencies executed successfully") + finally: + os.remove(req_file) + + +class TestRemoteFunctionVersionCompatibility: + """Tests for version compatibility between local and remote environments.""" + + @pytest.mark.integ + @skip_if_no_aws_region + def test_deserialization_with_injected_sagemaker(self): + """Test that deserialization works with injected sagemaker dependency. + + This test verifies that the remote environment can properly deserialize + functions when sagemaker>=2.256.0 is available. + """ + + @remote( + instance_type="ml.m5.large", + ) + def complex_computation(data): + """Function that performs complex computation.""" + result = sum(data) * len(data) + return result + + # Execute with various data types + test_data = [1, 2, 3, 4, 5] + result = complex_computation(test_data) + + # Verify result (sum=15, len=5, 15*5=75) + assert result == 75, f"Expected 75, got {result}" + print("✓ Deserialization with injected sagemaker works correctly") + + @pytest.mark.integ + @skip_if_no_aws_region + def test_multiple_remote_functions_with_dependencies(self): + """Test multiple remote functions with different dependency configurations. + + This test verifies that the dependency injection works correctly + when multiple remote functions are defined and executed. + """ + + @remote(instance_type="ml.m5.large") + def func1(x): + return x + 1 + + @remote(instance_type="ml.m5.large") + def func2(x): + return x * 2 + + # Execute both functions + result1 = func1(5) + result2 = func2(5) + + assert result1 == 6, f"func1: Expected 6, got {result1}" + assert result2 == 10, f"func2: Expected 10, got {result2}" + print("✓ Multiple remote functions with dependencies executed successfully") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integ"]) diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py new file mode 100644 index 0000000000..ac7f11e124 --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -0,0 +1,325 @@ +"""Unit tests for _ensure_sagemaker_dependency function. + +Tests the logic that ensures sagemaker>=2.256.0 is included in remote function dependencies +to prevent version mismatch issues with HMAC key integrity checks. +""" + +from __future__ import absolute_import + +import os +import sys +import tempfile +import unittest + +# Add src to path before importing sagemaker +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) + +from sagemaker.remote_function.job import ( # noqa: E402 + _ensure_sagemaker_dependency, + _check_sagemaker_version_compatibility, +) + + +class TestEnsureSagemakerDependency(unittest.TestCase): + """Test cases for _ensure_sagemaker_dependency function.""" + + def test_no_dependencies_creates_temp_requirements_file(self): + """Test that a temp requirements.txt is created when no dependencies provided.""" + result = _ensure_sagemaker_dependency(None) + + # Verify file was created + self.assertTrue(os.path.exists(result), f"Requirements file not created at {result}") + + # Verify it's in temp directory + self.assertIn(tempfile.gettempdir(), result) + + # Verify content + with open(result, "r") as f: + content = f.read() + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + + # Cleanup + os.remove(result) + + def test_no_dependencies_file_has_correct_format(self): + """Test that created requirements.txt has correct format.""" + result = _ensure_sagemaker_dependency(None) + + with open(result, "r") as f: + lines = f.readlines() + + # Should have exactly one line with sagemaker dependency + self.assertEqual(len(lines), 1) + self.assertEqual(lines[0].strip(), "sagemaker>=2.256.0,<3.0.0") + + # Cleanup + os.remove(result) + + def test_appends_sagemaker_to_existing_requirements(self): + """Test that sagemaker is appended to existing requirements.txt.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("numpy>=1.20.0\npandas>=1.3.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Verify content + with open(result, "r") as f: + content = f.read() + + self.assertIn("numpy>=1.20.0", content) + self.assertIn("pandas>=1.3.0", content) + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + finally: + os.remove(temp_file) + + def test_does_not_duplicate_sagemaker_if_already_present(self): + """Test that sagemaker is not duplicated if already in requirements.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("numpy>=1.20.0\nsagemaker>=2.256.0,<3.0.0\npandas>=1.3.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # Count occurrences of sagemaker + sagemaker_count = content.lower().count("sagemaker") + self.assertEqual(sagemaker_count, 1, "sagemaker should appear exactly once") + + # Verify user's version is preserved + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + finally: + os.remove(temp_file) + + def test_preserves_user_dependencies(self): + """Test that user's existing dependencies are preserved.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("torch>=1.9.0\ntorchvision>=0.10.0\nscikit-learn>=0.24.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # All user dependencies should be present + self.assertIn("torch>=1.9.0", content) + self.assertIn("torchvision>=0.10.0", content) + self.assertIn("scikit-learn>=0.24.0", content) + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + finally: + os.remove(temp_file) + + def test_handles_yml_files_gracefully(self): + """Test that yml files are returned unchanged.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: + f.write("name: test-env\nchannels:\n - conda-forge\ndependencies:\n - numpy\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Content should be unchanged (yml files are not modified) + with open(result, "r") as f: + content = f.read() + + self.assertNotIn("sagemaker", content.lower()) + finally: + os.remove(temp_file) + + def test_handles_yaml_files_gracefully(self): + """Test that yaml files are returned unchanged.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("name: test-env\nchannels:\n - conda-forge\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Content should be unchanged + with open(result, "r") as f: + content = f.read() + + self.assertNotIn("sagemaker", content.lower()) + finally: + os.remove(temp_file) + + def test_case_insensitive_sagemaker_detection(self): + """Test that sagemaker detection is case-insensitive.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("numpy>=1.20.0\nSAGEMAKER>=2.256.0,<3.0.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # Should not duplicate even with different case + sagemaker_count = content.lower().count("sagemaker") + self.assertEqual(sagemaker_count, 1) + finally: + os.remove(temp_file) + + def test_temp_file_location(self): + """Test that temp file is created in system temp directory.""" + result = _ensure_sagemaker_dependency(None) + + # Should be in system temp directory + temp_dir = tempfile.gettempdir() + self.assertTrue(result.startswith(temp_dir)) + + # Should have correct prefix + self.assertIn("sagemaker_requirements_", result) + + # Cleanup + os.remove(result) + + def test_version_constraint_format(self): + """Test that version constraint has correct format.""" + result = _ensure_sagemaker_dependency(None) + + with open(result, "r") as f: + content = f.read().strip() + + # Should have both lower and upper bounds + self.assertIn(">=2.256.0", content) + self.assertIn("<3.0.0", content) + + # Cleanup + os.remove(result) + + +class TestCheckSagemakerVersionCompatibility(unittest.TestCase): + """Test cases for _check_sagemaker_version_compatibility function.""" + + # ===== GOOD CASES (should NOT raise ValueError) ===== + + def test_v2_good_exact_version_256(self): + """Test V2 exact version 2.256.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==2.256.0") + + def test_v2_good_exact_version_300(self): + """Test V2 exact version 2.300.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==2.300.0") + + def test_v2_good_range_256_to_300(self): + """Test V2 range 2.256.0 to 2.300.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=2.256.0,<2.300.0") + + def test_v3_good_exact_version_32(self): + """Test V3 exact version 3.2.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==3.2.0") + + def test_v3_good_greater_equal_32(self): + """Test V3 greater or equal 3.2.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0") + + def test_v3_good_range_32_to_40(self): + """Test V3 range 3.2.0 to 4.0.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0,<4.0.0") + + def test_unparseable_requirement_no_error(self): + """Test that unparseable requirements don't raise (let pip handle it).""" + # Should not raise - let pip handle invalid syntax + _check_sagemaker_version_compatibility("sagemaker") + _check_sagemaker_version_compatibility("invalid-requirement") + + def test_v2_bad_exact_version_255(self): + """Test V2 exact version 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError) as context: + _check_sagemaker_version_compatibility("sagemaker==2.255.0") + self.assertIn("HMAC-based integrity checks", str(context.exception)) + + def test_v2_bad_exact_version_200(self): + """Test V2 exact version 2.200.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==2.200.0") + + def test_v2_bad_less_than_256(self): + """Test V2 less than 2.256.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<2.256.0") + + def test_v2_bad_less_equal_255(self): + """Test V2 less or equal 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<=2.255.0") + + def test_v2_bad_greater_than_255_0(self): + """Test V2 greater than 2.255.0 (not checked - treat as lower bound only).""" + # Should not raise - > is treated as a lower bound, we don't check those + _check_sagemaker_version_compatibility("sagemaker>2.255.0") + + def test_v2_bad_range_200_to_255(self): + """Test V2 range 2.200.0 to 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.256.0") + + def test_v3_bad_exact_version_31(self): + """Test V3 exact version 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==3.1.0") + + def test_v3_bad_exact_version_300(self): + """Test V3 exact version 3.0.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==3.0.0") + + def test_v3_bad_less_than_32(self): + """Test V3 less than 3.2.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<3.2.0") + + def test_v3_bad_less_equal_31(self): + """Test V3 less or equal 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<=3.1.0") + + def test_v3_bad_range_300_to_31(self): + """Test V3 range 3.0.0 to 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=3.0.0,<3.2.0") + + # ===== EDGE CASES ===== + + def test_multiple_version_specifiers_good(self): + """Test multiple version specifiers that are good.""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=2.256.0,<3.0.0") + + def test_multiple_version_specifiers_good_with_lower_bound(self): + """Test multiple version specifiers that are good (upper bound resolves to good version).""" + # Should not raise - <2.300.0 decrements to 2.299.0 which is >= 2.256.0 + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.300.0") + + def test_multiple_version_specifiers_bad(self): + """Test multiple version specifiers that are bad.""" + # Should raise - <2.256.0 decrements to 2.255.0 which is < 2.256.0 (HMAC) + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.256.0") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 62987395ee..f0aa1695e4 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -525,6 +525,7 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess monkeypatch.delenv("SAGEMAKER_INTERNAL_IMAGE_URI") +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -539,7 +540,10 @@ def test_start( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -562,7 +566,6 @@ def test_start( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -576,7 +579,7 @@ def test_start( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -629,7 +632,7 @@ def test_start( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -647,6 +650,7 @@ def test_start( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -661,7 +665,10 @@ def test_start_with_checkpoint_location( mock_runtime_manager, mock_script_upload, mock_user_workspace_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -743,7 +750,7 @@ def test_start_with_checkpoint_location( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -797,6 +804,7 @@ def test_start_with_checkpoint_location_failed_with_multiple_checkpoint_location ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -810,10 +818,14 @@ def test_start_with_complete_job_settings( mock_runtime_manager, mock_bootstrap_script_upload, mock_user_workspace_upload, + mock_ensure_sagemaker, ): + # This test has explicit dependencies, so mock returns the same path + dependencies_path = "path/to/dependencies/req.txt" + mock_ensure_sagemaker.return_value = dependencies_path job_settings = _JobSettings( - dependencies="path/to/dependencies/req.txt", + dependencies=dependencies_path, pre_execution_script="path/to/script.sh", environment_variables={"AWS_DEFAULT_REGION": "us-east-2"}, image_uri=IMAGE, @@ -839,7 +851,6 @@ def test_start_with_complete_job_settings( s3_kms_key=KMS_KEY_ARN, ) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -853,7 +864,7 @@ def test_start_with_complete_job_settings( ) mock_user_workspace_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path=dependencies_path, include_local_workdir=False, pre_execution_commands=None, pre_execution_script_local_path="path/to/script.sh", @@ -932,6 +943,7 @@ def test_start_with_complete_job_settings( @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch( "sagemaker.remote_function.job._prepare_dependencies_and_pre_execution_scripts", return_value="some_s3_uri", @@ -950,7 +962,11 @@ def test_get_train_args_under_pipeline_context( mock_bootstrap_scripts_upload, mock_user_workspace_upload, mock_user_dependencies_upload, + mock_ensure_sagemaker, ): + # This test has explicit dependencies, so mock returns the same path + dependencies_path = "path/to/dependencies/req.txt" + mock_ensure_sagemaker.return_value = dependencies_path from sagemaker.workflow.parameters import ParameterInteger @@ -965,7 +981,7 @@ def test_get_train_args_under_pipeline_context( function_step._properties.OutputDataConfig.S3OutputPath = func_step_s3_output_prop job_settings = _JobSettings( - dependencies="path/to/dependencies/req.txt", + dependencies=dependencies_path, pre_execution_script="path/to/script.sh", environment_variables={"AWS_DEFAULT_REGION": "us-east-2"}, image_uri=IMAGE, @@ -1026,7 +1042,7 @@ def test_get_train_args_under_pipeline_context( ) mock_user_workspace_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path=dependencies_path, include_local_workdir=False, pre_execution_commands=None, pre_execution_script_local_path="path/to/script.sh", @@ -1148,6 +1164,7 @@ def test_get_train_args_under_pipeline_context( "sagemaker.remote_function.job._prepare_and_upload_spark_dependent_files", return_value=tuple(["jars_s3_uri", "py_files_s3_uri", "files_s3_uri", "config_file_s3_uri"]), ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -1162,9 +1179,13 @@ def test_start_with_spark( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, mock_spark_dependency_upload, mock_get_default_spark_image, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" + spark_config = SparkConfig() job_settings = _JobSettings( spark_config=spark_config, @@ -1258,7 +1279,7 @@ def test_start_with_spark( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -1804,6 +1825,7 @@ def test_extend_spark_config_to_request( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -1818,7 +1840,10 @@ def test_start_with_torchrun_single_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -1843,7 +1868,6 @@ def test_start_with_torchrun_single_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -1857,7 +1881,7 @@ def test_start_with_torchrun_single_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -1910,7 +1934,7 @@ def test_start_with_torchrun_single_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "torchrun", "--run_in_context", @@ -1930,6 +1954,7 @@ def test_start_with_torchrun_single_node( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -1944,7 +1969,10 @@ def test_start_with_torchrun_multi_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -1970,7 +1998,6 @@ def test_start_with_torchrun_multi_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -1984,7 +2011,7 @@ def test_start_with_torchrun_multi_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -2039,7 +2066,7 @@ def test_start_with_torchrun_multi_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "torchrun", "--run_in_context", @@ -2319,6 +2346,7 @@ def test_set_env_multi_node_multi_gpu_mpirun( assert env_file == expected_env +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -2333,7 +2361,10 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -2359,7 +2390,6 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -2373,7 +2403,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -2426,7 +2456,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "torchrun", "--user_nproc_per_node", @@ -2448,6 +2478,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -2462,7 +2493,10 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -2488,7 +2522,6 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -2502,7 +2535,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -2555,7 +2588,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "mpirun", "--user_nproc_per_node",