Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):

local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)

# Ensure sagemaker dependency is included to prevent version mismatch issues
# Resolves issue where computing hash for integrity check changed in 2.256.0
local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path)
job_settings.dependencies = local_dependencies_path

if step_compilation_context:
with _tmpdir() as tmp_dir:
script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts(
Expand Down Expand Up @@ -1291,6 +1296,225 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
return input_data_config


def _decrement_version(version_str: str) -> str:
"""Decrement a version string by one minor or patch version.

Rules:
- If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0
- If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1

Args:
version_str: Version string (e.g., "2.256.0")

Returns:
Decremented version string
"""
from packaging import version as pkg_version

try:
parsed = pkg_version.parse(version_str)
major = parsed.major
minor = parsed.minor
patch = parsed.micro

if patch == 0:
# Decrement minor version
minor = max(0, minor - 1)
else:
# Decrement patch version
patch = max(0, patch - 1)

return f"{major}.{minor}.{patch}"
except Exception:
return version_str


def _resolve_version_from_specifier(specifier_str: str) -> str:
"""Resolve the version to check based on upper bounds.

Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only).
If no upper bound exists, it's safe (unbounded).
If the decremented upper bound is less than a lower bound, use the lower bound.

Args:
specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0")

Returns:
The resolved version string to check, or None if safe
"""
import re
from packaging import version as pkg_version

# Handle exact version pinning (==)
match = re.search(r"==\s*([\d.]+)", specifier_str)
if match:
return match.group(1)

# Extract lower bounds for comparison
lower_bounds = []
for match in re.finditer(r">=\s*([\d.]+)", specifier_str):
lower_bounds.append(match.group(1))

# Handle upper bounds - find the most restrictive one
upper_bounds = []

# Find all <= bounds
for match in re.finditer(r"<=\s*([\d.]+)", specifier_str):
upper_bounds.append(("<=", match.group(1)))

# Find all < bounds
for match in re.finditer(r"<\s*([\d.]+)", specifier_str):
upper_bounds.append(("<", match.group(1)))

if upper_bounds:
# Sort by version to find the most restrictive (lowest) upper bound
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
operator, version = upper_bounds[0]

# Special case: if upper bound is <3.0.0, it's safe (V2 only)
try:
parsed_upper = pkg_version.parse(version)
if (
operator == "<"
and parsed_upper.major == 3
and parsed_upper.minor == 0
and parsed_upper.micro == 0
):
# <3.0.0 means V2 only, which is safe
return None
except Exception:
pass

resolved_version = version
if operator == "<":
resolved_version = _decrement_version(version)

# If we have a lower bound and the resolved version is less than it, use the lower bound
if lower_bounds:
try:
resolved_parsed = pkg_version.parse(resolved_version)
for lower_bound_str in lower_bounds:
lower_parsed = pkg_version.parse(lower_bound_str)
if resolved_parsed < lower_parsed:
resolved_version = lower_bound_str
except Exception:
pass

return resolved_version

# For lower bounds only (>=, >), we don't check
return None


def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
"""Check if the sagemaker version requirement uses incompatible hashing.

Raises ValueError if the requirement would install a version that uses HMAC hashing
(which is incompatible with the current SHA256-based integrity checks).

Args:
sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=2.200.0")

Raises:
ValueError: If the requirement would install a version using HMAC hashing
"""
import re
from packaging import version as pkg_version

match = re.search(r"sagemaker\s*(.+)$", sagemaker_requirement.strip(), re.IGNORECASE)
if not match:
return

specifier_str = match.group(1).strip()

# Resolve the version that would be installed
resolved_version_str = _resolve_version_from_specifier(specifier_str)
if not resolved_version_str:
# No upper bound or exact version, so we can't determine if it's bad
return

try:
resolved_version = pkg_version.parse(resolved_version_str)
except Exception:
return

# Define HMAC thresholds for each major version
v2_hmac_threshold = pkg_version.parse("2.256.0")
v3_hmac_threshold = pkg_version.parse("3.2.0")

# Check if the resolved version uses HMAC hashing
uses_hmac = False
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
uses_hmac = True
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
uses_hmac = True

if uses_hmac:
raise ValueError(
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
f"could install a version using HMAC-based integrity checks which are incompatible "
f"with the current SHA256-based integrity checks. Please update to "
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
)


def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
"""Ensure sagemaker>=2.256.0 is in the dependencies.

This function ensures that the remote environment has a compatible version of sagemaker
that includes the fix for the HMAC key security issue. Versions < 2.256.0 use HMAC-based
integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable.
Versions >= 2.256.0 use SHA256-based integrity checks which are secure and don't require
the secret key.

If no dependencies are provided, creates a temporary requirements.txt with sagemaker.
If dependencies are provided, appends sagemaker if not already present.

Args:
local_dependencies_path: Path to user's dependencies file or None

Returns:
Path to the dependencies file (created or modified)

Raises:
ValueError: If user has pinned sagemaker to a version using HMAC hashing
"""
import tempfile

SAGEMAKER_MIN_VERSION = "sagemaker>=2.256.0,<3.0.0"

if local_dependencies_path is None:
# Create a temporary requirements.txt in the system temp directory
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
os.close(fd)

with open(req_file, "w") as f:
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
logger.info(
"Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION
)
return req_file

# If dependencies provided, ensure sagemaker is included
if local_dependencies_path.endswith(".txt"):
with open(local_dependencies_path, "r") as f:
content = f.read()

# Check if sagemaker is already specified
if "sagemaker" in content.lower():
# Extract the sagemaker requirement line for compatibility check
for line in content.split("\n"):
if "sagemaker" in line.lower():
_check_sagemaker_version_compatibility(line.strip())
break
else:
with open(local_dependencies_path, "a") as f:
f.write(f"\n{SAGEMAKER_MIN_VERSION}\n")
logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION)

