From 19f6195b7fde5153c25461c606b67a74dd7e6cda Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 12:00:40 -0800 Subject: [PATCH 01/10] Add sagemaker dependency for remote function by default --- src/sagemaker/remote_function/job.py | 112 +++++++ .../test_sagemaker_dependency_injection.py | 127 +++++++ .../test_ensure_sagemaker_dependency.py | 316 ++++++++++++++++++ 3 files changed, 555 insertions(+) create mode 100644 tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py create mode 100644 tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 490f872861..468a4077be 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -1235,6 +1235,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) + # Ensure sagemaker dependency is included to prevent version mismatch issues + # Resolves issue where computing hash for integrity check changed in 2.256.0 + local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path) + job_settings.dependencies = local_dependencies_path + if step_compilation_context: with _tmpdir() as tmp_dir: script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( @@ -1291,6 +1296,113 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): return input_data_config +def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: + """Check if the sagemaker version requirement uses incompatible hashing. + + Raises ValueError if the requirement would install a version that uses HMAC hashing + (which is incompatible with the current SHA256-based integrity checks). + + Args: + sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=2.200.0") + + Raises: + ValueError: If the requirement would install a version using HMAC hashing + """ + import re + from packaging.specifiers import SpecifierSet + from packaging import version as pkg_version + + match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE) + if not match: + return + + specifier_str = match.group(1).strip() + + try: + specifier_set = SpecifierSet(specifier_str) + except Exception: + return + + # Test if any HMAC version would satisfy the specifier + # V2 HMAC versions: < 2.256.0 + v2_hmac_test_versions = ["2.0.0", "2.100.0", "2.200.0", "2.255.0", "2.255.1", "2.255.99"] + for test_version in v2_hmac_test_versions: + if test_version in specifier_set: + raise ValueError( + f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " + f"could install a version using HMAC-based integrity checks which are incompatible " + f"with the current SHA256-based integrity checks. Please update to " + f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." + ) + + # V3 HMAC versions: < 3.2.0 + v3_hmac_test_versions = ["3.0.0", "3.0.1", "3.1.0", "3.1.99"] + for test_version in v3_hmac_test_versions: + if test_version in specifier_set: + raise ValueError( + f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " + f"could install a version using HMAC-based integrity checks which are incompatible " + f"with the current SHA256-based integrity checks. Please update to " + f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." + ) + + +def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: + """Ensure sagemaker>=2.256.0 is in the dependencies. + + This function ensures that the remote environment has a compatible version of sagemaker + that includes the fix for the HMAC key security issue. Versions < 2.256.0 use HMAC-based + integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable. + Versions >= 2.256.0 use SHA256-based integrity checks which are secure and don't require + the secret key. + + If no dependencies are provided, creates a temporary requirements.txt with sagemaker. + If dependencies are provided, appends sagemaker if not already present. + + Args: + local_dependencies_path: Path to user's dependencies file or None + + Returns: + Path to the dependencies file (created or modified) + + Raises: + ValueError: If user has pinned sagemaker to a version using HMAC hashing + """ + import tempfile + + SAGEMAKER_MIN_VERSION = "sagemaker>=2.256.0,<3.0.0" + + if local_dependencies_path is None: + # Create a temporary requirements.txt in the system temp directory + # This avoids overwriting any user files in their working directory + fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_") + os.close(fd) # Close the file descriptor, we'll write to it ourselves + + with open(req_file, "w") as f: + f.write(f"{SAGEMAKER_MIN_VERSION}\n") + logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION) + return req_file + + # If dependencies provided, ensure sagemaker is included + if local_dependencies_path.endswith(".txt"): + with open(local_dependencies_path, "r") as f: + content = f.read() + + # Check if sagemaker is already specified + if "sagemaker" in content.lower(): + # Extract the sagemaker requirement line for compatibility check + for line in content.split('\n'): + if 'sagemaker' in line.lower(): + _check_sagemaker_version_compatibility(line.strip()) + break + else: + with open(local_dependencies_path, "a") as f: + f.write(f"\n{SAGEMAKER_MIN_VERSION}\n") + logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION) + + return local_dependencies_path + + def _prepare_dependencies_and_pre_execution_scripts( local_dependencies_path: str, pre_execution_commands: List[str], diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py new file mode 100644 index 0000000000..a4b7e99806 --- /dev/null +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -0,0 +1,127 @@ +"""Integration tests for sagemaker dependency injection in remote functions. + +These tests verify that the sagemaker>=2.256.0 dependency is properly injected +into remote function jobs, preventing version mismatch issues. +""" + +import os +import sys +import tempfile +import pytest + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) + +from sagemaker.remote_function import remote + + +class TestRemoteFunctionDependencyInjection: + """Integration tests for dependency injection in remote functions.""" + + @pytest.mark.integ + def test_remote_function_without_dependencies(self): + """Test remote function execution without explicit dependencies. + + This test verifies that when no dependencies are provided, the remote + function still executes successfully because sagemaker>=2.256.0 is + automatically injected. + """ + @remote( + instance_type="ml.m5.large", + # No dependencies specified - sagemaker should be injected automatically + ) + def simple_add(x, y): + """Simple function that adds two numbers.""" + return x + y + + # Execute the function + result = simple_add(5, 3) + + # Verify result + assert result == 8, f"Expected 8, got {result}" + print("✓ Remote function without dependencies executed successfully") + + @pytest.mark.integ + def test_remote_function_with_user_dependencies_no_sagemaker(self): + """Test remote function with user dependencies but no sagemaker. + + This test verifies that when user provides dependencies without sagemaker, + sagemaker>=2.256.0 is automatically appended. + """ + # Create a temporary requirements.txt without sagemaker + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\npandas>=1.3.0\n") + req_file = f.name + + try: + @remote( + instance_type="ml.m5.large", + dependencies=req_file, + ) + def compute_with_numpy(x): + """Function that uses numpy.""" + import numpy as np + return np.array([x, x*2, x*3]).sum() + + # Execute the function + result = compute_with_numpy(5) + + # Verify result (5 + 10 + 15 = 30) + assert result == 30, f"Expected 30, got {result}" + print("✓ Remote function with user dependencies executed successfully") + finally: + os.remove(req_file) + + +class TestRemoteFunctionVersionCompatibility: + """Tests for version compatibility between local and remote environments.""" + + @pytest.mark.integ + def test_deserialization_with_injected_sagemaker(self): + """Test that deserialization works with injected sagemaker dependency. + + This test verifies that the remote environment can properly deserialize + functions when sagemaker>=2.256.0 is available. + """ + @remote( + instance_type="ml.m5.large", + ) + def complex_computation(data): + """Function that performs complex computation.""" + result = sum(data) * len(data) + return result + + # Execute with various data types + test_data = [1, 2, 3, 4, 5] + result = complex_computation(test_data) + + # Verify result (sum=15, len=5, 15*5=75) + assert result == 75, f"Expected 75, got {result}" + print("✓ Deserialization with injected sagemaker works correctly") + + @pytest.mark.integ + def test_multiple_remote_functions_with_dependencies(self): + """Test multiple remote functions with different dependency configurations. + + This test verifies that the dependency injection works correctly + when multiple remote functions are defined and executed. + """ + @remote(instance_type="ml.m5.large") + def func1(x): + return x + 1 + + @remote(instance_type="ml.m5.large") + def func2(x): + return x * 2 + + # Execute both functions + result1 = func1(5) + result2 = func2(5) + + assert result1 == 6, f"func1: Expected 6, got {result1}" + assert result2 == 10, f"func2: Expected 10, got {result2}" + print("✓ Multiple remote functions with dependencies executed successfully") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integ"]) diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py new file mode 100644 index 0000000000..f64ad654d5 --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -0,0 +1,316 @@ +"""Unit tests for _ensure_sagemaker_dependency function. + +Tests the logic that ensures sagemaker>=2.256.0 is included in remote function dependencies +to prevent version mismatch issues with HMAC key integrity checks. +""" + +import os +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +# Add src to path +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) + +from sagemaker.remote_function.job import _ensure_sagemaker_dependency, _check_sagemaker_version_compatibility + + +class TestEnsureSagemakerDependency(unittest.TestCase): + """Test cases for _ensure_sagemaker_dependency function.""" + + def test_no_dependencies_creates_temp_requirements_file(self): + """Test that a temp requirements.txt is created when no dependencies provided.""" + result = _ensure_sagemaker_dependency(None) + + # Verify file was created + self.assertTrue(os.path.exists(result), f"Requirements file not created at {result}") + + # Verify it's in temp directory + self.assertIn(tempfile.gettempdir(), result) + + # Verify content + with open(result, "r") as f: + content = f.read() + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + + # Cleanup + os.remove(result) + + def test_no_dependencies_file_has_correct_format(self): + """Test that created requirements.txt has correct format.""" + result = _ensure_sagemaker_dependency(None) + + with open(result, "r") as f: + lines = f.readlines() + + # Should have exactly one line with sagemaker dependency + self.assertEqual(len(lines), 1) + self.assertEqual(lines[0].strip(), "sagemaker>=2.256.0,<3.0.0") + + # Cleanup + os.remove(result) + + def test_appends_sagemaker_to_existing_requirements(self): + """Test that sagemaker is appended to existing requirements.txt.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\npandas>=1.3.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Verify content + with open(result, "r") as f: + content = f.read() + + self.assertIn("numpy>=1.20.0", content) + self.assertIn("pandas>=1.3.0", content) + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + finally: + os.remove(temp_file) + + def test_does_not_duplicate_sagemaker_if_already_present(self): + """Test that sagemaker is not duplicated if already in requirements.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\nsagemaker>=2.256.0,<3.0.0\npandas>=1.3.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # Count occurrences of sagemaker + sagemaker_count = content.lower().count("sagemaker") + self.assertEqual(sagemaker_count, 1, "sagemaker should appear exactly once") + + # Verify user's version is preserved + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + finally: + os.remove(temp_file) + + def test_preserves_user_dependencies(self): + """Test that user's existing dependencies are preserved.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("torch>=1.9.0\ntorchvision>=0.10.0\nscikit-learn>=0.24.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # All user dependencies should be present + self.assertIn("torch>=1.9.0", content) + self.assertIn("torchvision>=0.10.0", content) + self.assertIn("scikit-learn>=0.24.0", content) + self.assertIn("sagemaker>=2.256.0,<3.0.0", content) + finally: + os.remove(temp_file) + + def test_handles_yml_files_gracefully(self): + """Test that yml files are returned unchanged.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + f.write("name: test-env\nchannels:\n - conda-forge\ndependencies:\n - numpy\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Content should be unchanged (yml files are not modified) + with open(result, "r") as f: + content = f.read() + + self.assertNotIn("sagemaker", content.lower()) + finally: + os.remove(temp_file) + + def test_handles_yaml_files_gracefully(self): + """Test that yaml files are returned unchanged.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write("name: test-env\nchannels:\n - conda-forge\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Content should be unchanged + with open(result, "r") as f: + content = f.read() + + self.assertNotIn("sagemaker", content.lower()) + finally: + os.remove(temp_file) + + def test_case_insensitive_sagemaker_detection(self): + """Test that sagemaker detection is case-insensitive.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\nSAGEMAKER>=2.256.0,<3.0.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # Should not duplicate even with different case + sagemaker_count = content.lower().count("sagemaker") + self.assertEqual(sagemaker_count, 1) + finally: + os.remove(temp_file) + + def test_temp_file_location(self): + """Test that temp file is created in system temp directory.""" + result = _ensure_sagemaker_dependency(None) + + # Should be in system temp directory + temp_dir = tempfile.gettempdir() + self.assertTrue(result.startswith(temp_dir)) + + # Should have correct prefix + self.assertIn("sagemaker_requirements_", result) + + # Cleanup + os.remove(result) + + def test_version_constraint_format(self): + """Test that version constraint has correct format.""" + result = _ensure_sagemaker_dependency(None) + + with open(result, "r") as f: + content = f.read().strip() + + # Should have both lower and upper bounds + self.assertIn(">=2.256.0", content) + self.assertIn("<3.0.0", content) + + # Cleanup + os.remove(result) + + +class TestCheckSagemakerVersionCompatibility(unittest.TestCase): + """Test cases for _check_sagemaker_version_compatibility function.""" + + # ===== GOOD CASES (should NOT raise ValueError) ===== + + def test_v2_good_exact_version_256(self): + """Test V2 exact version 2.256.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==2.256.0") + + def test_v2_good_exact_version_300(self): + """Test V2 exact version 2.300.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==2.300.0") + + def test_v2_good_range_256_to_300(self): + """Test V2 range 2.256.0 to 2.300.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=2.256.0,<2.300.0") + + def test_v3_good_exact_version_32(self): + """Test V3 exact version 3.2.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==3.2.0") + + def test_v3_good_greater_equal_32(self): + """Test V3 greater or equal 3.2.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0") + + def test_v3_good_range_32_to_40(self): + """Test V3 range 3.2.0 to 4.0.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0,<4.0.0") + + def test_unparseable_requirement_no_error(self): + """Test that unparseable requirements don't raise (let pip handle it).""" + # Should not raise - let pip handle invalid syntax + _check_sagemaker_version_compatibility("sagemaker") + _check_sagemaker_version_compatibility("invalid-requirement") + + def test_v2_bad_exact_version_255(self): + """Test V2 exact version 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError) as context: + _check_sagemaker_version_compatibility("sagemaker==2.255.0") + self.assertIn("HMAC-based integrity checks", str(context.exception)) + + def test_v2_bad_exact_version_200(self): + """Test V2 exact version 2.200.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==2.200.0") + + def test_v2_bad_less_than_256(self): + """Test V2 less than 2.256.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<2.256.0") + + def test_v2_bad_less_equal_255(self): + """Test V2 less or equal 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<=2.255.0") + + def test_v2_bad_greater_than_255_0(self): + """Test V2 greater than 2.255.0 (bad - could be 2.255.1 with HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>2.255.0") + + def test_v2_bad_range_200_to_255(self): + """Test V2 range 2.200.0 to 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.256.0") + + def test_v3_bad_exact_version_31(self): + """Test V3 exact version 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==3.1.0") + + def test_v3_bad_exact_version_300(self): + """Test V3 exact version 3.0.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==3.0.0") + + def test_v3_bad_less_than_32(self): + """Test V3 less than 3.2.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<3.2.0") + + def test_v3_bad_less_equal_31(self): + """Test V3 less or equal 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<=3.1.0") + + def test_v3_bad_range_300_to_31(self): + """Test V3 range 3.0.0 to 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=3.0.0,<3.2.0") + + # ===== EDGE CASES ===== + + def test_multiple_version_specifiers_good(self): + """Test multiple version specifiers that are good.""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=2.256.0,<3.0.0") + + def test_multiple_version_specifiers_bad(self): + """Test multiple version specifiers that are bad.""" + # Should raise because lower bound is < 2.256.0 + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.300.0") + + +if __name__ == "__main__": + unittest.main() From 5d84038e7dfe22e2bb4c3d6f0f95141727754678 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 12:33:36 -0800 Subject: [PATCH 02/10] Revise sagemaker compatibility check --- src/sagemaker/remote_function/job.py | 157 +++++++++++++++--- .../test_ensure_sagemaker_dependency.py | 15 +- 2 files changed, 142 insertions(+), 30 deletions(-) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 468a4077be..43000287f4 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -1296,6 +1296,111 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): return input_data_config +def _decrement_version(version_str: str) -> str: + """Decrement a version string by one minor or patch version. + + Rules: + - If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0 + - If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1 + + Args: + version_str: Version string (e.g., "2.256.0") + + Returns: + Decremented version string + """ + from packaging import version as pkg_version + + try: + parsed = pkg_version.parse(version_str) + major = parsed.major + minor = parsed.minor + patch = parsed.micro + + if patch == 0: + # Decrement minor version + minor = max(0, minor - 1) + else: + # Decrement patch version + patch = max(0, patch - 1) + + return f"{major}.{minor}.{patch}" + except Exception: + return version_str + + +def _resolve_version_from_specifier(specifier_str: str) -> str: + """Resolve the version to check based on upper bounds. + + Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only). + If no upper bound exists, it's safe (unbounded). + If the decremented upper bound is less than a lower bound, use the lower bound. + + Args: + specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0") + + Returns: + The resolved version string to check, or None if safe + """ + import re + from packaging import version as pkg_version + + # Handle exact version pinning (==) + match = re.search(r'==\s*([\d.]+)', specifier_str) + if match: + return match.group(1) + + # Extract lower bounds for comparison + lower_bounds = [] + for match in re.finditer(r'>=\s*([\d.]+)', specifier_str): + lower_bounds.append(match.group(1)) + + # Handle upper bounds - find the most restrictive one + upper_bounds = [] + + # Find all <= bounds + for match in re.finditer(r'<=\s*([\d.]+)', specifier_str): + upper_bounds.append(('<=', match.group(1))) + + # Find all < bounds + for match in re.finditer(r'<\s*([\d.]+)', specifier_str): + upper_bounds.append(('<', match.group(1))) + + if upper_bounds: + # Sort by version to find the most restrictive (lowest) upper bound + upper_bounds.sort(key=lambda x: pkg_version.parse(x[1])) + operator, version = upper_bounds[0] + + # Special case: if upper bound is <3.0.0, it's safe (V2 only) + try: + parsed_upper = pkg_version.parse(version) + if operator == '<' and parsed_upper.major == 3 and parsed_upper.minor == 0 and parsed_upper.micro == 0: + # <3.0.0 means V2 only, which is safe + return None + except Exception: + pass + + resolved_version = version + if operator == '<': + resolved_version = _decrement_version(version) + + # If we have a lower bound and the resolved version is less than it, use the lower bound + if lower_bounds: + try: + resolved_parsed = pkg_version.parse(resolved_version) + for lower_bound_str in lower_bounds: + lower_parsed = pkg_version.parse(lower_bound_str) + if resolved_parsed < lower_parsed: + resolved_version = lower_bound_str + except Exception: + pass + + return resolved_version + + # For lower bounds only (>=, >), we don't check + return None + + def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: """Check if the sagemaker version requirement uses incompatible hashing. @@ -1309,42 +1414,44 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: ValueError: If the requirement would install a version using HMAC hashing """ import re - from packaging.specifiers import SpecifierSet from packaging import version as pkg_version - + match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE) if not match: return specifier_str = match.group(1).strip() + # Resolve the version that would be installed + resolved_version_str = _resolve_version_from_specifier(specifier_str) + if not resolved_version_str: + # No upper bound or exact version, so we can't determine if it's bad + return + try: - specifier_set = SpecifierSet(specifier_str) + resolved_version = pkg_version.parse(resolved_version_str) except Exception: return - - # Test if any HMAC version would satisfy the specifier - # V2 HMAC versions: < 2.256.0 - v2_hmac_test_versions = ["2.0.0", "2.100.0", "2.200.0", "2.255.0", "2.255.1", "2.255.99"] - for test_version in v2_hmac_test_versions: - if test_version in specifier_set: - raise ValueError( - f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " - f"could install a version using HMAC-based integrity checks which are incompatible " - f"with the current SHA256-based integrity checks. Please update to " - f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." - ) - # V3 HMAC versions: < 3.2.0 - v3_hmac_test_versions = ["3.0.0", "3.0.1", "3.1.0", "3.1.99"] - for test_version in v3_hmac_test_versions: - if test_version in specifier_set: - raise ValueError( - f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " - f"could install a version using HMAC-based integrity checks which are incompatible " - f"with the current SHA256-based integrity checks. Please update to " - f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." - ) + # Define HMAC thresholds for each major version + v2_hmac_threshold = pkg_version.parse("2.256.0") + v3_hmac_threshold = pkg_version.parse("3.2.0") + + # Check if the resolved version uses HMAC hashing + uses_hmac = False + if resolved_version.major == 2 and resolved_version < v2_hmac_threshold: + uses_hmac = True + elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold: + uses_hmac = True + + if uses_hmac: + raise ValueError( + f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " + f"could install a version using HMAC-based integrity checks which are incompatible " + f"with the current SHA256-based integrity checks. Please update to " + f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." + ) + def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py index f64ad654d5..93a4e718db 100644 --- a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -264,9 +264,9 @@ def test_v2_bad_less_equal_255(self): _check_sagemaker_version_compatibility("sagemaker<=2.255.0") def test_v2_bad_greater_than_255_0(self): - """Test V2 greater than 2.255.0 (bad - could be 2.255.1 with HMAC).""" - with self.assertRaises(ValueError): - _check_sagemaker_version_compatibility("sagemaker>2.255.0") + """Test V2 greater than 2.255.0 (not checked - treat as lower bound only).""" + # Should not raise - > is treated as a lower bound, we don't check those + _check_sagemaker_version_compatibility("sagemaker>2.255.0") def test_v2_bad_range_200_to_255(self): """Test V2 range 2.200.0 to 2.255.0 (bad - HMAC).""" @@ -305,11 +305,16 @@ def test_multiple_version_specifiers_good(self): # Should not raise _check_sagemaker_version_compatibility("sagemaker>=2.256.0,<3.0.0") + def test_multiple_version_specifiers_good_with_lower_bound(self): + """Test multiple version specifiers that are good (upper bound resolves to good version).""" + # Should not raise - <2.300.0 decrements to 2.299.0 which is >= 2.256.0 + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.300.0") + def test_multiple_version_specifiers_bad(self): """Test multiple version specifiers that are bad.""" - # Should raise because lower bound is < 2.256.0 + # Should raise - <2.256.0 decrements to 2.255.0 which is < 2.256.0 (HMAC) with self.assertRaises(ValueError): - _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.300.0") + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.256.0") if __name__ == "__main__": From 2b870a38f4414a03c36dc2a2cf7a7ba8344839d1 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 15:46:49 -0800 Subject: [PATCH 03/10] Fixing unit and itnegration tests --- .../test_sagemaker_dependency_injection.py | 10 +++ .../sagemaker/remote_function/test_job.py | 77 +++++++++++++------ 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index a4b7e99806..633f11a90d 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -9,6 +9,12 @@ import tempfile import pytest +# Skip decorator for AWS configuration +skip_if_no_aws_region = pytest.mark.skipif( + not os.environ.get('AWS_DEFAULT_REGION'), + reason="AWS credentials not configured" +) + # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) @@ -19,6 +25,7 @@ class TestRemoteFunctionDependencyInjection: """Integration tests for dependency injection in remote functions.""" @pytest.mark.integ + @skip_if_no_aws_region def test_remote_function_without_dependencies(self): """Test remote function execution without explicit dependencies. @@ -42,6 +49,7 @@ def simple_add(x, y): print("✓ Remote function without dependencies executed successfully") @pytest.mark.integ + @skip_if_no_aws_region def test_remote_function_with_user_dependencies_no_sagemaker(self): """Test remote function with user dependencies but no sagemaker. @@ -77,6 +85,7 @@ class TestRemoteFunctionVersionCompatibility: """Tests for version compatibility between local and remote environments.""" @pytest.mark.integ + @skip_if_no_aws_region def test_deserialization_with_injected_sagemaker(self): """Test that deserialization works with injected sagemaker dependency. @@ -100,6 +109,7 @@ def complex_computation(data): print("✓ Deserialization with injected sagemaker works correctly") @pytest.mark.integ + @skip_if_no_aws_region def test_multiple_remote_functions_with_dependencies(self): """Test multiple remote functions with different dependency configurations. diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 62987395ee..f0aa1695e4 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -525,6 +525,7 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess monkeypatch.delenv("SAGEMAKER_INTERNAL_IMAGE_URI") +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -539,7 +540,10 @@ def test_start( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -562,7 +566,6 @@ def test_start( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -576,7 +579,7 @@ def test_start( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -629,7 +632,7 @@ def test_start( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -647,6 +650,7 @@ def test_start( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -661,7 +665,10 @@ def test_start_with_checkpoint_location( mock_runtime_manager, mock_script_upload, mock_user_workspace_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -743,7 +750,7 @@ def test_start_with_checkpoint_location( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -797,6 +804,7 @@ def test_start_with_checkpoint_location_failed_with_multiple_checkpoint_location ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -810,10 +818,14 @@ def test_start_with_complete_job_settings( mock_runtime_manager, mock_bootstrap_script_upload, mock_user_workspace_upload, + mock_ensure_sagemaker, ): + # This test has explicit dependencies, so mock returns the same path + dependencies_path = "path/to/dependencies/req.txt" + mock_ensure_sagemaker.return_value = dependencies_path job_settings = _JobSettings( - dependencies="path/to/dependencies/req.txt", + dependencies=dependencies_path, pre_execution_script="path/to/script.sh", environment_variables={"AWS_DEFAULT_REGION": "us-east-2"}, image_uri=IMAGE, @@ -839,7 +851,6 @@ def test_start_with_complete_job_settings( s3_kms_key=KMS_KEY_ARN, ) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -853,7 +864,7 @@ def test_start_with_complete_job_settings( ) mock_user_workspace_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path=dependencies_path, include_local_workdir=False, pre_execution_commands=None, pre_execution_script_local_path="path/to/script.sh", @@ -932,6 +943,7 @@ def test_start_with_complete_job_settings( @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch( "sagemaker.remote_function.job._prepare_dependencies_and_pre_execution_scripts", return_value="some_s3_uri", @@ -950,7 +962,11 @@ def test_get_train_args_under_pipeline_context( mock_bootstrap_scripts_upload, mock_user_workspace_upload, mock_user_dependencies_upload, + mock_ensure_sagemaker, ): + # This test has explicit dependencies, so mock returns the same path + dependencies_path = "path/to/dependencies/req.txt" + mock_ensure_sagemaker.return_value = dependencies_path from sagemaker.workflow.parameters import ParameterInteger @@ -965,7 +981,7 @@ def test_get_train_args_under_pipeline_context( function_step._properties.OutputDataConfig.S3OutputPath = func_step_s3_output_prop job_settings = _JobSettings( - dependencies="path/to/dependencies/req.txt", + dependencies=dependencies_path, pre_execution_script="path/to/script.sh", environment_variables={"AWS_DEFAULT_REGION": "us-east-2"}, image_uri=IMAGE, @@ -1026,7 +1042,7 @@ def test_get_train_args_under_pipeline_context( ) mock_user_workspace_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path=dependencies_path, include_local_workdir=False, pre_execution_commands=None, pre_execution_script_local_path="path/to/script.sh", @@ -1148,6 +1164,7 @@ def test_get_train_args_under_pipeline_context( "sagemaker.remote_function.job._prepare_and_upload_spark_dependent_files", return_value=tuple(["jars_s3_uri", "py_files_s3_uri", "files_s3_uri", "config_file_s3_uri"]), ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -1162,9 +1179,13 @@ def test_start_with_spark( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, mock_spark_dependency_upload, mock_get_default_spark_image, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" + spark_config = SparkConfig() job_settings = _JobSettings( spark_config=spark_config, @@ -1258,7 +1279,7 @@ def test_start_with_spark( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--run_in_context", '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', ], @@ -1804,6 +1825,7 @@ def test_extend_spark_config_to_request( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -1818,7 +1840,10 @@ def test_start_with_torchrun_single_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -1843,7 +1868,6 @@ def test_start_with_torchrun_single_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -1857,7 +1881,7 @@ def test_start_with_torchrun_single_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -1910,7 +1934,7 @@ def test_start_with_torchrun_single_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "torchrun", "--run_in_context", @@ -1930,6 +1954,7 @@ def test_start_with_torchrun_single_node( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -1944,7 +1969,10 @@ def test_start_with_torchrun_multi_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -1970,7 +1998,6 @@ def test_start_with_torchrun_multi_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -1984,7 +2011,7 @@ def test_start_with_torchrun_multi_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -2039,7 +2066,7 @@ def test_start_with_torchrun_multi_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "torchrun", "--run_in_context", @@ -2319,6 +2346,7 @@ def test_set_env_multi_node_multi_gpu_mpirun( assert env_file == expected_env +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -2333,7 +2361,10 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -2359,7 +2390,6 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -2373,7 +2403,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -2426,7 +2456,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "torchrun", "--user_nproc_per_node", @@ -2448,6 +2478,7 @@ def test_start_with_torchrun_single_node_with_nproc_per_node( ) +@patch("sagemaker.remote_function.job._ensure_sagemaker_dependency") @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) @patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") @patch( @@ -2462,7 +2493,10 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( mock_runtime_manager, mock_script_upload, mock_dependency_upload, + mock_ensure_sagemaker, ): + # Mock returns a fixed temp file path for tests without explicit dependencies + mock_ensure_sagemaker.return_value = "/tmp/sagemaker_requirements_test.txt" job_settings = _JobSettings( image_uri=IMAGE, @@ -2488,7 +2522,6 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) - local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() @@ -2502,7 +2535,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( ) mock_dependency_upload.assert_called_once_with( - local_dependencies_path=local_dependencies_path, + local_dependencies_path="/tmp/sagemaker_requirements_test.txt", include_local_workdir=True, pre_execution_commands=None, pre_execution_script_local_path=None, @@ -2555,7 +2588,7 @@ def test_start_with_mpirun_single_node_with_nproc_per_node( "--client_sagemaker_pysdk_version", mock_sagemaker_pysdk_version, "--dependency_settings", - '{"dependency_file": null}', + '{"dependency_file": "sagemaker_requirements_test.txt"}', "--distribution", "mpirun", "--user_nproc_per_node", From 947560310fad66d8174dcd855767538de2392dec Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:42:48 -0800 Subject: [PATCH 04/10] Fix codestyle issues --- src/sagemaker/remote_function/job.py | 85 +++++++++--------- .../test_sagemaker_dependency_injection.py | 39 ++++---- .../test_ensure_sagemaker_dependency.py | 90 ++++++++++--------- 3 files changed, 113 insertions(+), 101 deletions(-) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 43000287f4..0e456d3e91 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -1298,32 +1298,32 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): def _decrement_version(version_str: str) -> str: """Decrement a version string by one minor or patch version. - + Rules: - If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0 - If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1 - + Args: version_str: Version string (e.g., "2.256.0") - + Returns: Decremented version string """ from packaging import version as pkg_version - + try: parsed = pkg_version.parse(version_str) major = parsed.major minor = parsed.minor patch = parsed.micro - + if patch == 0: # Decrement minor version minor = max(0, minor - 1) else: # Decrement patch version patch = max(0, patch - 1) - + return f"{major}.{minor}.{patch}" except Exception: return version_str @@ -1335,55 +1335,60 @@ def _resolve_version_from_specifier(specifier_str: str) -> str: Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only). If no upper bound exists, it's safe (unbounded). If the decremented upper bound is less than a lower bound, use the lower bound. - + Args: specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0") - + Returns: The resolved version string to check, or None if safe """ import re from packaging import version as pkg_version - + # Handle exact version pinning (==) - match = re.search(r'==\s*([\d.]+)', specifier_str) + match = re.search(r"==\s*([\d.]+)", specifier_str) if match: return match.group(1) - + # Extract lower bounds for comparison lower_bounds = [] - for match in re.finditer(r'>=\s*([\d.]+)', specifier_str): + for match in re.finditer(r">=\s*([\d.]+)", specifier_str): lower_bounds.append(match.group(1)) - + # Handle upper bounds - find the most restrictive one upper_bounds = [] - + # Find all <= bounds - for match in re.finditer(r'<=\s*([\d.]+)', specifier_str): - upper_bounds.append(('<=', match.group(1))) - + for match in re.finditer(r"<=\s*([\d.]+)", specifier_str): + upper_bounds.append(("<=", match.group(1))) + # Find all < bounds - for match in re.finditer(r'<\s*([\d.]+)', specifier_str): - upper_bounds.append(('<', match.group(1))) - + for match in re.finditer(r"<\s*([\d.]+)", specifier_str): + upper_bounds.append(("<", match.group(1))) + if upper_bounds: # Sort by version to find the most restrictive (lowest) upper bound upper_bounds.sort(key=lambda x: pkg_version.parse(x[1])) operator, version = upper_bounds[0] - + # Special case: if upper bound is <3.0.0, it's safe (V2 only) try: parsed_upper = pkg_version.parse(version) - if operator == '<' and parsed_upper.major == 3 and parsed_upper.minor == 0 and parsed_upper.micro == 0: + if ( + operator == "<" + and parsed_upper.major == 3 + and parsed_upper.minor == 0 + and parsed_upper.micro == 0 + ): # <3.0.0 means V2 only, which is safe return None except Exception: pass - + resolved_version = version - if operator == '<': + if operator == "<": resolved_version = _decrement_version(version) - + # If we have a lower bound and the resolved version is less than it, use the lower bound if lower_bounds: try: @@ -1394,9 +1399,9 @@ def _resolve_version_from_specifier(specifier_str: str) -> str: resolved_version = lower_bound_str except Exception: pass - + return resolved_version - + # For lower bounds only (>=, >), we don't check return None @@ -1415,35 +1420,35 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: """ import re from packaging import version as pkg_version - - match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE) + + match = re.search(r"sagemaker\s*(.+)$", sagemaker_requirement.strip(), re.IGNORECASE) if not match: return specifier_str = match.group(1).strip() - + # Resolve the version that would be installed resolved_version_str = _resolve_version_from_specifier(specifier_str) if not resolved_version_str: # No upper bound or exact version, so we can't determine if it's bad return - + try: resolved_version = pkg_version.parse(resolved_version_str) except Exception: return - + # Define HMAC thresholds for each major version v2_hmac_threshold = pkg_version.parse("2.256.0") v3_hmac_threshold = pkg_version.parse("3.2.0") - + # Check if the resolved version uses HMAC hashing uses_hmac = False if resolved_version.major == 2 and resolved_version < v2_hmac_threshold: uses_hmac = True elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold: uses_hmac = True - + if uses_hmac: raise ValueError( f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " @@ -1453,7 +1458,6 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: ) - def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: """Ensure sagemaker>=2.256.0 is in the dependencies. @@ -1481,13 +1485,14 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: if local_dependencies_path is None: # Create a temporary requirements.txt in the system temp directory - # This avoids overwriting any user files in their working directory fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_") - os.close(fd) # Close the file descriptor, we'll write to it ourselves + os.close(fd) with open(req_file, "w") as f: f.write(f"{SAGEMAKER_MIN_VERSION}\n") - logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION) + logger.info( + "Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION + ) return req_file # If dependencies provided, ensure sagemaker is included @@ -1498,8 +1503,8 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: # Check if sagemaker is already specified if "sagemaker" in content.lower(): # Extract the sagemaker requirement line for compatibility check - for line in content.split('\n'): - if 'sagemaker' in line.lower(): + for line in content.split("\n"): + if "sagemaker" in line.lower(): _check_sagemaker_version_compatibility(line.strip()) break else: diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index 633f11a90d..62ee63d40a 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -11,12 +11,11 @@ # Skip decorator for AWS configuration skip_if_no_aws_region = pytest.mark.skipif( - not os.environ.get('AWS_DEFAULT_REGION'), - reason="AWS credentials not configured" + not os.environ.get('AWS_DEFAULT_REGION'), reason="AWS credentials not configured" ) # Add src to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) from sagemaker.remote_function import remote @@ -28,11 +27,12 @@ class TestRemoteFunctionDependencyInjection: @skip_if_no_aws_region def test_remote_function_without_dependencies(self): """Test remote function execution without explicit dependencies. - + This test verifies that when no dependencies are provided, the remote function still executes successfully because sagemaker>=2.256.0 is automatically injected. """ + @remote( instance_type="ml.m5.large", # No dependencies specified - sagemaker should be injected automatically @@ -40,10 +40,10 @@ def test_remote_function_without_dependencies(self): def simple_add(x, y): """Simple function that adds two numbers.""" return x + y - + # Execute the function result = simple_add(5, 3) - + # Verify result assert result == 8, f"Expected 8, got {result}" print("✓ Remote function without dependencies executed successfully") @@ -52,15 +52,15 @@ def simple_add(x, y): @skip_if_no_aws_region def test_remote_function_with_user_dependencies_no_sagemaker(self): """Test remote function with user dependencies but no sagemaker. - + This test verifies that when user provides dependencies without sagemaker, sagemaker>=2.256.0 is automatically appended. """ # Create a temporary requirements.txt without sagemaker - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: f.write("numpy>=1.20.0\npandas>=1.3.0\n") req_file = f.name - + try: @remote( instance_type="ml.m5.large", @@ -69,11 +69,12 @@ def test_remote_function_with_user_dependencies_no_sagemaker(self): def compute_with_numpy(x): """Function that uses numpy.""" import numpy as np + return np.array([x, x*2, x*3]).sum() - + # Execute the function result = compute_with_numpy(5) - + # Verify result (5 + 10 + 15 = 30) assert result == 30, f"Expected 30, got {result}" print("✓ Remote function with user dependencies executed successfully") @@ -88,10 +89,11 @@ class TestRemoteFunctionVersionCompatibility: @skip_if_no_aws_region def test_deserialization_with_injected_sagemaker(self): """Test that deserialization works with injected sagemaker dependency. - + This test verifies that the remote environment can properly deserialize functions when sagemaker>=2.256.0 is available. """ + @remote( instance_type="ml.m5.large", ) @@ -99,11 +101,11 @@ def complex_computation(data): """Function that performs complex computation.""" result = sum(data) * len(data) return result - + # Execute with various data types test_data = [1, 2, 3, 4, 5] result = complex_computation(test_data) - + # Verify result (sum=15, len=5, 15*5=75) assert result == 75, f"Expected 75, got {result}" print("✓ Deserialization with injected sagemaker works correctly") @@ -112,22 +114,23 @@ def complex_computation(data): @skip_if_no_aws_region def test_multiple_remote_functions_with_dependencies(self): """Test multiple remote functions with different dependency configurations. - + This test verifies that the dependency injection works correctly when multiple remote functions are defined and executed. """ + @remote(instance_type="ml.m5.large") def func1(x): return x + 1 - + @remote(instance_type="ml.m5.large") def func2(x): return x * 2 - + # Execute both functions result1 = func1(5) result2 = func2(5) - + assert result1 == 6, f"func1: Expected 6, got {result1}" assert result2 == 10, f"func2: Expected 10, got {result2}" print("✓ Multiple remote functions with dependencies executed successfully") diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py index 93a4e718db..b06984903b 100644 --- a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -11,9 +11,13 @@ # Add src to path import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) -from sagemaker.remote_function.job import _ensure_sagemaker_dependency, _check_sagemaker_version_compatibility +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) + +from sagemaker.remote_function.job import ( + _ensure_sagemaker_dependency, + _check_sagemaker_version_compatibility, +) class TestEnsureSagemakerDependency(unittest.TestCase): @@ -22,51 +26,51 @@ class TestEnsureSagemakerDependency(unittest.TestCase): def test_no_dependencies_creates_temp_requirements_file(self): """Test that a temp requirements.txt is created when no dependencies provided.""" result = _ensure_sagemaker_dependency(None) - + # Verify file was created self.assertTrue(os.path.exists(result), f"Requirements file not created at {result}") - + # Verify it's in temp directory self.assertIn(tempfile.gettempdir(), result) - + # Verify content with open(result, "r") as f: content = f.read() self.assertIn("sagemaker>=2.256.0,<3.0.0", content) - + # Cleanup os.remove(result) def test_no_dependencies_file_has_correct_format(self): """Test that created requirements.txt has correct format.""" result = _ensure_sagemaker_dependency(None) - + with open(result, "r") as f: lines = f.readlines() - + # Should have exactly one line with sagemaker dependency self.assertEqual(len(lines), 1) self.assertEqual(lines[0].strip(), "sagemaker>=2.256.0,<3.0.0") - + # Cleanup os.remove(result) def test_appends_sagemaker_to_existing_requirements(self): """Test that sagemaker is appended to existing requirements.txt.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: f.write("numpy>=1.20.0\npandas>=1.3.0\n") temp_file = f.name - + try: result = _ensure_sagemaker_dependency(temp_file) - + # Should return the same file self.assertEqual(result, temp_file) - + # Verify content with open(result, "r") as f: content = f.read() - + self.assertIn("numpy>=1.20.0", content) self.assertIn("pandas>=1.3.0", content) self.assertIn("sagemaker>=2.256.0,<3.0.0", content) @@ -75,20 +79,20 @@ def test_appends_sagemaker_to_existing_requirements(self): def test_does_not_duplicate_sagemaker_if_already_present(self): """Test that sagemaker is not duplicated if already in requirements.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: f.write("numpy>=1.20.0\nsagemaker>=2.256.0,<3.0.0\npandas>=1.3.0\n") temp_file = f.name - + try: result = _ensure_sagemaker_dependency(temp_file) - + with open(result, "r") as f: content = f.read() - + # Count occurrences of sagemaker sagemaker_count = content.lower().count("sagemaker") self.assertEqual(sagemaker_count, 1, "sagemaker should appear exactly once") - + # Verify user's version is preserved self.assertIn("sagemaker>=2.256.0,<3.0.0", content) finally: @@ -96,16 +100,16 @@ def test_does_not_duplicate_sagemaker_if_already_present(self): def test_preserves_user_dependencies(self): """Test that user's existing dependencies are preserved.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: f.write("torch>=1.9.0\ntorchvision>=0.10.0\nscikit-learn>=0.24.0\n") temp_file = f.name - + try: result = _ensure_sagemaker_dependency(temp_file) - + with open(result, "r") as f: content = f.read() - + # All user dependencies should be present self.assertIn("torch>=1.9.0", content) self.assertIn("torchvision>=0.10.0", content) @@ -116,56 +120,56 @@ def test_preserves_user_dependencies(self): def test_handles_yml_files_gracefully(self): """Test that yml files are returned unchanged.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: f.write("name: test-env\nchannels:\n - conda-forge\ndependencies:\n - numpy\n") temp_file = f.name - + try: result = _ensure_sagemaker_dependency(temp_file) - + # Should return the same file self.assertEqual(result, temp_file) - + # Content should be unchanged (yml files are not modified) with open(result, "r") as f: content = f.read() - + self.assertNotIn("sagemaker", content.lower()) finally: os.remove(temp_file) def test_handles_yaml_files_gracefully(self): """Test that yaml files are returned unchanged.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: f.write("name: test-env\nchannels:\n - conda-forge\n") temp_file = f.name - + try: result = _ensure_sagemaker_dependency(temp_file) - + # Should return the same file self.assertEqual(result, temp_file) - + # Content should be unchanged with open(result, "r") as f: content = f.read() - + self.assertNotIn("sagemaker", content.lower()) finally: os.remove(temp_file) def test_case_insensitive_sagemaker_detection(self): """Test that sagemaker detection is case-insensitive.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: f.write("numpy>=1.20.0\nSAGEMAKER>=2.256.0,<3.0.0\n") temp_file = f.name - + try: result = _ensure_sagemaker_dependency(temp_file) - + with open(result, "r") as f: content = f.read() - + # Should not duplicate even with different case sagemaker_count = content.lower().count("sagemaker") self.assertEqual(sagemaker_count, 1) @@ -175,28 +179,28 @@ def test_case_insensitive_sagemaker_detection(self): def test_temp_file_location(self): """Test that temp file is created in system temp directory.""" result = _ensure_sagemaker_dependency(None) - + # Should be in system temp directory temp_dir = tempfile.gettempdir() self.assertTrue(result.startswith(temp_dir)) - + # Should have correct prefix self.assertIn("sagemaker_requirements_", result) - + # Cleanup os.remove(result) def test_version_constraint_format(self): """Test that version constraint has correct format.""" result = _ensure_sagemaker_dependency(None) - + with open(result, "r") as f: content = f.read().strip() - + # Should have both lower and upper bounds self.assertIn(">=2.256.0", content) self.assertIn("<3.0.0", content) - + # Cleanup os.remove(result) From bd8bf0eeb2e0ea06fd53413819a1463d9c189b2b Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:58:50 -0800 Subject: [PATCH 05/10] More codestyle fixes --- src/sagemaker/remote_function/job.py | 8 ++++---- .../test_sagemaker_dependency_injection.py | 4 ++-- .../remote_function/test_ensure_sagemaker_dependency.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 0e456d3e91..e1d4018b76 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -1331,7 +1331,7 @@ def _decrement_version(version_str: str) -> str: def _resolve_version_from_specifier(specifier_str: str) -> str: """Resolve the version to check based on upper bounds. - + Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only). If no upper bound exists, it's safe (unbounded). If the decremented upper bound is less than a lower bound, use the lower bound. @@ -1375,9 +1375,9 @@ def _resolve_version_from_specifier(specifier_str: str) -> str: try: parsed_upper = pkg_version.parse(version) if ( - operator == "<" - and parsed_upper.major == 3 - and parsed_upper.minor == 0 + operator == "<" + and parsed_upper.major == 3 + and parsed_upper.minor == 0 and parsed_upper.micro == 0 ): # <3.0.0 means V2 only, which is safe diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index 62ee63d40a..677b7b7dd2 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -11,7 +11,7 @@ # Skip decorator for AWS configuration skip_if_no_aws_region = pytest.mark.skipif( - not os.environ.get('AWS_DEFAULT_REGION'), reason="AWS credentials not configured" + not os.environ.get("AWS_DEFAULT_REGION"), reason="AWS credentials not configured" ) # Add src to path @@ -70,7 +70,7 @@ def compute_with_numpy(x): """Function that uses numpy.""" import numpy as np - return np.array([x, x*2, x*3]).sum() + return np.array([x, x * 2, x * 3]).sum() # Execute the function result = compute_with_numpy(5) diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py index b06984903b..945ac7e0ea 100644 --- a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -15,7 +15,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) from sagemaker.remote_function.job import ( - _ensure_sagemaker_dependency, + _ensure_sagemaker_dependency, _check_sagemaker_version_compatibility, ) From 638ed5e749ecfbae064b60ec46f856b9dc820f37 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 17:04:27 -0800 Subject: [PATCH 06/10] Fixing one more codestyle issue --- .../remote_function/test_sagemaker_dependency_injection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index 677b7b7dd2..6178d7d57a 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -62,6 +62,7 @@ def test_remote_function_with_user_dependencies_no_sagemaker(self): req_file = f.name try: + @remote( instance_type="ml.m5.large", dependencies=req_file, From e10f974f92452e8980c2430a755f2f5d73cd7b98 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 17:10:00 -0800 Subject: [PATCH 07/10] Fixing flake errors --- .../test_sagemaker_dependency_injection.py | 12 +++++++----- .../test_ensure_sagemaker_dependency.py | 5 ++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index 6178d7d57a..284aea4c14 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -3,22 +3,24 @@ These tests verify that the sagemaker>=2.256.0 dependency is properly injected into remote function jobs, preventing version mismatch issues. """ +from __future__ import absolute_import import os import sys import tempfile -import pytest -# Skip decorator for AWS configuration -skip_if_no_aws_region = pytest.mark.skipif( - not os.environ.get("AWS_DEFAULT_REGION"), reason="AWS credentials not configured" -) +import pytest # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) from sagemaker.remote_function import remote +# Skip decorator for AWS configuration +skip_if_no_aws_region = pytest.mark.skipif( + not os.environ.get("AWS_DEFAULT_REGION"), reason="AWS credentials not configured" +) + class TestRemoteFunctionDependencyInjection: """Integration tests for dependency injection in remote functions.""" diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py index 945ac7e0ea..6317a620e8 100644 --- a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -3,15 +3,14 @@ Tests the logic that ensures sagemaker>=2.256.0 is included in remote function dependencies to prevent version mismatch issues with HMAC key integrity checks. """ +from __future__ import absolute_import import os +import sys import tempfile import unittest -from unittest.mock import patch, MagicMock # Add src to path -import sys - sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) from sagemaker.remote_function.job import ( From a36ee2c32fe1a11d7095a39ffa9277bfd73180c2 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 17:18:24 -0800 Subject: [PATCH 08/10] More codestyle fixes --- .../remote_function/test_sagemaker_dependency_injection.py | 1 + .../remote_function/test_ensure_sagemaker_dependency.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index 284aea4c14..a3ad1998cc 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -3,6 +3,7 @@ These tests verify that the sagemaker>=2.256.0 dependency is properly injected into remote function jobs, preventing version mismatch issues. """ + from __future__ import absolute_import import os diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py index 6317a620e8..ea36111649 100644 --- a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -3,6 +3,7 @@ Tests the logic that ensures sagemaker>=2.256.0 is included in remote function dependencies to prevent version mismatch issues with HMAC key integrity checks. """ + from __future__ import absolute_import import os From fe269f99d9c8adb9c141fe7ea848933d80a4c30e Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 17:25:14 -0800 Subject: [PATCH 09/10] More flake test fixes --- .../remote_function/test_sagemaker_dependency_injection.py | 5 ++--- .../remote_function/test_ensure_sagemaker_dependency.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index a3ad1998cc..6d9824c8b7 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -8,14 +8,13 @@ import os import sys -import tempfile import pytest -# Add src to path +# Add src to path before importing sagemaker sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) -from sagemaker.remote_function import remote +from sagemaker.remote_function import remote # noqa: E402 # Skip decorator for AWS configuration skip_if_no_aws_region = pytest.mark.skipif( diff --git a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py index ea36111649..ac7f11e124 100644 --- a/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py +++ b/tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py @@ -11,10 +11,10 @@ import tempfile import unittest -# Add src to path +# Add src to path before importing sagemaker sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src")) -from sagemaker.remote_function.job import ( +from sagemaker.remote_function.job import ( # noqa: E402 _ensure_sagemaker_dependency, _check_sagemaker_version_compatibility, ) From df9714fcb92010f239475aeb8bc9443e92149cbd Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 14 Jan 2026 17:30:44 -0800 Subject: [PATCH 10/10] Fixing one more flake error --- .../remote_function/test_sagemaker_dependency_injection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py index 6d9824c8b7..bc5e7870a9 100644 --- a/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py +++ b/tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py @@ -8,6 +8,7 @@ import os import sys +import tempfile import pytest