return local_dependencies_path


def _prepare_dependencies_and_pre_execution_scripts(
local_dependencies_path: str,
pre_execution_commands: List[str],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Integration tests for sagemaker dependency injection in remote functions.

These tests verify that the sagemaker>=2.256.0 dependency is properly injected
into remote function jobs, preventing version mismatch issues.
"""

from __future__ import absolute_import

import os
import sys
import tempfile

import pytest

# Add src to path before importing sagemaker
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src"))

from sagemaker.remote_function import remote # noqa: E402

# Skip decorator for AWS configuration
skip_if_no_aws_region = pytest.mark.skipif(
not os.environ.get("AWS_DEFAULT_REGION"), reason="AWS credentials not configured"
)


class TestRemoteFunctionDependencyInjection:
"""Integration tests for dependency injection in remote functions."""

@pytest.mark.integ
@skip_if_no_aws_region
def test_remote_function_without_dependencies(self):
"""Test remote function execution without explicit dependencies.

This test verifies that when no dependencies are provided, the remote
function still executes successfully because sagemaker>=2.256.0 is
automatically injected.
"""

@remote(
instance_type="ml.m5.large",
# No dependencies specified - sagemaker should be injected automatically
)
def simple_add(x, y):
"""Simple function that adds two numbers."""
return x + y

# Execute the function
result = simple_add(5, 3)

# Verify result
assert result == 8, f"Expected 8, got {result}"
print("✓ Remote function without dependencies executed successfully")

@pytest.mark.integ
@skip_if_no_aws_region
def test_remote_function_with_user_dependencies_no_sagemaker(self):
"""Test remote function with user dependencies but no sagemaker.

This test verifies that when user provides dependencies without sagemaker,
sagemaker>=2.256.0 is automatically appended.
"""
# Create a temporary requirements.txt without sagemaker
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("numpy>=1.20.0\npandas>=1.3.0\n")
req_file = f.name

try:

@remote(
instance_type="ml.m5.large",
dependencies=req_file,
)
def compute_with_numpy(x):
"""Function that uses numpy."""
import numpy as np

return np.array([x, x * 2, x * 3]).sum()

# Execute the function
result = compute_with_numpy(5)

# Verify result (5 + 10 + 15 = 30)
assert result == 30, f"Expected 30, got {result}"
print("✓ Remote function with user dependencies executed successfully")
finally:
os.remove(req_file)


class TestRemoteFunctionVersionCompatibility:
"""Tests for version compatibility between local and remote environments."""

@pytest.mark.integ
@skip_if_no_aws_region
def test_deserialization_with_injected_sagemaker(self):
"""Test that deserialization works with injected sagemaker dependency.

This test verifies that the remote environment can properly deserialize
functions when sagemaker>=2.256.0 is available.
"""

@remote(
instance_type="ml.m5.large",
)
def complex_computation(data):
"""Function that performs complex computation."""
result = sum(data) * len(data)
return result

# Execute with various data types
test_data = [1, 2, 3, 4, 5]
result = complex_computation(test_data)

# Verify result (sum=15, len=5, 15*5=75)
assert result == 75, f"Expected 75, got {result}"
print("✓ Deserialization with injected sagemaker works correctly")

@pytest.mark.integ
@skip_if_no_aws_region
def test_multiple_remote_functions_with_dependencies(self):
"""Test multiple remote functions with different dependency configurations.

This test verifies that the dependency injection works correctly
when multiple remote functions are defined and executed.
"""

@remote(instance_type="ml.m5.large")
def func1(x):
return x + 1

@remote(instance_type="ml.m5.large")
def func2(x):
return x * 2

# Execute both functions
result1 = func1(5)
result2 = func2(5)

assert result1 == 6, f"func1: Expected 6, got {result1}"
assert result2 == 10, f"func2: Expected 10, got {result2}"
print("✓ Multiple remote functions with dependencies executed successfully")


if __name__ == "__main__":
pytest.main([__file__, "-v", "-m", "integ"])
Loading