diff --git a/.github/hooks/pre-commit b/.github/hooks/pre-commit new file mode 100755 index 00000000..a9a1af02 --- /dev/null +++ b/.github/hooks/pre-commit @@ -0,0 +1,9 @@ +#!/bin/sh + +if hatch fmt --check; then + echo "Hatch fmt check passed!" +else + hatch fmt + echo "Error: hatch fmt modified your files. Please re-stage and commit again." + exit 1 +fi \ No newline at end of file diff --git a/.github/scripts/lintcommit.py b/.github/scripts/lintcommit.py index f24ab886..255ea0ec 100644 --- a/.github/scripts/lintcommit.py +++ b/.github/scripts/lintcommit.py @@ -164,7 +164,8 @@ def lint_range(git_range: str, *, skip_dirty_check: bool = False) -> LintResult: status = subprocess.run( ["git", "status", "--porcelain"], capture_output=True, - text=True, check=False, + text=True, + check=False, ) if status.stdout.strip(): return LintResult( @@ -178,7 +179,8 @@ def lint_range(git_range: str, *, skip_dirty_check: bool = False) -> LintResult: result = subprocess.run( ["git", "log", "--no-merges", git_range, "-z", "--format=%H%n%B"], capture_output=True, - text=True, check=False, + text=True, + check=False, ) if result.returncode != 0: return LintResult(git_error=result.stderr.strip()) diff --git a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json index fb9ab785..fe8ccfd7 100644 --- a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json +++ b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json @@ -602,6 +602,17 @@ "ExecutionTimeout": 300 }, "path": "./src/parallel/parallel_with_named_branches.py" + }, + { + "name": "Plugin", + "description": "Test plugin", + "handler": "execution_with_plugin.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/plugin/execution_with_plugin.py" } ] } diff --git a/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py new file mode 100644 index 00000000..d8858baa --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_plugin.py @@ -0,0 +1,63 @@ +"""Demonstrates handler execution without any durable operations.""" + +import logging +from typing import Any + +from aws_durable_execution_sdk_python import StepContext +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, +) + + +class MyPlugin(DurableInstrumentationPlugin): + logger = logging.getLogger("MyPlugin") + + def on_operation_start(self, info): + self.logger.info(f"Operation started: {info}") + + def on_operation_end(self, info): + self.logger.info(f"Operation ended: {info}") + + def on_invocation_start(self, info): + self.logger.info(f"Invocation started: {info}") + + def on_invocation_end(self, info): + self.logger.info(f"Invocation ended: {info}") + + def on_user_function_start(self, info) -> None: + self.logger.info(f"User function started: {info}") + + def on_user_function_end(self, info) -> None: + self.logger.info(f"User function ended: {info}") + + +@durable_step +def add_numbers(_step_context: StepContext, a: int, b: int) -> int: + return a + b + + +@durable_with_child_context +def add_numbers_in_child(child_context: DurableContext, a: int, b: int): + result: int = child_context.step( + add_numbers(a, b), + name="add-a-and-b", + ) + return result + + +@durable_execution(plugins=[MyPlugin()]) +def handler(_event: Any, context: DurableContext) -> int: + result: int = context.run_in_child_context( + add_numbers_in_child(6, 4), + name="add-6-and-4", + ) + return context.step( + add_numbers(result, 2), + name="add-result-to-2", + ) diff --git a/packages/aws-durable-execution-sdk-python-examples/template.yaml b/packages/aws-durable-execution-sdk-python-examples/template.yaml index 2854e729..bf91637f 100644 --- a/packages/aws-durable-execution-sdk-python-examples/template.yaml +++ b/packages/aws-durable-execution-sdk-python-examples/template.yaml @@ -977,6 +977,24 @@ "ExecutionTimeout": 300 } } + }, + "ExecutionWithPlugin": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "execution_with_plugin.handler", + "Description": "Test plugin", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py new file mode 100644 index 00000000..5e21ba6e --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_plugin.py @@ -0,0 +1,24 @@ +"""Tests for step example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus + +from src.plugin import execution_with_plugin +from test.conftest import deserialize_operation_payload + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=execution_with_plugin.handler, + lambda_function_name="Plugin", +) +def test_plugin(durable_runner): + """Test basic step example.""" + with durable_runner: + result = durable_runner.run(input="{}", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == 12 + + step_result = result.get_step("add-result-to-2") + assert deserialize_operation_payload(step_result.result) == 12 diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py index 3a7ab136..61bdbb0d 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -30,7 +30,7 @@ TimedSuspendExecution, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier -from aws_durable_execution_sdk_python.lambda_service import ErrorObject +from aws_durable_execution_sdk_python.lambda_service import ErrorObject, OperationType from aws_durable_execution_sdk_python.operation.child import child_handler @@ -428,9 +428,10 @@ def _execute_item_in_child_context( # For FLAT `child_handler` skips checkpoints, so not used. # Construct it unconditionally to keep the call simple. operation_identifier = OperationIdentifier( - operation_id, - executor_context._parent_id, # noqa: SLF001 - name, + operation_id=operation_id, + sub_type=self.sub_type_iteration, + parent_id=executor_context._parent_id, # noqa: SLF001 + name=name, ) def run_in_child_handler() -> ResultType: diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py index 6691f2ab..00e575d2 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py @@ -23,7 +23,10 @@ ValidationError, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier -from aws_durable_execution_sdk_python.lambda_service import OperationSubType +from aws_durable_execution_sdk_python.lambda_service import ( + OperationSubType, + OperationType, +) from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.callback import ( CallbackOperationExecutor, @@ -443,6 +446,7 @@ def create_callback( state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, + sub_type=OperationSubType.CALLBACK, parent_id=self._parent_id, name=name, ), @@ -485,6 +489,7 @@ def invoke( state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, + sub_type=OperationSubType.CHAINED_INVOKE, parent_id=self._parent_id, name=name, ), @@ -507,6 +512,7 @@ def map( operation_id = self._create_step_id() operation_identifier = OperationIdentifier( operation_id=operation_id, + sub_type=OperationSubType.MAP, parent_id=self._parent_id, name=map_name, ) @@ -553,7 +559,10 @@ def parallel( operation_id = self._create_step_id() parallel_context = self.create_child_context(operation_id=operation_id) operation_identifier = OperationIdentifier( - operation_id=operation_id, parent_id=self._parent_id, name=name + operation_id=operation_id, + sub_type=OperationSubType.PARALLEL, + parent_id=self._parent_id, + name=name, ) def parallel_in_child_context() -> BatchResult[T]: @@ -606,6 +615,11 @@ def run_in_child_context( step_name: str | None = self._resolve_step_name(name, func) # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id operation_id = self._create_step_id() + sub_type = ( + config.sub_type + if config and config.sub_type + else OperationSubType.RUN_IN_CHILD_CONTEXT + ) is_virtual: bool = config.is_virtual if config else False @@ -621,6 +635,7 @@ def callable_with_child_context(): state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, + sub_type=sub_type, parent_id=self._parent_id, name=step_name, ), @@ -646,6 +661,7 @@ def step( state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, + sub_type=OperationSubType.STEP, parent_id=self._parent_id, name=step_name, ), @@ -673,6 +689,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, + sub_type=OperationSubType.WAIT, parent_id=self._parent_id, name=name, ), @@ -728,6 +745,7 @@ def wait_for_condition( state=self.state, operation_identifier=OperationIdentifier( operation_id=operation_id, + sub_type=OperationSubType.WAIT_FOR_CONDITION, parent_id=self._parent_id, name=name, ), diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py index df535b41..afb710e9 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py @@ -4,11 +4,12 @@ import functools import json import logging +import warnings from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Any + from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, @@ -26,6 +27,12 @@ Operation, OperationType, OperationUpdate, + InvocationStatus, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, + PluginExecutor, ) from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -149,62 +156,6 @@ def from_durable_execution_invocation_input( ) -class InvocationStatus(Enum): - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - PENDING = "PENDING" - - -@dataclass(frozen=True) -class DurableExecutionInvocationOutput: - """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. - - If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution, - payload must be empty for SUCCEEDED/FAILED status. - """ - - status: InvocationStatus - result: str | None = None - error: ErrorObject | None = None - - @classmethod - def from_dict( - cls, data: MutableMapping[str, Any] - ) -> DurableExecutionInvocationOutput: - """Create an instance from a dictionary. - - Args: - data: Dictionary with camelCase keys matching the original structure - - Returns: - A DurableExecutionInvocationOutput instance - """ - status = InvocationStatus(data.get("Status")) - error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None - return cls(status=status, result=data.get("Result"), error=error) - - def to_dict(self) -> MutableMapping[str, Any]: - """Convert to a dictionary with the original field names. - - Returns: - Dictionary with the original camelCase keys - """ - result: MutableMapping[str, Any] = {"Status": self.status.value} - - if self.result is not None: - # large payloads return "", because checkpointed already - result["Result"] = self.result - if self.error: - result["Error"] = self.error.to_dict() - - return result - - @classmethod - def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: - """Create a succeeded invocation output.""" - return cls(status=InvocationStatus.SUCCEEDED, result=result) - - # endregion Invocation models @@ -212,14 +163,36 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: Boto3LambdaClient | None = None, + plugins: list[DurableInstrumentationPlugin] | None = None, ) -> Callable[[Any, LambdaContext], Any]: + """ + Decorator to create a durable execution handler. + + Args: + func: The user function to decorate + boto3_client: Optional boto3 Lambda client to use + plugins: Optional list of plugins to use (EXPERIMENTAL: This + feature has known issues and this parameter may change or be removed.) + """ # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, boto3_client=boto3_client, plugins=plugins + ) logger.debug("Starting durable execution handler...") + if plugins: + warnings.warn( + "The 'plugins' parameter is provisional and may be altered or removed.", + category=FutureWarning, + stacklevel=2, # point the warning to the caller of durable_execution + ) + + plugin_executor = PluginExecutor(plugins) + + @plugin_executor.handle_durable_output def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input: DurableExecutionInvocationInput service_client: DurableServiceClient @@ -255,6 +228,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: operations={}, service_client=service_client, replay_status=ReplayStatus.NEW, + plugin_executor=plugin_executor, ) try: @@ -306,6 +280,13 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ) as executor, contextlib.closing(execution_state) as execution_state, ): + # execute the plugins + plugin_executor.on_invocation_start( + execution_arn=invocation_input.durable_execution_arn, + lambda_context=context, + execution_start_time=execution_state.get_execution_operation().start_timestamp, + is_first_invocation=not execution_state.is_replaying(), + ) # Thread 1: Run background checkpoint processing executor.submit(execution_state.checkpoint_batches_forever) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/identifier.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/identifier.py index d273d097..89d07727 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/identifier.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/identifier.py @@ -4,11 +4,21 @@ from dataclasses import dataclass +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationSubType, +) + @dataclass(frozen=True) class OperationIdentifier: """Container for operation id, parent id, and name.""" operation_id: str + sub_type: OperationSubType parent_id: str | None = None name: str | None = None + + @property + def type(self) -> OperationType: + return OperationType.from_sub_type(self.sub_type) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py index aa78e4e8..3fcbac9e 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/lambda_service.py @@ -73,6 +73,29 @@ class OperationType(Enum): CALLBACK = "CALLBACK" CHAINED_INVOKE = "CHAINED_INVOKE" + @classmethod + def from_sub_type(cls, sub_type: OperationSubType) -> OperationType: + match sub_type: + case OperationSubType.STEP | OperationSubType.WAIT_FOR_CONDITION: + return OperationType.STEP + case OperationSubType.WAIT: + return OperationType.WAIT + case OperationSubType.CHAINED_INVOKE: + return OperationType.CHAINED_INVOKE + case OperationSubType.CALLBACK: + return OperationType.CALLBACK + case ( + OperationSubType.WAIT_FOR_CALLBACK + | OperationSubType.RUN_IN_CHILD_CONTEXT + | OperationSubType.MAP + | OperationSubType.MAP_ITERATION + | OperationSubType.PARALLEL + | OperationSubType.PARALLEL_BRANCH + ): + return OperationType.CONTEXT + case _: + raise ValueError(f"Unknown operation sub-type {sub_type}") + class CallbackTimeoutType(Enum): TIMEOUT = "Callback.Timeout" @@ -105,6 +128,70 @@ class OperationSubType(Enum): CHAINED_INVOKE = "ChainedInvoke" +class InvocationStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + # Used internally only: the invocation failed and the backend will retry + RETRY = "RETRY" + + +@dataclass(frozen=True) +class DurableExecutionInvocationOutput: + """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. + + If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution, + payload must be empty for SUCCEEDED/FAILED status. + """ + + status: InvocationStatus + result: str | None = None + error: ErrorObject | None = None + + @classmethod + def from_dict( + cls, data: MutableMapping[str, Any] + ) -> DurableExecutionInvocationOutput: + """Create an instance from a dictionary. + + Args: + data: Dictionary with camelCase keys matching the original structure + + Returns: + A DurableExecutionInvocationOutput instance + """ + status = InvocationStatus(data.get("Status")) + error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None + return cls(status=status, result=data.get("Result"), error=error) + + def to_dict(self) -> MutableMapping[str, Any]: + """Convert to a dictionary with the original field names. + + Returns: + Dictionary with the original camelCase keys + """ + result: MutableMapping[str, Any] = {"Status": self.status.value} + + if self.result is not None: + # large payloads return "", because checkpointed already + result["Result"] = self.result + if self.error: + result["Error"] = self.error.to_dict() + + return result + + @classmethod + def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: + """Create a succeeded invocation output.""" + return cls(status=InvocationStatus.SUCCEEDED, result=result) + + @classmethod + def create_retry(cls, error: ErrorObject) -> DurableExecutionInvocationOutput: + """Create a failed invocation output.""" + return cls(status=InvocationStatus.RETRY, error=error) + + @dataclass(frozen=True) class ExecutionDetails: input_payload: str | None = None diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py index ecaf0f9d..780ab952 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/child.py @@ -154,9 +154,15 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: self.operation_identifier.operation_id, self.operation_identifier.name, ) - try: - raw_result: T = self.func() + # todo: fix attempt (checkpointed_result.is_existent is always True) + wrapped_user_func = self.state.wrap_user_function( + self.func, + self.operation_identifier, + checkpointed_result.is_replay_children(), + attempt=None if checkpointed_result.is_existent() else 1, + ) + raw_result: T = wrapped_user_func() if self.is_virtual: logger.debug( diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py index 8a418fb3..35ae2d19 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/step.py @@ -13,16 +13,19 @@ ExecutionError, InvalidStateError, StepInterruptedError, + SuspendExecution, ) from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, OperationUpdate, + OperationType, ) from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.base import ( CheckResult, OperationExecutor, ) +from aws_durable_execution_sdk_python.plugin import UserFunctionStartInfo from aws_durable_execution_sdk_python.retries import RetryDecision, RetryPresets from aws_durable_execution_sdk_python.serdes import deserialize, serialize from aws_durable_execution_sdk_python.suspend import ( @@ -219,7 +222,14 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: try: # This is the actual code provided by the caller to execute durably inside the step - raw_result: T = self.func(step_context) + wrapped_user_func = self.state.wrap_user_function( + self.func, + self.operation_identifier, + False, + attempt, + ) + raw_result: T = wrapped_user_func(step_context) + serialized_result: str = serialize( serdes=self.config.serdes, value=raw_result, diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py index 5c4f1c4c..a53d8ef9 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py @@ -11,6 +11,8 @@ from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, OperationUpdate, + OperationType, + OperationSubType, ) from aws_durable_execution_sdk_python.logger import LogInfo from aws_durable_execution_sdk_python.operation.base import ( @@ -188,7 +190,13 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T: ) ) - new_state = self.check(current_state, check_context) + wrapped_user_func = self.state.wrap_user_function( + self.check, + self.operation_identifier, + False, + attempt, + ) + new_state = wrapped_user_func(current_state, check_context) # Check if condition is met with the wait strategy decision: WaitForConditionDecision = self.config.wait_strategy( diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py new file mode 100644 index 00000000..1a7ecd27 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py @@ -0,0 +1,408 @@ +import contextlib +import datetime +import functools +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, MutableMapping + +from aws_durable_execution_sdk_python.exceptions import SuspendExecution +from aws_durable_execution_sdk_python.identifier import OperationIdentifier +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, + InvocationStatus, + Operation, + OperationUpdate, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.types import LambdaContext + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class OperationInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None + name: str | None + parent_id: str | None + start_time: datetime.datetime | None + + +@dataclass(frozen=True) +class OperationStartInfo(OperationInfo): + pass + + +@dataclass(frozen=True) +class OperationEndInfo(OperationInfo): + status: OperationStatus + end_time: datetime.datetime | None + error: ErrorObject | None + + +class UserFunctionOutcome(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + @classmethod + def from_error(cls, error: ErrorObject | None) -> "UserFunctionOutcome": + if error is None: + return cls(cls.SUCCEEDED) + elif error.type == SuspendExecution.__name__: + return cls(cls.PENDING) + else: + return cls(cls.FAILED) + + +@dataclass(frozen=True) +class UserFunctionStartInfo(OperationInfo): + is_replay_children: bool = ( + False # True if user function is called to replay children (MAP/PARALLEL) + ) + attempt: int | None = ( + None # None for user function called more than once in CONTEXT + ) + + +@dataclass(frozen=True) +class UserFunctionEndInfo(OperationInfo): + is_replay_children: ( + bool # True if user function is called to replay children (MAP/PARALLEL) + ) + attempt: int | None # None for user function called more than once in CONTEXT + outcome: UserFunctionOutcome + end_time: datetime.datetime | None + error: ErrorObject | None + + @classmethod + def from_start_info( + cls, start_info: UserFunctionStartInfo, error: ErrorObject | None + ) -> "UserFunctionEndInfo": + return UserFunctionEndInfo( + operation_id=start_info.operation_id, + operation_type=start_info.operation_type, + sub_type=start_info.sub_type, + name=start_info.name, + parent_id=start_info.parent_id, + start_time=start_info.start_time, + is_replay_children=start_info.is_replay_children, + attempt=start_info.attempt, + outcome=UserFunctionOutcome.from_error(error), + end_time=datetime.datetime.now(datetime.UTC), + error=error, + ) + + +@dataclass(frozen=True) +class InvocationInfo: + request_id: str | None + execution_arn: str | None + start_time: datetime.datetime | None + is_first_invocation: bool + + +@dataclass(frozen=True) +class InvocationStartInfo(InvocationInfo): + pass + + +@dataclass(frozen=True) +class InvocationEndInfo(InvocationInfo): + status: InvocationStatus + end_time: datetime.datetime | None + error: ErrorObject | None + + @classmethod + def from_durable_execution_invocation_output( + cls, + invocation_start_info: InvocationStartInfo, + output: "DurableExecutionInvocationOutput", + ): + return InvocationEndInfo( + request_id=invocation_start_info.request_id, + execution_arn=invocation_start_info.execution_arn, + start_time=invocation_start_info.start_time, + is_first_invocation=invocation_start_info.is_first_invocation, + status=output.status, + end_time=datetime.datetime.now(datetime.UTC), + error=output.error, + ) + + +class DurableInstrumentationPlugin: + """Base class for plugins. Override only the methods you need.""" + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + """Called when an invocation starts. This is called within the thread that runs user function handler. + + Args: + info: Information about the invocation. + """ + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + """Called when an invocation ends. This is called within the thread that runs user function handler. + + Args: + info: Information about the invocation. + """ + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + """ + Called when an operation checkpoints STARTED status. This is called NOT within the thread that runs operation. + + Args: + info: Information about the operation. + + """ + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + """ + Called when an operation checkpoints a terminal status. This is called NOT within the thread that runs operation. + + Args: + info: Information about the operation. + """ + pass + + def on_user_function_start(self, info: UserFunctionStartInfo) -> None: + """Called when an operation starts to execute user provided function. This is called within the thread that runs user provided function. + + Args: + info: Information about the operation attempt. + """ + pass + + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + """Called when an operation finishes executing user provided function. This is called within the thread that runs user provided function. + + Args: + info: Information about the operation attempt. + """ + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + + +class PluginExecutor: + def __init__(self, plugins: list[DurableInstrumentationPlugin] | None): + self._plugins = plugins or [] + self._executor: ThreadPoolExecutor | None = None + self._invocation_status: InvocationStartInfo | None = None + + @contextlib.contextmanager + def run(self): + if self._plugins: + self._executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="plugin-executor", + ) + try: + yield + finally: + self._invocation_status = None + # Shut down the thread pool, waiting for pending tasks to complete. + if self._executor: + self._executor.shutdown(wait=True) + + @staticmethod + def _dispatch_plugin(plugin: DurableInstrumentationPlugin, info) -> None: + """Invoke the appropriate plugin callback. Runs inside the thread pool.""" + try: + match info: + case InvocationStartInfo(): + plugin.on_invocation_start(info) + case InvocationEndInfo(): + plugin.on_invocation_end(info) + case OperationStartInfo(): + plugin.on_operation_start(info) + case OperationEndInfo(): + plugin.on_operation_end(info) + case UserFunctionStartInfo(): + plugin.on_user_function_start(info) + case UserFunctionEndInfo(): + plugin.on_user_function_end(info) + case _: + raise RuntimeError(f"Unknown info type: {type(info)}") + except Exception: + # log and ignore the exception + logger.exception("Plugin %s exception ignored", plugin.__class__.__name__) + + def execute_plugins(self, info, sync): + if not self._executor: + return + for plugin in self._plugins: + if sync: + # this is called synchronously, so plugins will be able to manipulate thread local objects + self._dispatch_plugin(plugin, info) + else: + # this is called asynchronously, so plugins cannot manipulate thread local objects + self._executor.submit(self._dispatch_plugin, plugin, info) + + def on_invocation_start( + self, + execution_arn: str, + is_first_invocation: bool, + execution_start_time: datetime.datetime | None, + lambda_context: LambdaContext | None, + ) -> None: + aws_request_id = lambda_context.aws_request_id if lambda_context else None + invocation_start_time = ( + datetime.datetime.now(datetime.UTC) + if is_first_invocation + else execution_start_time + ) + self._invocation_status = InvocationStartInfo( + execution_arn=execution_arn, + request_id=aws_request_id, + is_first_invocation=is_first_invocation, + start_time=invocation_start_time, + ) + self.execute_plugins(self._invocation_status, sync=True) + + def on_invocation_end( + self, + output: "DurableExecutionInvocationOutput", + ) -> None: + if self._invocation_status is None: + # on_invocation_start not called, skip + return + + invocation_end_info = ( + InvocationEndInfo.from_durable_execution_invocation_output( + self._invocation_status, output + ) + ) + self.execute_plugins(invocation_end_info, sync=True) + + def on_user_function_start( + self, + operation_identifier: OperationIdentifier, + is_replay_children: bool = False, + attempt: int | None = None, + ) -> UserFunctionStartInfo: + """Execute any registered plugins for the operation when its user function starts to execute.""" + start_info = UserFunctionStartInfo( + operation_id=operation_identifier.operation_id, + operation_type=operation_identifier.type, + sub_type=operation_identifier.sub_type, + name=operation_identifier.name, + parent_id=operation_identifier.parent_id, + start_time=datetime.datetime.now(datetime.UTC), + is_replay_children=is_replay_children, + attempt=attempt, + ) + self.execute_plugins(start_info, sync=True) + return start_info + + def on_user_function_end(self, start_info: UserFunctionStartInfo, error) -> None: + """Execute any registered plugins for the operation when its user function finishes execution.""" + self.execute_plugins( + UserFunctionEndInfo.from_start_info(start_info, error), sync=True + ) + + def on_operation_action(self, update: OperationUpdate): + """Execute any registered plugins for a given operation when an update is checkpointed + + Args: + update: the operation update that is checkpointed + """ + if update.action is OperationAction.START: + # we handle only START action here because on_operation_update may not be able to see a STARTED update + # when START is checkpointed in batch with terminal status updates. + self.execute_plugins( + OperationStartInfo( + operation_id=update.operation_id, + operation_type=update.operation_type, + sub_type=update.sub_type, + name=update.name, + parent_id=update.parent_id, + start_time=datetime.datetime.now(datetime.UTC), + ), + sync=False, + ) + + def on_operation_update(self, operation: Operation | None): + """Execute any registered plugins for a given operation when it receives an update + + Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be + checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED). + + Note: the operation may not be up-to-date if the checkpoint is called asynchronously. + + Args: + operation: the operation is just checkpointed + """ + if operation and self._is_terminal_status(operation.status): + self.execute_plugins( + OperationEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_time=operation.start_timestamp, + end_time=operation.end_timestamp, + status=operation.status, + error=self._extract_error(operation), + ), + sync=False, + ) + + @staticmethod + def _extract_error(operation: Operation): + if operation.step_details and operation.step_details.error: + return operation.step_details.error + if operation.callback_details and operation.callback_details.error: + return operation.callback_details.error + if operation.chained_invoke_details and operation.chained_invoke_details.error: + return operation.chained_invoke_details.error + if operation.context_details and operation.context_details.error: + return operation.context_details.error + return None + + @staticmethod + def _is_terminal_status(status): + return status in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ] + + @property + def handle_durable_output(self): + def decorator(func: Callable[[Any, LambdaContext], MutableMapping[str, Any]]): + @functools.wraps(func) + def wrapper(event: Any, context: LambdaContext): + with self.run(): + try: + output = func(event, context) + + self.on_invocation_end( + output=DurableExecutionInvocationOutput.from_dict(output), + ) + return output + except Exception as e: + self.on_invocation_end( + output=DurableExecutionInvocationOutput.create_retry( + ErrorObject.from_exception(e) + ), + ) + raise + + return wrapper + + return decorator diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index 83175503..7fcfadcc 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools import json import logging import queue @@ -10,7 +11,7 @@ from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Any from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, @@ -18,7 +19,9 @@ DurableExecutionsError, GetExecutionStateError, OrphanedChildException, + SuspendExecution, ) +from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( CheckpointOutput, DurableServiceClient, @@ -29,6 +32,11 @@ OperationType, OperationUpdate, StateOutput, + OperationSubType, +) +from aws_durable_execution_sdk_python.plugin import ( + PluginExecutor, + UserFunctionStartInfo, ) from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock @@ -236,6 +244,7 @@ def __init__( initial_checkpoint_token: str, operations: MutableMapping[str, Operation], service_client: DurableServiceClient, + plugin_executor: PluginExecutor, batcher_config: CheckpointBatcherConfig | None = None, replay_status: ReplayStatus = ReplayStatus.NEW, ): @@ -243,6 +252,7 @@ def __init__( self._current_checkpoint_token: str = initial_checkpoint_token self.operations: MutableMapping[str, Operation] = operations self._service_client: DurableServiceClient = service_client + self._plugin_executor: PluginExecutor = plugin_executor self._ordered_checkpoint_lock: OrderedLock = OrderedLock() self._operations_lock: Lock = Lock() @@ -274,7 +284,7 @@ def fetch_paginated_operations( initial_operations: list[Operation], checkpoint_token: str, next_marker: str | None, - ) -> None: + ) -> list[Operation]: """Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe. The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety. @@ -283,6 +293,8 @@ def fetch_paginated_operations( initial_operations: initial operations to be added to ExecutionState checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + Returns: + List of all operations fetched from the Durable Functions API Raises: GetExecutionStateError: If the API call fails. The error is logged @@ -315,6 +327,7 @@ def fetch_paginated_operations( self.operations.update( {op.operation_id: op for op in all_operations} ) + return all_operations def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation @@ -689,12 +702,18 @@ def checkpoint_batches_forever(self) -> None: current_checkpoint_token = output.checkpoint_token # Fetch new operations from the API before unblocking sync waiters - self.fetch_paginated_operations( + updated_operations = self.fetch_paginated_operations( output.new_execution_state.operations, output.checkpoint_token, output.new_execution_state.next_marker, ) + for update in updates: + self._plugin_executor.on_operation_action(update) + + for operation in updated_operations: + self._plugin_executor.on_operation_update(operation) + # Signal completion for any synchronous operations for queued_op in batch: if queued_op.completion_event is not None: @@ -903,3 +922,35 @@ def _calculate_operation_size(queued_op: QueuedOperation) -> int: def close(self): self.stop_checkpointing() + + def wrap_user_function( + self, + user_function: Callable, + operation_identifier: OperationIdentifier, + is_replay_children: bool = False, + attempt: int | None = None, + ): + @functools.wraps(user_function) + def wrapper(*args, **kwargs): + start_info = self._plugin_executor.on_user_function_start( + operation_identifier, is_replay_children, attempt + ) + try: + result = user_function(*args, **kwargs) + self._plugin_executor.on_user_function_end(start_info, None) + return result + except SuspendExecution as e: + self._plugin_executor.on_user_function_end( + start_info, + ErrorObject( + type=type(e).__name__, message=None, data=None, stack_trace=None + ), + ) + raise + except Exception as e: + self._plugin_executor.on_user_function_end( + start_info, ErrorObject.from_exception(e) + ) + raise + + return wrapper diff --git a/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py b/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py index 3d9ae270..ef7d7f57 100644 --- a/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/concurrency_test.py @@ -1134,7 +1134,9 @@ def failure_callable(): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context result = executor.execute(execution_state, executor_context) @@ -1172,7 +1174,9 @@ def execute_item(self, child_context, executable): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context result = executor._execute_item_in_child_context( # noqa: SLF001 executor_context, executables[0] @@ -1262,7 +1266,9 @@ def execute_item(self, child_context, executable): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context # Should raise TimedSuspendExecution since no other tasks running with pytest.raises(TimedSuspendExecution): @@ -1308,7 +1314,9 @@ def execute_item(self, child_context, executable): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context # Should raise TimedSuspendExecution after Task B completes with pytest.raises(TimedSuspendExecution): @@ -1353,7 +1361,9 @@ def execute_item(self, child_context, executable): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context # Should raise TimedSuspendExecution since single task suspends with pytest.raises(TimedSuspendExecution): @@ -1426,7 +1436,9 @@ def execute_item(self, child_context, executable): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context # Should complete successfully after B resubmits and both tasks finish result = executor.execute(execution_state, executor_context) @@ -1569,7 +1581,9 @@ def failure_callable(): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context result = executor.execute(execution_state, executor_context) @@ -1794,7 +1808,9 @@ def failure_callable(): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context result = executor.execute(execution_state, executor_context) @@ -1898,7 +1914,9 @@ def suspend_callable(): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context # Should raise SuspendExecution since single task suspends with pytest.raises(SuspendExecution): @@ -2838,7 +2856,9 @@ def execute_item(self, child_context, executable): execution_state.create_checkpoint = Mock() executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context # Should return (not hang) and batch should reflect one FAILED and one SUCCEEDED result = executor.execute(execution_state, executor_context) @@ -2876,7 +2896,9 @@ def task_func(ctx, item, idx, items): execution_state.create_checkpoint = Mock() executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context result = executor.execute(execution_state, executor_context) @@ -2937,6 +2959,7 @@ def slow_branch(): def create_child_context(op_id, *, is_virtual=False): child = Mock() child.state = execution_state + child.state.wrap_user_function = lambda func, *args, **kwargs: func return child executor_context.create_child_context = create_child_context @@ -3003,6 +3026,7 @@ def slow_branch(): execution_state = Mock() execution_state.create_checkpoint = Mock() + execution_state.wrap_user_function = lambda func, *args, **kwargs: func executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001 executor_context._parent_id = "parent" # noqa: SLF001 @@ -3379,6 +3403,7 @@ def execute_item(self, child_context, executable): execution_state = Mock() execution_state.create_checkpoint = Mock() + execution_state.wrap_user_function = lambda func, *args, **kwargs: func # Mock out the checkpoint so the real child_handler reports "not # existent" (non-existent checkpoint -> normal execution path). @@ -3434,6 +3459,7 @@ def execute_item(self, child_context, executable): execution_state = Mock() execution_state.create_checkpoint = Mock() + execution_state.wrap_user_function = lambda func, *args, **kwargs: func mock_checkpoint = Mock() mock_checkpoint.is_succeeded.return_value = False diff --git a/packages/aws-durable-execution-sdk-python/tests/context_test.py b/packages/aws-durable-execution-sdk-python/tests/context_test.py index 0e2cf0e2..13d1dc64 100644 --- a/packages/aws-durable-execution-sdk-python/tests/context_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/context_test.py @@ -36,6 +36,7 @@ Operation, OperationStatus, OperationType, + OperationSubType, ) from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState from aws_durable_execution_sdk_python.waits import ( @@ -288,7 +289,9 @@ def test_create_callback_basic(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_operation_id, None, None), + operation_identifier=OperationIdentifier( + expected_operation_id, OperationSubType.CALLBACK, None, None + ), config=CallbackConfig(), ) mock_executor.process.assert_called_once() @@ -320,7 +323,9 @@ def test_create_callback_with_name_and_config(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_operation_id, None, None), + operation_identifier=OperationIdentifier( + expected_operation_id, OperationSubType.CALLBACK, None, None + ), config=config, ) mock_executor.process.assert_called_once() @@ -352,7 +357,9 @@ def test_create_callback_with_parent_id(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_operation_id, "parent123"), + operation_identifier=OperationIdentifier( + expected_operation_id, OperationSubType.CALLBACK, "parent123" + ), config=CallbackConfig(), ) @@ -417,7 +424,9 @@ def test_step_basic(mock_executor_class): assert result == "step_result" mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_operation_id, None, None), + operation_identifier=OperationIdentifier( + expected_operation_id, OperationSubType.STEP, None, None + ), config=ANY, # StepConfig() is created in context.step() func=mock_callable, context_logger=ANY, @@ -456,7 +465,9 @@ def test_step_with_name_and_config(mock_executor_class): assert result == "configured_result" mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, None, None), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.STEP, None, None + ), config=config, func=mock_callable, context_logger=ANY, @@ -493,7 +504,9 @@ def test_step_with_parent_id(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, "parent123"), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.STEP, "parent123" + ), config=ANY, func=mock_callable, context_logger=ANY, @@ -533,10 +546,10 @@ def test_step_increments_counter(mock_executor_class): assert context._step_counter.get_current() == 12 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier(expected_id1, None, None) + ] == OperationIdentifier(expected_id1, OperationSubType.STEP, None, None) assert mock_executor_class.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier(expected_id2, None, None) + ] == OperationIdentifier(expected_id2, OperationSubType.STEP, None, None) @patch("aws_durable_execution_sdk_python.context.StepOperationExecutor") @@ -564,7 +577,9 @@ def test_step_with_original_name(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, None, "override_name"), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.STEP, None, "override_name" + ), config=ANY, func=mock_callable, context_logger=ANY, @@ -599,7 +614,9 @@ def test_invoke_basic(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_operation_id, None, None), + operation_identifier=OperationIdentifier( + expected_operation_id, OperationSubType.CHAINED_INVOKE, None, None + ), function_name="test_function", payload="test_payload", config=ANY, # InvokeConfig() is created in context.invoke() @@ -636,7 +653,9 @@ def test_invoke_with_name_and_config(mock_executor_class): assert result == "configured_result" mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, None, "named_invoke"), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.CHAINED_INVOKE, None, "named_invoke" + ), function_name="test_function", payload={"key": "value"}, config=config, @@ -668,7 +687,9 @@ def test_invoke_with_parent_id(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, "parent123", None), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.CHAINED_INVOKE, "parent123", None + ), function_name="test_function", payload=None, config=ANY, @@ -703,10 +724,10 @@ def test_invoke_increments_counter(mock_executor_class): assert context._step_counter.get_current() == 12 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier(expected_id1, None, None) + ] == OperationIdentifier(expected_id1, OperationSubType.CHAINED_INVOKE, None, None) assert mock_executor_class.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier(expected_id2, None, None) + ] == OperationIdentifier(expected_id2, OperationSubType.CHAINED_INVOKE, None, None) @patch("aws_durable_execution_sdk_python.context.InvokeOperationExecutor") @@ -733,7 +754,9 @@ def test_invoke_with_none_payload(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, None, None), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.CHAINED_INVOKE, None, None + ), function_name="test_function", payload=None, config=ANY, @@ -778,7 +801,7 @@ def test_invoke_with_custom_serdes(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, operation_identifier=OperationIdentifier( - expected_id, None, "custom_serdes_invoke" + expected_id, OperationSubType.CHAINED_INVOKE, None, "custom_serdes_invoke" ), function_name="test_function", payload={"original": "data"}, @@ -811,7 +834,9 @@ def test_wait_basic(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_operation_id, None, None), + operation_identifier=OperationIdentifier( + expected_operation_id, OperationSubType.WAIT, None, None + ), seconds=30, ) mock_executor.process.assert_called_once() @@ -840,7 +865,9 @@ def test_wait_with_name(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, None, "test_wait"), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.WAIT, None, "test_wait" + ), seconds=60, ) mock_executor.process.assert_called_once() @@ -869,7 +896,9 @@ def test_wait_with_parent_id(mock_executor_class): mock_executor_class.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier(expected_id, "parent123"), + operation_identifier=OperationIdentifier( + expected_id, OperationSubType.WAIT, "parent123" + ), seconds=45, ) mock_executor.process.assert_called_once() @@ -901,10 +930,10 @@ def test_wait_increments_counter(mock_executor_class): assert context._step_counter.get_current() == 12 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier(expected_id1, None, None) + ] == OperationIdentifier(expected_id1, OperationSubType.WAIT, None, None) assert mock_executor_class.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier(expected_id2, None, None) + ] == OperationIdentifier(expected_id2, OperationSubType.WAIT, None, None) @patch("aws_durable_execution_sdk_python.context.WaitOperationExecutor") @@ -974,7 +1003,7 @@ def test_run_in_child_context_basic(mock_handler): call_args = mock_handler.call_args assert call_args[1]["state"] is mock_state assert call_args[1]["operation_identifier"] == OperationIdentifier( - expected_operation_id, None, None + expected_operation_id, OperationSubType.RUN_IN_CHILD_CONTEXT, None, None ) assert call_args[1]["config"] is None @@ -1004,7 +1033,7 @@ def test_run_in_child_context_with_name_and_config(mock_handler): assert result == "configured_child_result" call_args = mock_handler.call_args assert call_args[1]["operation_identifier"] == OperationIdentifier( - expected_id, None, "original_function" + expected_id, OperationSubType.RUN_IN_CHILD_CONTEXT, None, "original_function" ) assert call_args[1]["config"] is config @@ -1037,7 +1066,7 @@ def test_run_in_child_context_with_parent_id(mock_executor_class): call_args = mock_executor_class.call_args assert call_args[1]["operation_identifier"] == OperationIdentifier( - expected_id, "parent456", None + expected_id, OperationSubType.RUN_IN_CHILD_CONTEXT, "parent456", None ) @@ -1101,10 +1130,14 @@ def test_run_in_child_context_increments_counter(mock_executor_class): assert context._step_counter.get_current() == 7 # noqa: SLF001 assert mock_executor_class.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier(expected_id1, None, None) + ] == OperationIdentifier( + expected_id1, OperationSubType.RUN_IN_CHILD_CONTEXT, None, None + ) assert mock_executor_class.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier(expected_id2, None, None) + ] == OperationIdentifier( + expected_id2, OperationSubType.RUN_IN_CHILD_CONTEXT, None, None + ) @patch("aws_durable_execution_sdk_python.context.child_handler") @@ -1343,6 +1376,8 @@ def test_map_with_empty_inputs(mock_handler): def test_function(context, item, index, items): return item + mock_state.wrap_user_function = lambda func, *args, **kwargs: func + inputs = [] with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: @@ -1362,6 +1397,7 @@ def test_map_with_different_input_types(mock_handler): mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" ) + mock_state.wrap_user_function = lambda func, *args, **kwargs: func def test_function(context, item, index, items): return str(item) @@ -1507,6 +1543,7 @@ def test_parallel_with_empty_callables(mock_handler): mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" ) + mock_state.wrap_user_function = lambda func, *args, **kwargs: func callables = [] @@ -1527,6 +1564,7 @@ def test_parallel_with_single_callable(mock_handler): mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" ) + mock_state.wrap_user_function = lambda func, *args, **kwargs: func def single_task(context): return "single_result" @@ -1550,6 +1588,7 @@ def test_parallel_with_many_callables(mock_handler): mock_state.durable_execution_arn = ( "arn:aws:durable:us-east-1:123456789012:execution/test" ) + mock_state.wrap_user_function = lambda func, *args, **kwargs: func def create_task(i): def task(context): @@ -1664,6 +1703,7 @@ def test_function(context, item, index, items): # Create mock state and context state = Mock() state.durable_execution_arn = "test_arn" + state.wrap_user_function = lambda func, *args, **kwargs: func context = create_test_context(state=state) @@ -1704,6 +1744,7 @@ def test_callable_2(context): # Create mock state and context state = Mock() state.durable_execution_arn = "test_arn" + state.wrap_user_function = lambda func, *args, **kwargs: func context = create_test_context(state=state) diff --git a/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py b/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py index c0fd0f50..de168afc 100644 --- a/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/e2e/checkpoint_response_int_test.py @@ -28,7 +28,7 @@ ) if TYPE_CHECKING: - from aws_durable_execution_sdk_python.types import StepContext + from aws_durable_execution_sdk_python.types import StepContext, LambdaContext def create_mock_checkpoint_with_operations(): @@ -101,7 +101,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -164,7 +164,7 @@ def my_handler(event, context: DurableContext) -> list[str]: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -220,7 +220,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -279,7 +279,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -388,7 +388,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -440,7 +440,7 @@ def my_handler(event, context: DurableContext): mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -499,7 +499,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -598,7 +598,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -665,7 +665,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -730,7 +730,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ diff --git a/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py b/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py index 5a884bff..ed774632 100644 --- a/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/e2e/execution_int_test.py @@ -135,7 +135,7 @@ def mock_checkpoint( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -221,7 +221,7 @@ def mock_checkpoint( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -262,7 +262,7 @@ def mock_checkpoint( 123, "str", extra={ - "executionArn": "test-arn", + "executionArn": "test-arn/execution-1", "operationName": "mystep", "attempt": 1, "operationId": operation_id, @@ -308,7 +308,7 @@ def my_handler(event, context): # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -409,7 +409,7 @@ def mock_checkpoint_failure( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -463,7 +463,7 @@ def my_handler(event: Any, context: DurableContext): # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -560,7 +560,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ diff --git a/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py b/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py index 8ad812e4..62ad7c2b 100644 --- a/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/e2e/map_with_concurrent_waits_int_test.py @@ -42,6 +42,7 @@ OperationUpdate, OperationType, ) +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, ExecutionState, @@ -68,6 +69,7 @@ def _make_state( operations={}, service_client=mock_client, batcher_config=config, + plugin_executor=PluginExecutor([]), ) diff --git a/packages/aws-durable-execution-sdk-python/tests/execution_test.py b/packages/aws-durable-execution-sdk-python/tests/execution_test.py index db13b5a9..ed79bedf 100644 --- a/packages/aws-durable-execution-sdk-python/tests/execution_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/execution_test.py @@ -23,7 +23,6 @@ from aws_durable_execution_sdk_python.execution import ( DurableExecutionInvocationInput, DurableExecutionInvocationInputWithClient, - DurableExecutionInvocationOutput, InitialExecutionState, InvocationStatus, durable_execution, @@ -46,7 +45,9 @@ StateOutput, StepDetails, WaitDetails, + DurableExecutionInvocationOutput, ) +from aws_durable_execution_sdk_python.plugin import DurableInstrumentationPlugin LARGE_RESULT = "large_success" * 1024 * 1024 @@ -56,7 +57,7 @@ def test_durable_execution_invocation_input_from_dict(): """Test that DurableExecutionInvocationInput.from_dict works correctly""" input_dict = { - "DurableExecutionArn": "9692ca80-399d-4f52-8d0a-41acc9cd0492", + "DurableExecutionArn": "9692ca80-399d-4f52-8d0a-41acc9cd0492/9692ca80-399d-4f52-8d0a-41acc9cd0492", "CheckpointToken": "9692ca80-399d-4f52-8d0a-41acc9cd0492", "InitialExecutionState": { "Operations": [ @@ -76,7 +77,10 @@ def test_durable_execution_invocation_input_from_dict(): result = DurableExecutionInvocationInput.from_dict(input_dict) - assert result.durable_execution_arn == "9692ca80-399d-4f52-8d0a-41acc9cd0492" + assert ( + result.durable_execution_arn + == "9692ca80-399d-4f52-8d0a-41acc9cd0492/9692ca80-399d-4f52-8d0a-41acc9cd0492" + ) assert result.checkpoint_token == "9692ca80-399d-4f52-8d0a-41acc9cd0492" # noqa: S105 assert isinstance(result.initial_execution_state, InitialExecutionState) assert len(result.initial_execution_state.operations) == 1 @@ -167,14 +171,14 @@ def test_durable_execution_invocation_input_to_dict(): ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), } @@ -186,14 +190,14 @@ def test_durable_execution_invocation_input_to_dict_not_local(): initial_state = InitialExecutionState(operations=[], next_marker="") invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), } @@ -207,7 +211,7 @@ def test_durable_execution_invocation_input_with_client_inheritance(): initial_state = InitialExecutionState(operations=[], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -216,7 +220,7 @@ def test_durable_execution_invocation_input_with_client_inheritance(): # Should inherit to_dict from parent class result = invocation_input.to_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_dict(), } @@ -231,7 +235,7 @@ def test_durable_execution_invocation_input_with_client_from_parent(): initial_state = InitialExecutionState(operations=[], next_marker="") parent_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) @@ -360,7 +364,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Create regular event with LocalRunner=False event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -412,7 +416,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Create regular event with LocalRunner=False event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -469,7 +473,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -516,7 +520,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -571,7 +575,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -617,7 +621,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -664,7 +668,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -702,7 +706,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -748,7 +752,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Create regular event dict instead of DurableExecutionInvocationInputWithClient event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -796,7 +800,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -835,7 +839,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -917,7 +921,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -957,7 +961,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1007,7 +1011,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1056,7 +1060,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1104,7 +1108,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1154,7 +1158,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1198,7 +1202,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1242,7 +1246,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1288,7 +1292,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1334,7 +1338,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1381,7 +1385,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1447,7 +1451,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1537,7 +1541,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1620,7 +1624,7 @@ def test_handler(event: Any, context: DurableContext) -> str: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1690,7 +1694,7 @@ def test_handler(event: Any, context: DurableContext) -> str: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1745,7 +1749,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1805,7 +1809,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1862,7 +1866,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, @@ -1907,7 +1911,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: return {"result": "success"} event = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -2204,14 +2208,14 @@ def test_durable_execution_invocation_input_to_json_dict_minimal(): ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_json_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": initial_state.to_json_dict(), } @@ -2238,7 +2242,7 @@ def test_durable_execution_invocation_input_to_json_dict_with_timestamps(): ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) @@ -2252,7 +2256,7 @@ def test_durable_execution_invocation_input_to_json_dict_with_timestamps(): assert operation_result["StartTimestamp"] == expected_start_ms assert operation_result["EndTimestamp"] == expected_end_ms - assert result["DurableExecutionArn"] == "arn:test:execution" + assert result["DurableExecutionArn"] == "arn:test:execution/exec1" assert result["CheckpointToken"] == "token123" @@ -2261,14 +2265,14 @@ def test_durable_execution_invocation_input_to_json_dict_empty_operations(): initial_state = InitialExecutionState(operations=[], next_marker="") invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) result = invocation_input.to_json_dict() expected = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": {"Operations": [], "NextMarker": ""}, } @@ -2279,7 +2283,7 @@ def test_durable_execution_invocation_input_to_json_dict_empty_operations(): def test_durable_execution_invocation_input_from_json_dict_minimal(): """Test DurableExecutionInvocationInput.from_json_dict with minimal data.""" data = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -2295,7 +2299,7 @@ def test_durable_execution_invocation_input_from_json_dict_minimal(): result = DurableExecutionInvocationInput.from_json_dict(data) - assert result.durable_execution_arn == "arn:test:execution" + assert result.durable_execution_arn == "arn:test:execution/exec1" assert result.checkpoint_token == "token123" # noqa: S105 assert isinstance(result.initial_execution_state, InitialExecutionState) assert len(result.initial_execution_state.operations) == 1 @@ -2309,7 +2313,7 @@ def test_durable_execution_invocation_input_from_json_dict_with_timestamps(): end_ms = 1672578000000 # 2023-01-01 13:00:00 UTC data = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", "InitialExecutionState": { "Operations": [ @@ -2340,13 +2344,13 @@ def test_durable_execution_invocation_input_from_json_dict_with_timestamps(): def test_durable_execution_invocation_input_from_json_dict_empty_initial_state(): """Test DurableExecutionInvocationInput.from_json_dict handles missing InitialExecutionState.""" data = { - "DurableExecutionArn": "arn:test:execution", + "DurableExecutionArn": "arn:test:execution/exec1", "CheckpointToken": "token123", } result = DurableExecutionInvocationInput.from_json_dict(data) - assert result.durable_execution_arn == "arn:test:execution" + assert result.durable_execution_arn == "arn:test:execution/exec1" assert result.checkpoint_token == "token123" # noqa: S105 assert isinstance(result.initial_execution_state, InitialExecutionState) assert len(result.initial_execution_state.operations) == 0 @@ -2486,7 +2490,7 @@ def test_durable_execution_invocation_input_json_dict_preserves_non_timestamp_fi ) invocation_input = DurableExecutionInvocationInput( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, ) @@ -2504,7 +2508,7 @@ def test_durable_execution_invocation_input_json_dict_preserves_non_timestamp_fi assert operation_result["CallbackDetails"]["CallbackId"] == "cb123" assert operation_result["CallbackDetails"]["Result"] == "callback_result" - assert result["DurableExecutionArn"] == "arn:test:execution" + assert result["DurableExecutionArn"] == "arn:test:execution/exec1" assert result["CheckpointToken"] == "token123" assert result["InitialExecutionState"]["NextMarker"] == "marker123" @@ -2666,7 +2670,7 @@ def _make_invocation_input(mock_client, next_marker=""): execution_details=ExecutionDetails(input_payload="{}"), ) return DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=InitialExecutionState( operations=[operation], next_marker=next_marker @@ -2711,7 +2715,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: assert result["Status"] == InvocationStatus.SUCCEEDED.value assert json.loads(result["Result"]) == {"is_replaying": True} mock_client.get_execution_state.assert_called_once_with( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", next_marker="page2", ) @@ -2827,3 +2831,293 @@ def test_handler(event: Any, context: DurableContext) -> dict: _make_invocation_input(mock_client, next_marker="next-page-marker"), _make_lambda_context(), ) + + +# region Plugin Integration Tests + + +class _RecordingPlugin(DurableInstrumentationPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append(f"execution_end:{info.status.value}") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append(f"invocation_end:{info.status.value}") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info): + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info): + self.calls.append(f"attempt_end:{info.operation_id}") + + +class _FailingPlugin(DurableInstrumentationPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("plugin boom") + + def on_execution_end(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_start(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("plugin boom") + + +def test_durable_execution_with_plugins_success(): + """Test that plugins receive invocation start/end and execution end on success.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # ExecutionStartInfo dispatches to on_invocation_start in the match block + assert "invocation_start" in plugin.calls + assert "invocation_end:SUCCEEDED" in plugin.calls + + +def test_durable_execution_with_plugins_failure(): + """Test that plugins receive invocation end and execution end on user error.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "user error" + raise ValueError(msg) + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.FAILED.value + assert "invocation_start" in plugin.calls + assert "invocation_end:FAILED" in plugin.calls + + +def test_durable_execution_with_plugins_pending(): + """Test that plugins receive invocation end with PENDING status on suspend.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + raise SuspendExecution("test") + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.PENDING.value + assert "invocation_start" in plugin.calls + assert "invocation_end:PENDING" in plugin.calls + # Execution end should NOT be fired for PENDING + execution_end_calls = [c for c in plugin.calls if c.startswith("execution_end")] + assert len(execution_end_calls) == 0 + + +def test_durable_execution_with_plugins_retryable_error(): + """Test that plugins receive invocation end with RETRY status on retryable error.""" + mock_client = Mock(spec=DurableServiceClient) + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "Retriable error" + raise InvocationError(msg) + + with pytest.raises(InvocationError): + test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert "invocation_start" in plugin.calls + assert "invocation_end:RETRY" in plugin.calls + + +def test_durable_execution_with_multiple_plugins(): + """Test that multiple plugins all receive callbacks.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin1 = _RecordingPlugin() + plugin2 = _RecordingPlugin() + + @durable_execution(plugins=[plugin1, plugin2]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin1.calls + assert "invocation_start" in plugin2.calls + assert "invocation_end:SUCCEEDED" in plugin1.calls + assert "invocation_end:SUCCEEDED" in plugin2.calls + + +def test_durable_execution_with_failing_plugin_does_not_break_execution(): + """Test that a failing plugin does not prevent the handler from completing.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + failing_plugin = _FailingPlugin() + recording_plugin = _RecordingPlugin() + + @durable_execution(plugins=[failing_plugin, recording_plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + # Execution should still succeed despite the failing plugin + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # The recording plugin should still have been called + assert "invocation_start" in recording_plugin.calls + assert "invocation_end:SUCCEEDED" in recording_plugin.calls + + +def test_durable_execution_with_no_plugins(): + """Test that passing no plugins (None) works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=None) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_with_empty_plugins_list(): + """Test that passing an empty plugins list works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=[]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_decorator_with_plugins_and_boto3_client(): + """Test that plugins parameter works alongside boto3_client parameter.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + # When using DurableExecutionInvocationInputWithClient, boto3_client is ignored + # but we verify the decorator accepts both parameters + @durable_execution(boto3_client=None, plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin.calls + + +# endregion Plugin Integration Tests diff --git a/packages/aws-durable-execution-sdk-python/tests/lambda_service_test.py b/packages/aws-durable-execution-sdk-python/tests/lambda_service_test.py index 626be292..4d4fd5a4 100644 --- a/packages/aws-durable-execution-sdk-python/tests/lambda_service_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/lambda_service_test.py @@ -585,7 +585,9 @@ def test_invoke_options_to_dict_complete(): def test_operation_update_create_invoke_start(): """Test OperationUpdate.create_invoke_start method to cover line 545.""" - identifier = OperationIdentifier("test-id", "parent-id") + identifier = OperationIdentifier( + "test-id", OperationSubType.CHAINED_INVOKE, "parent-id" + ) invoke_options = ChainedInvokeOptions("test-func") update = OperationUpdate.create_invoke_start(identifier, "payload", invoke_options) assert update.operation_id == "test-id" @@ -686,7 +688,8 @@ def test_operation_update_create_callback(): """Test OperationUpdate.create_callback factory method.""" callback_options = CallbackOptions(timeout_seconds=300) update = OperationUpdate.create_callback( - OperationIdentifier("cb1", None, "test_callback"), callback_options + OperationIdentifier("cb1", OperationSubType.CALLBACK, None, "test_callback"), + callback_options, ) assert update.operation_id == "cb1" assert update.operation_type is OperationType.CALLBACK @@ -700,7 +703,8 @@ def test_operation_update_create_wait_start(): """Test OperationUpdate.create_wait_start factory method.""" wait_options = WaitOptions(wait_seconds=30) update = OperationUpdate.create_wait_start( - OperationIdentifier("wait1", "parent1", "test_wait"), wait_options + OperationIdentifier("wait1", OperationSubType.WAIT, "parent1", "test_wait"), + wait_options, ) assert update.operation_id == "wait1" assert update.parent_id == "parent1" @@ -728,7 +732,8 @@ def test_operation_update_create_execution_succeed(mock_datetime): def test_operation_update_create_step_succeed(): """Test OperationUpdate.create_step_succeed factory method.""" update = OperationUpdate.create_step_succeed( - OperationIdentifier("step1", None, "test_step"), "step_payload" + OperationIdentifier("step1", OperationSubType.STEP, None, "test_step"), + "step_payload", ) assert update.operation_id == "step1" assert update.operation_type is OperationType.STEP @@ -746,7 +751,9 @@ def test_operation_update_factory_methods(): # Test create_context_start update = OperationUpdate.create_context_start( - OperationIdentifier("ctx1", None, "test_context"), + OperationIdentifier( + "ctx1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_context" + ), OperationSubType.RUN_IN_CHILD_CONTEXT, ) assert update.operation_type is OperationType.CONTEXT @@ -755,7 +762,9 @@ def test_operation_update_factory_methods(): # Test create_context_succeed update = OperationUpdate.create_context_succeed( - OperationIdentifier("ctx1", None, "test_context"), + OperationIdentifier( + "ctx1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_context" + ), "payload", OperationSubType.RUN_IN_CHILD_CONTEXT, ) @@ -765,7 +774,9 @@ def test_operation_update_factory_methods(): # Test create_context_fail update = OperationUpdate.create_context_fail( - OperationIdentifier("ctx1", None, "test_context"), + OperationIdentifier( + "ctx1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_context" + ), error, OperationSubType.RUN_IN_CHILD_CONTEXT, ) @@ -780,7 +791,7 @@ def test_operation_update_factory_methods(): # Test create_step_fail update = OperationUpdate.create_step_fail( - OperationIdentifier("step1", None, "test_step"), error + OperationIdentifier("step1", OperationSubType.STEP, None, "test_step"), error ) assert update.operation_type is OperationType.STEP assert update.action is OperationAction.FAIL @@ -788,14 +799,16 @@ def test_operation_update_factory_methods(): # Test create_step_start update = OperationUpdate.create_step_start( - OperationIdentifier("step1", None, "test_step") + OperationIdentifier("step1", OperationSubType.STEP, None, "test_step") ) assert update.action is OperationAction.START assert update.sub_type is OperationSubType.STEP # Test create_step_retry update = OperationUpdate.create_step_retry( - OperationIdentifier("step1", None, "test_step"), error, 30 + OperationIdentifier("step1", OperationSubType.STEP, None, "test_step"), + error, + 30, ) assert update.action is OperationAction.RETRY assert update.step_options.next_attempt_delay_seconds == 30 @@ -953,7 +966,12 @@ def test_operation_update_complete_with_new_fields(): def test_operation_update_create_wait_for_condition_start(): """Test OperationUpdate.create_wait_for_condition_start factory method.""" - identifier = OperationIdentifier("wait_cond_1", "parent1", "test_wait_condition") + identifier = OperationIdentifier( + "wait_cond_1", + OperationSubType.WAIT_FOR_CONDITION, + "parent1", + "test_wait_condition", + ) update = OperationUpdate.create_wait_for_condition_start(identifier) assert update.operation_id == "wait_cond_1" @@ -966,7 +984,12 @@ def test_operation_update_create_wait_for_condition_start(): def test_operation_update_create_wait_for_condition_succeed(): """Test OperationUpdate.create_wait_for_condition_succeed factory method.""" - identifier = OperationIdentifier("wait_cond_1", "parent1", "test_wait_condition") + identifier = OperationIdentifier( + "wait_cond_1", + OperationSubType.WAIT_FOR_CONDITION, + "parent1", + "test_wait_condition", + ) update = OperationUpdate.create_wait_for_condition_succeed( identifier, "success_payload" ) @@ -982,7 +1005,12 @@ def test_operation_update_create_wait_for_condition_succeed(): def test_operation_update_create_wait_for_condition_retry(): """Test OperationUpdate.create_wait_for_condition_retry factory method.""" - identifier = OperationIdentifier("wait_cond_1", "parent1", "test_wait_condition") + identifier = OperationIdentifier( + "wait_cond_1", + OperationSubType.WAIT_FOR_CONDITION, + "parent1", + "test_wait_condition", + ) update = OperationUpdate.create_wait_for_condition_retry( identifier, "retry_payload", 45 ) @@ -999,7 +1027,12 @@ def test_operation_update_create_wait_for_condition_retry(): def test_operation_update_create_wait_for_condition_fail(): """Test OperationUpdate.create_wait_for_condition_fail factory method.""" - identifier = OperationIdentifier("wait_cond_1", "parent1", "test_wait_condition") + identifier = OperationIdentifier( + "wait_cond_1", + OperationSubType.WAIT_FOR_CONDITION, + "parent1", + "test_wait_condition", + ) error = ErrorObject( message="Condition failed", type="ConditionError", data=None, stack_trace=None ) diff --git a/packages/aws-durable-execution-sdk-python/tests/logger_test.py b/packages/aws-durable-execution-sdk-python/tests/logger_test.py index b6017fa6..9ce3719d 100644 --- a/packages/aws-durable-execution-sdk-python/tests/logger_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/logger_test.py @@ -9,8 +9,10 @@ Operation, OperationStatus, OperationType, + OperationSubType, ) from aws_durable_execution_sdk_python.logger import Logger, LoggerInterface, LogInfo +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -83,6 +85,7 @@ def exception( initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), ) @@ -135,7 +138,7 @@ def test_log_info_creation_minimal(): def test_log_info_from_operation_identifier(): """Test LogInfo.from_operation_identifier.""" - op_id = OperationIdentifier("op123", "parent456", "op_name") + op_id = OperationIdentifier("op123", OperationSubType.STEP, "parent456", "op_name") log_info = LogInfo.from_operation_identifier(EXECUTION_STATE, op_id, 3) assert log_info.execution_state.durable_execution_arn == "arn:aws:test" assert log_info.parent_id == "parent456" @@ -146,7 +149,7 @@ def test_log_info_from_operation_identifier(): def test_log_info_from_operation_identifier_no_attempt(): """Test LogInfo.from_operation_identifier without attempt.""" - op_id = OperationIdentifier("op123", "parent456", "op_name") + op_id = OperationIdentifier("op123", OperationSubType.STEP, "parent456", "op_name") log_info = LogInfo.from_operation_identifier(EXECUTION_STATE, op_id) assert log_info.execution_state.durable_execution_arn == "arn:aws:test" assert log_info.parent_id == "parent456" @@ -227,6 +230,7 @@ def test_logger_with_log_info(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor([]), ) new_info = LogInfo(execution_state_new, "parent2", "op123", "new_name") new_logger = logger.with_log_info(new_info) @@ -377,6 +381,7 @@ def test_logger_replay_no_logging(): operations={"op1": operation}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) mock_logger = Mock() @@ -404,6 +409,7 @@ def test_logger_replay_then_new_logging(): operations={"op1": operation1, "op2": operation2}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(execution_state, "parent123", "test_name", 5) mock_logger = Mock() diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/callback_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/callback_test.py index 334e276e..71a194c6 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/callback_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/callback_test.py @@ -70,7 +70,9 @@ def test_create_callback_handler_new_operation_with_config(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback1", None, "test_callback"), + operation_identifier=OperationIdentifier( + "callback1", OperationSubType.CALLBACK, None, "test_callback" + ), config=config, ) @@ -110,7 +112,9 @@ def test_create_callback_handler_new_operation_without_config(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback2", None), + operation_identifier=OperationIdentifier( + "callback2", OperationSubType.CALLBACK, None + ), config=None, ) @@ -144,7 +148,9 @@ def test_create_callback_handler_existing_started_operation(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback3", None), + operation_identifier=OperationIdentifier( + "callback3", OperationSubType.CALLBACK, None + ), config=None, ) @@ -171,7 +177,9 @@ def test_create_callback_handler_existing_failed_operation(): # Should return callback_id without raising callback_id = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback4", None), + operation_identifier=OperationIdentifier( + "callback4", OperationSubType.CALLBACK, None + ), config=None, ) @@ -194,7 +202,9 @@ def test_create_callback_handler_existing_started_missing_callback_details(): with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback5", None), + operation_identifier=OperationIdentifier( + "callback5", OperationSubType.CALLBACK, None + ), config=None, ) @@ -216,7 +226,9 @@ def test_create_callback_handler_new_operation_missing_callback_details_after_ch with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback6", None), + operation_identifier=OperationIdentifier( + "callback6", OperationSubType.CALLBACK, None + ), config=None, ) @@ -236,7 +248,9 @@ def test_create_callback_handler_existing_timed_out_operation(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_timed_out", None), + operation_identifier=OperationIdentifier( + "callback_timed_out", OperationSubType.CALLBACK, None + ), config=None, ) @@ -260,7 +274,7 @@ def test_create_callback_handler_existing_timed_out_missing_callback_details(): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier( - "callback_timed_out_no_details", None + "callback_timed_out_no_details", OperationSubType.CALLBACK, None ), config=None, ) @@ -347,7 +361,9 @@ def test_create_callback_handler_with_none_operation_in_result(): with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("none_operation", None), + operation_identifier=OperationIdentifier( + "none_operation", OperationSubType.CALLBACK, None + ), config=None, ) @@ -470,7 +486,9 @@ def test_create_callback_handler_existing_succeeded_operation(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_succeeded", None), + operation_identifier=OperationIdentifier( + "callback_succeeded", OperationSubType.CALLBACK, None + ), config=None, ) @@ -494,7 +512,7 @@ def test_create_callback_handler_existing_succeeded_missing_callback_details(): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier( - "callback_succeeded_no_details", None + "callback_succeeded_no_details", OperationSubType.CALLBACK, None ), config=None, ) @@ -521,7 +539,9 @@ def test_create_callback_handler_config_with_zero_timeouts(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_zero", None), + operation_identifier=OperationIdentifier( + "callback_zero", OperationSubType.CALLBACK, None + ), config=config, ) @@ -564,7 +584,9 @@ def test_create_callback_handler_config_with_large_timeouts(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_large", None), + operation_identifier=OperationIdentifier( + "callback_large", OperationSubType.CALLBACK, None + ), config=config, ) @@ -602,7 +624,7 @@ def test_create_callback_handler_empty_operation_id(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("", None), + operation_identifier=OperationIdentifier("", OperationSubType.CALLBACK, None), config=None, ) @@ -802,7 +824,9 @@ def test_callback_lifecycle_complete_flow(): ) callback_id = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("lifecycle_callback", None), + operation_identifier=OperationIdentifier( + "lifecycle_callback", OperationSubType.CALLBACK, None + ), config=config, ) @@ -844,12 +868,16 @@ def test_callback_retry_scenario(): callback_id_1 = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("retry_callback", None), + operation_identifier=OperationIdentifier( + "retry_callback", OperationSubType.CALLBACK, None + ), config=None, ) callback_id_2 = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("retry_callback", None), + operation_identifier=OperationIdentifier( + "retry_callback", OperationSubType.CALLBACK, None + ), config=None, ) @@ -883,7 +911,7 @@ def test_callback_timeout_configuration(): callback_id = create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier( - f"timeout_callback_{timeout_seconds}", None + f"timeout_callback_{timeout_seconds}", OperationSubType.CALLBACK, None ), config=config, ) @@ -908,7 +936,9 @@ def test_callback_error_propagation(): # Should return callback_id without raising callback_id = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("error_callback", None), + operation_identifier=OperationIdentifier( + "error_callback", OperationSubType.CALLBACK, None + ), config=None, ) assert callback_id == "failed_cb" @@ -978,7 +1008,9 @@ def test_callback_state_consistency(): callback_id_1 = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("consistent_callback", None), + operation_identifier=OperationIdentifier( + "consistent_callback", OperationSubType.CALLBACK, None + ), config=None, ) @@ -989,7 +1021,9 @@ def test_callback_state_consistency(): callback_id_2 = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("consistent_callback", None), + operation_identifier=OperationIdentifier( + "consistent_callback", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1050,12 +1084,14 @@ def test_callback_operation_update_creation(mock_operation_update): create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("update_test", None), + operation_identifier=OperationIdentifier( + "update_test", OperationSubType.CALLBACK, None + ), config=config, ) mock_operation_update.create_callback.assert_called_once_with( - identifier=OperationIdentifier("update_test", None), + identifier=OperationIdentifier("update_test", OperationSubType.CALLBACK, None), callback_options=CallbackOptions( timeout_seconds=600, heartbeat_timeout_seconds=120 ), @@ -1084,7 +1120,9 @@ def test_callback_immediate_response_get_checkpoint_result_called_twice(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_1", None), + operation_identifier=OperationIdentifier( + "callback_immediate_1", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1112,7 +1150,9 @@ def test_callback_immediate_response_create_checkpoint_with_is_sync_true(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_2", None), + operation_identifier=OperationIdentifier( + "callback_immediate_2", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1146,7 +1186,9 @@ def test_callback_immediate_response_immediate_success(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_3", None), + operation_identifier=OperationIdentifier( + "callback_immediate_3", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1182,7 +1224,9 @@ def test_callback_immediate_response_immediate_failure_deferred(): # CRITICAL: Should return callback_id without raising result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_4", None), + operation_identifier=OperationIdentifier( + "callback_immediate_4", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1292,7 +1336,9 @@ def test_callback_immediate_response_no_immediate_response(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_5", None), + operation_identifier=OperationIdentifier( + "callback_immediate_5", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1325,7 +1371,9 @@ def test_callback_immediate_response_already_completed(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_6", None), + operation_identifier=OperationIdentifier( + "callback_immediate_6", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1359,7 +1407,9 @@ def test_callback_immediate_response_already_failed(): # Should return callback_id without raising result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_immediate_7", None), + operation_identifier=OperationIdentifier( + "callback_immediate_7", OperationSubType.CALLBACK, None + ), config=None, ) @@ -1399,7 +1449,9 @@ def test_callback_deferred_error_handling_code_execution_between_create_and_resu # Step 1: create_callback() returns callback_id without raising callback_id = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_deferred_error", None), + operation_identifier=OperationIdentifier( + "callback_deferred_error", OperationSubType.CALLBACK, None + ), config=None, ) assert callback_id == "cb_deferred_error" @@ -1450,7 +1502,9 @@ def test_callback_immediate_response_with_config(): result = create_callback_handler( state=mock_state, - operation_identifier=OperationIdentifier("callback_with_config", None), + operation_identifier=OperationIdentifier( + "callback_with_config", OperationSubType.CALLBACK, None + ), config=config, ) @@ -1490,7 +1544,9 @@ def test_callback_returns_id_when_second_check_returns_started(): executor = CallbackOperationExecutor( state=mock_state, - operation_identifier=OperationIdentifier("callback-1", None, "test_callback"), + operation_identifier=OperationIdentifier( + "callback-1", OperationSubType.CALLBACK, None, "test_callback" + ), config=CallbackConfig(), ) callback_id = executor.process() @@ -1522,7 +1578,9 @@ def test_callback_returns_id_when_second_check_returns_started_duplicate(): executor = CallbackOperationExecutor( state=mock_state, - operation_identifier=OperationIdentifier("callback-1", None, "test_callback"), + operation_identifier=OperationIdentifier( + "callback-1", OperationSubType.CALLBACK, None, "test_callback" + ), config=CallbackConfig(), ) callback_id = executor.process() diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/child_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/child_test.py index 3915fd22..1f8617ca 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/child_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/child_test.py @@ -58,9 +58,15 @@ def test_child_handler_not_started( mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="fresh_result") + mock_state.wrap_user_function.return_value = mock_callable result = child_handler( - mock_callable, mock_state, OperationIdentifier("op1", None, "test_name"), config + mock_callable, + mock_state, + OperationIdentifier( + "op1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), + config, ) assert result == "fresh_result" @@ -114,7 +120,12 @@ def test_child_handler_already_succeeded(): mock_callable = Mock() result = child_handler( - mock_callable, mock_state, OperationIdentifier("op2", None, "test_name"), None + mock_callable, + mock_state, + OperationIdentifier( + "op2", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), + None, ) assert result == "cached_result" @@ -138,7 +149,12 @@ def test_child_handler_already_succeeded_none_result(): mock_callable = Mock() result = child_handler( - mock_callable, mock_state, OperationIdentifier("op3", None, "test_name"), None + mock_callable, + mock_state, + OperationIdentifier( + "op3", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), + None, ) assert result is None @@ -167,7 +183,9 @@ def test_child_handler_already_failed(): child_handler( mock_callable, mock_state, - OperationIdentifier("op4", None, "test_name"), + OperationIdentifier( + "op4", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), None, ) @@ -207,9 +225,15 @@ def test_child_handler_already_started( mock_result.is_replay_children.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="started_result") + mock_state.wrap_user_function.return_value = mock_callable result = child_handler( - mock_callable, mock_state, OperationIdentifier("op5", None, "test_name"), config + mock_callable, + mock_state, + OperationIdentifier( + "op5", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), + config, ) assert result == "started_result" @@ -261,12 +285,15 @@ def test_child_handler_callable_exception( mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(side_effect=ValueError("Test error")) + mock_state.wrap_user_function.return_value = mock_callable with pytest.raises(CallableRuntimeError): child_handler( mock_callable, mock_state, - OperationIdentifier("op6", None, "test_name"), + OperationIdentifier( + "op6", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config, ) @@ -315,12 +342,15 @@ def test_child_handler_error_wrapped(): mock_state.get_checkpoint_result.return_value = mock_result test_error = RuntimeError("Test error") mock_callable = Mock(side_effect=test_error) + mock_state.wrap_user_function.return_value = mock_callable with pytest.raises(CallableRuntimeError): child_handler( mock_callable, mock_state, - OperationIdentifier("op7", None, "test_name"), + OperationIdentifier( + "op7", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), None, ) @@ -347,12 +377,15 @@ def test_child_handler_invocation_error_reraised(): mock_state.get_checkpoint_result.return_value = mock_result test_error = InvocationError("Invocation failed") mock_callable = Mock(side_effect=test_error) + mock_state.wrap_user_function.return_value = mock_callable with pytest.raises(InvocationError, match="Invocation failed"): child_handler( mock_callable, mock_state, - OperationIdentifier("op7b", None, "test_name"), + OperationIdentifier( + "op7b", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), None, ) @@ -376,10 +409,16 @@ def test_child_handler_with_config(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="config_result") + mock_state.wrap_user_function.return_value = mock_callable config = ChildConfig() result = child_handler( - mock_callable, mock_state, OperationIdentifier("op8", None, "test_name"), config + mock_callable, + mock_state, + OperationIdentifier( + "op8", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), + config, ) assert result == "config_result" @@ -401,9 +440,15 @@ def test_child_handler_default_serialization(): mock_state.get_checkpoint_result.return_value = mock_result complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) + mock_state.wrap_user_function.return_value = mock_callable result = child_handler( - mock_callable, mock_state, OperationIdentifier("op9", None, "test_name"), None + mock_callable, + mock_state, + OperationIdentifier( + "op9", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), + None, ) assert result == complex_result @@ -430,12 +475,15 @@ def test_child_handler_custom_serdes_not_start() -> None: mock_state.get_checkpoint_result.return_value = mock_result complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) + mock_state.wrap_user_function.return_value = mock_callable child_config: ChildConfig = ChildConfig(serdes=CustomDictSerDes()) child_handler( mock_callable, mock_state, - OperationIdentifier("op9", None, "test_name"), + OperationIdentifier( + "op9", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -464,7 +512,9 @@ def test_child_handler_custom_serdes_already_succeeded() -> None: actual_result = child_handler( mock_callable, mock_state, - OperationIdentifier("op9", None, "test_name"), + OperationIdentifier( + "op9", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -497,6 +547,7 @@ def test_child_handler_large_payload_with_summary_generator() -> None: mock_state.get_checkpoint_result.return_value = mock_result large_result = "large" * 256 * 1024 mock_callable = Mock(return_value=large_result) + mock_state.wrap_user_function.return_value = mock_callable def my_summary(result: str) -> str: return "summary" @@ -508,7 +559,9 @@ def my_summary(result: str) -> str: actual_result = child_handler( mock_callable, mock_state, - OperationIdentifier("op9", None, "test_name"), + OperationIdentifier( + "op9", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -542,12 +595,15 @@ def test_child_handler_large_payload_without_summary_generator() -> None: mock_state.get_checkpoint_result.return_value = mock_result large_result = "large" * 256 * 1024 mock_callable = Mock(return_value=large_result) + mock_state.wrap_user_function.return_value = mock_callable child_config: ChildConfig = ChildConfig() actual_result = child_handler( mock_callable, mock_state, - OperationIdentifier("op9", None, "test_name"), + OperationIdentifier( + "op9", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -581,12 +637,15 @@ def test_child_handler_replay_children_mode() -> None: mock_state.get_checkpoint_result.return_value = mock_result complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) + mock_state.wrap_user_function.return_value = mock_callable child_config: ChildConfig = ChildConfig() actual_result = child_handler( mock_callable, mock_state, - OperationIdentifier("op9", None, "test_name"), + OperationIdentifier( + "op9", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -619,6 +678,7 @@ def test_small_payload_with_summary_generator(): # Small payload (< 256KB) small_result = "small_payload" mock_callable = Mock(return_value=small_result) + mock_state.wrap_user_function.return_value = mock_callable def my_summary(result: str) -> str: return "summary_of_small_payload" @@ -628,7 +688,9 @@ def my_summary(result: str) -> str: actual_result = child_handler( mock_callable, mock_state, - OperationIdentifier("op1", None, "test_name"), + OperationIdentifier( + "op1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -666,13 +728,16 @@ def test_small_payload_without_summary_generator(): # Small payload (< 256KB); no summary_generator provided small_result = "small_payload" mock_callable = Mock(return_value=small_result) + mock_state.wrap_user_function.return_value = mock_callable child_config: ChildConfig[str] = ChildConfig[str]() actual_result = child_handler( mock_callable, mock_state, - OperationIdentifier("op1", None, "test_name"), + OperationIdentifier( + "op1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), child_config, ) @@ -704,13 +769,16 @@ def test_child_handler_is_virtual_no_start(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="no_checkpoint_result") + mock_state.wrap_user_function.return_value = mock_callable config = ChildConfig(is_virtual=True) result = child_handler( mock_callable, mock_state, - OperationIdentifier("op1", None, "test_name"), + OperationIdentifier( + "op1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config, ) @@ -742,13 +810,16 @@ def test_child_handler_is_virtual_no_succeed(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="no_checkpoint_result") + mock_state.wrap_user_function.return_value = mock_callable config = ChildConfig(is_virtual=True) result = child_handler( mock_callable, mock_state, - OperationIdentifier("op2", None, "test_name"), + OperationIdentifier( + "op2", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config, ) @@ -772,13 +843,16 @@ def test_child_handler_not_is_virtual_finish_mode(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="checkpoint_result") + mock_state.wrap_user_function.return_value = mock_callable config = ChildConfig(is_virtual=False) result = child_handler( mock_callable, mock_state, - OperationIdentifier("op3", None, "test_name"), + OperationIdentifier( + "op3", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config, ) @@ -821,6 +895,7 @@ def test_child_handler_is_virtual_with_exception(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(side_effect=ValueError("Test error")) + mock_state.wrap_user_function.return_value = mock_callable config = ChildConfig(is_virtual=True) @@ -828,7 +903,9 @@ def test_child_handler_is_virtual_with_exception(): child_handler( mock_callable, mock_state, - OperationIdentifier("op4", None, "test_name"), + OperationIdentifier( + "op4", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config, ) @@ -850,6 +927,7 @@ def test_child_handler_not_is_virtual_with_exception(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(side_effect=ValueError("Test error")) + mock_state.wrap_user_function.return_value = mock_callable config = ChildConfig(is_virtual=False) @@ -857,7 +935,9 @@ def test_child_handler_not_is_virtual_with_exception(): child_handler( mock_callable, mock_state, - OperationIdentifier("op5", None, "test_name"), + OperationIdentifier( + "op5", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config, ) @@ -892,6 +972,7 @@ def setup_mocks(): mock_result.is_existent.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="test_result") + mock_state.wrap_user_function.return_value = mock_callable return mock_state, mock_callable # is_virtual=False: 2 checkpoints @@ -901,7 +982,9 @@ def setup_mocks(): result1 = child_handler( mock_callable1, mock_state1, - OperationIdentifier("op1", None, "test_name"), + OperationIdentifier( + "op1", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config1, ) @@ -915,7 +998,9 @@ def setup_mocks(): result2 = child_handler( mock_callable2, mock_state2, - OperationIdentifier("op2", None, "test_name"), + OperationIdentifier( + "op2", OperationSubType.RUN_IN_CHILD_CONTEXT, None, "test_name" + ), config2, ) diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/invoke_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/invoke_test.py index 5bb98da2..c6a247f1 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/invoke_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/invoke_test.py @@ -22,6 +22,7 @@ OperationAction, OperationStatus, OperationType, + OperationSubType, ) from aws_durable_execution_sdk_python.operation.invoke import InvokeOperationExecutor from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState @@ -62,7 +63,9 @@ def test_invoke_handler_already_succeeded(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke1", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke1", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -88,7 +91,9 @@ def test_invoke_handler_already_succeeded_none_result(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke2", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke2", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -113,7 +118,9 @@ def test_invoke_handler_already_succeeded_no_chained_invoke_details(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke3", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke3", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -145,7 +152,9 @@ def test_invoke_handler_already_terminated(kind: OperationStatus): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke4", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke4", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -172,7 +181,9 @@ def test_invoke_handler_already_timed_out(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke5", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke5", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -199,7 +210,9 @@ def test_invoke_handler_already_started(status): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke6", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke6", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -226,7 +239,9 @@ def test_invoke_handler_already_started_with_timeout(status): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke7", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke7", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=config, ) @@ -255,7 +270,9 @@ def test_invoke_handler_new_operation(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke8", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke8", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=config, ) @@ -292,7 +309,9 @@ def test_invoke_handler_new_operation_with_timeout(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke9", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke9", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=config, ) @@ -318,7 +337,9 @@ def test_invoke_handler_new_operation_no_timeout(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke10", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke10", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=config, ) @@ -342,7 +363,9 @@ def test_invoke_handler_no_config(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke11", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke11", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -378,7 +401,9 @@ def test_invoke_handler_custom_serdes(): function_name="test_function", payload={"key": "value", "number": 42, "list": [1, 2, 3]}, state=mock_state, - operation_identifier=OperationIdentifier("invoke12", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke12", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=config, ) @@ -410,7 +435,9 @@ def test_invoke_handler_custom_serdes_new_operation(): function_name="test_function", payload=complex_payload, state=mock_state, - operation_identifier=OperationIdentifier("invoke13", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke13", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=config, ) @@ -472,7 +499,9 @@ def test_invoke_handler_with_operation_name(status: OperationStatus): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke14", None, "named_invoke"), + operation_identifier=OperationIdentifier( + "invoke14", OperationSubType.CHAINED_INVOKE, None, "named_invoke" + ), config=None, ) @@ -497,7 +526,9 @@ def test_invoke_handler_without_operation_name(status: OperationStatus): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke15", None, None), + operation_identifier=OperationIdentifier( + "invoke15", OperationSubType.CHAINED_INVOKE, None, None + ), config=None, ) @@ -521,7 +552,9 @@ def test_invoke_handler_with_none_payload(): function_name="test_function", payload=None, state=mock_state, - operation_identifier=OperationIdentifier("invoke16", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke16", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -549,7 +582,9 @@ def test_invoke_handler_already_succeeded_with_none_payload(): function_name="test_function", payload=None, state=mock_state, - operation_identifier=OperationIdentifier("invoke17", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke17", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -586,7 +621,9 @@ def test_invoke_handler_suspend_does_not_raise(mock_suspend): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke18", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke18", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=None, ) @@ -614,7 +651,9 @@ def test_invoke_handler_with_tenant_id(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke1", None, None), + operation_identifier=OperationIdentifier( + "invoke1", OperationSubType.CHAINED_INVOKE, None, None + ), config=config, ) @@ -647,7 +686,9 @@ def test_invoke_handler_without_tenant_id(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke1", None, None), + operation_identifier=OperationIdentifier( + "invoke1", OperationSubType.CHAINED_INVOKE, None, None + ), config=config, ) @@ -678,7 +719,9 @@ def test_invoke_handler_default_config_no_tenant_id(): function_name="test_function", payload="test_input", state=mock_state, - operation_identifier=OperationIdentifier("invoke1", None, None), + operation_identifier=OperationIdentifier( + "invoke1", OperationSubType.CHAINED_INVOKE, None, None + ), config=None, ) @@ -712,7 +755,9 @@ def test_invoke_handler_defaults_to_json_serdes(): function_name="test_function", payload=payload, state=mock_state, - operation_identifier=OperationIdentifier("invoke_json", None, None), + operation_identifier=OperationIdentifier( + "invoke_json", OperationSubType.CHAINED_INVOKE, None, None + ), config=config, ) @@ -742,7 +787,9 @@ def test_invoke_handler_result_defaults_to_json_serdes(): function_name="test_function", payload={"input": "data"}, state=mock_state, - operation_identifier=OperationIdentifier("invoke_result_json", None, None), + operation_identifier=OperationIdentifier( + "invoke_result_json", OperationSubType.CHAINED_INVOKE, None, None + ), config=config, ) @@ -776,7 +823,10 @@ def test_invoke_immediate_response_get_checkpoint_result_called_twice(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_1", None, "test_invoke" + "invoke_immediate_1", + OperationSubType.CHAINED_INVOKE, + None, + "test_invoke", ), config=None, ) @@ -806,7 +856,10 @@ def test_invoke_immediate_response_create_checkpoint_with_is_sync_true(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_2", None, "test_invoke" + "invoke_immediate_2", + OperationSubType.CHAINED_INVOKE, + None, + "test_invoke", ), config=None, ) @@ -844,7 +897,7 @@ def test_invoke_immediate_response_immediate_success(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_3", None, "test_invoke" + "invoke_immediate_3", OperationSubType.CHAINED_INVOKE, None, "test_invoke" ), config=None, ) @@ -878,7 +931,7 @@ def test_invoke_immediate_response_immediate_success_with_none_result(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_4", None, "test_invoke" + "invoke_immediate_4", OperationSubType.CHAINED_INVOKE, None, "test_invoke" ), config=None, ) @@ -922,7 +975,10 @@ def test_invoke_immediate_response_immediate_failure(status: OperationStatus): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_5", None, "test_invoke" + "invoke_immediate_5", + OperationSubType.CHAINED_INVOKE, + None, + "test_invoke", ), config=None, ) @@ -958,7 +1014,10 @@ def test_invoke_immediate_response_no_immediate_response(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_6", None, "test_invoke" + "invoke_immediate_6", + OperationSubType.CHAINED_INVOKE, + None, + "test_invoke", ), config=None, ) @@ -995,7 +1054,7 @@ def test_invoke_immediate_response_already_completed(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_7", None, "test_invoke" + "invoke_immediate_7", OperationSubType.CHAINED_INVOKE, None, "test_invoke" ), config=None, ) @@ -1033,7 +1092,7 @@ def test_invoke_immediate_response_with_timeout_immediate_success(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_8", None, "test_invoke" + "invoke_immediate_8", OperationSubType.CHAINED_INVOKE, None, "test_invoke" ), config=config, ) @@ -1070,7 +1129,10 @@ def test_invoke_immediate_response_with_timeout_no_immediate_response(): payload="test_input", state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_9", None, "test_invoke" + "invoke_immediate_9", + OperationSubType.CHAINED_INVOKE, + None, + "test_invoke", ), config=config, ) @@ -1105,7 +1167,7 @@ def test_invoke_immediate_response_with_custom_serdes(): payload={"key": "value", "number": 42, "list": [1, 2, 3]}, state=mock_state, operation_identifier=OperationIdentifier( - "invoke_immediate_10", None, "test_invoke" + "invoke_immediate_10", OperationSubType.CHAINED_INVOKE, None, "test_invoke" ), config=config, ) @@ -1139,7 +1201,9 @@ def test_invoke_suspends_when_second_check_returns_started(): executor = InvokeOperationExecutor( state=mock_state, - operation_identifier=OperationIdentifier("invoke-1", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke-1", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), function_name="my-function", payload={"data": "test"}, config=InvokeConfig(), @@ -1175,7 +1239,9 @@ def test_invoke_suspends_when_second_check_returns_started_duplicate(): function_name="my-function", payload={"data": "test"}, state=mock_state, - operation_identifier=OperationIdentifier("invoke-1", None, "test_invoke"), + operation_identifier=OperationIdentifier( + "invoke-1", OperationSubType.CHAINED_INVOKE, None, "test_invoke" + ), config=InvokeConfig(), ) diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/map_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/map_test.py index e856cda2..8f4e0c5f 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/map_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/map_test.py @@ -167,7 +167,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = MapConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) result = map_handler( items, @@ -198,7 +200,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Since MapConfig() is called in map_handler when config is None, # we need to provide a valid config to avoid the NameError @@ -321,7 +325,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = MapConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) result = map_handler( items, @@ -367,7 +373,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) result = map_handler( items, @@ -404,7 +412,9 @@ def callable_func(ctx, item, idx, items): executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context class MockExecutionState: def get_checkpoint_result(self, operation_id): @@ -414,7 +424,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = MapConfig(serdes=CustomStrSerDes()) - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) result = map_handler( items, @@ -452,7 +464,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Call map_handler map_handler( @@ -510,7 +524,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Call map_handler with None config (should use default) map_handler( @@ -571,7 +587,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 @@ -614,7 +632,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = MapConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Mock map context map_context = Mock() @@ -675,7 +695,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = MapConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Mock map context map_context = Mock() @@ -749,7 +771,9 @@ def test_func(ctx, item, idx, items): items = ["a", "b"] config = MapConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Track whether we're in first or second execution execution_count = 0 @@ -845,6 +869,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -907,6 +932,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -955,8 +981,12 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() map_context = Mock() map_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2", "3"]) # noqa SLF001 - map_context.create_child_context = Mock(return_value=Mock()) - operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + map_context.create_child_context = Mock(return_value=child_context) + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.MAP, "parent", "test_map" + ) # Execute map result = map_handler( @@ -1010,6 +1040,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -1068,6 +1099,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -1133,6 +1165,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -1184,6 +1217,7 @@ def map_func(ctx, item, idx, items): mock_state.get_checkpoint_result = Mock(return_value=parent_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context = create_test_context(state=mock_state) diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/parallel_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/parallel_test.py index 23dd978d..0a7da40a 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/parallel_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/parallel_test.py @@ -193,7 +193,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) # Mock the run_in_child_context function def mock_run_in_child_context(callable_func, name, child_config): @@ -231,7 +233,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) def mock_run_in_child_context(callable_func, name, child_config): return callable_func("mock-context") @@ -269,7 +273,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 @@ -307,7 +313,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 @@ -406,11 +414,15 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 - executor_context.create_child_context = lambda *args, **kwargs: Mock() + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + executor_context.create_child_context = lambda *args, **kwargs: child_context result = parallel_handler( callables, @@ -442,7 +454,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock(return_value="1") # noqa SLF001 @@ -496,7 +510,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2"]) # noqa SLF001 @@ -540,7 +556,9 @@ def get_checkpoint_result(self, operation_id): return mock_result execution_state = MockExecutionState() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 @@ -586,7 +604,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = ParallelConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) # Mock parallel context parallel_context = Mock() @@ -643,7 +663,9 @@ def get_checkpoint_result(self, operation_id): execution_state = MockExecutionState() config = ParallelConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) # Mock parallel context parallel_context = Mock() @@ -714,7 +736,9 @@ def task2(ctx): callables = [task1, task2] config = ParallelConfig() - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) # Track whether we're in first or second execution execution_count = 0 @@ -810,6 +834,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -871,6 +896,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -927,8 +953,12 @@ def get_checkpoint_result(self, operation_id): parallel_context._create_step_id_for_logical_step = Mock( # noqa SLF001 side_effect=["1", "2", "3"] ) - parallel_context.create_child_context = Mock(return_value=Mock()) - operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + child_context = Mock() + child_context.state.wrap_user_function = lambda func, *args, **kwargs: func + parallel_context.create_child_context = Mock(return_value=child_context) + operation_identifier = OperationIdentifier( + "test_op", OperationSubType.PARALLEL, "parent", "test_parallel" + ) # Execute parallel result = parallel_handler( @@ -986,6 +1016,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -1044,6 +1075,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} @@ -1109,6 +1141,7 @@ def get_checkpoint(op_id): mock_state.durable_execution_arn = "arn:test" mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) mock_state.create_checkpoint = Mock() + mock_state.wrap_user_function = lambda func, *args, **kwargs: func context_map = {} diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/step_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/step_test.py index 75ed7685..e7195512 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/step_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/step_test.py @@ -68,7 +68,7 @@ def test_step_handler_already_succeeded(): result = step_handler( mock_callable, mock_state, - OperationIdentifier("step1", None, "test_step"), + OperationIdentifier("step1", OperationSubType.STEP, None, "test_step"), None, mock_logger, ) @@ -97,7 +97,7 @@ def test_step_handler_already_succeeded_none_result(): result = step_handler( mock_callable, mock_state, - OperationIdentifier("step2", None, "test_step"), + OperationIdentifier("step2", OperationSubType.STEP, None, "test_step"), None, mock_logger, ) @@ -129,7 +129,7 @@ def test_step_handler_already_failed(): step_handler( mock_callable, mock_state, - OperationIdentifier("step3", None, "test_step"), + OperationIdentifier("step3", OperationSubType.STEP, None, "test_step"), None, mock_logger, ) @@ -158,7 +158,7 @@ def test_step_handler_started_at_most_once(): step_handler( mock_callable, mock_state, - OperationIdentifier("step4", None, "test_step"), + OperationIdentifier("step4", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -182,12 +182,13 @@ def test_step_handler_started_at_least_once(): config = StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) step_handler( mock_callable, mock_state, - OperationIdentifier("step5", None, "test_step"), + OperationIdentifier("step5", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -202,13 +203,14 @@ def test_step_handler_success_at_least_once(): config = StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger result = step_handler( mock_callable, mock_state, - OperationIdentifier("step6", None, "test_step"), + OperationIdentifier("step6", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -253,13 +255,14 @@ def test_step_handler_success_at_most_once(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger result = step_handler( mock_callable, mock_state, - OperationIdentifier("step7", None, "test_step"), + OperationIdentifier("step7", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -294,6 +297,7 @@ def test_step_handler_non_retriable_execution_error(): mock_state.durable_execution_arn = "test_arn" mock_callable = Mock(side_effect=ExecutionError("Do Not Retry")) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -301,7 +305,7 @@ def test_step_handler_non_retriable_execution_error(): step_handler( mock_callable, mock_state, - OperationIdentifier("step8", None, "test_step"), + OperationIdentifier("step8", OperationSubType.STEP, None, "test_step"), None, mock_logger, ) @@ -319,6 +323,7 @@ def test_step_handler_retry_success(): ) config = StepConfig(retry_strategy=mock_retry_strategy) mock_callable = Mock(side_effect=RuntimeError("Test error")) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -326,7 +331,7 @@ def test_step_handler_retry_success(): step_handler( mock_callable, mock_state, - OperationIdentifier("step9", None, "test_step"), + OperationIdentifier("step9", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -362,6 +367,7 @@ def test_step_handler_retry_exhausted(): ) config = StepConfig(retry_strategy=mock_retry_strategy) mock_callable = Mock(side_effect=RuntimeError("Test error")) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -369,7 +375,7 @@ def test_step_handler_retry_exhausted(): step_handler( mock_callable, mock_state, - OperationIdentifier("step10", None, "test_step"), + OperationIdentifier("step10", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -406,6 +412,7 @@ def test_step_handler_retry_interrupted_error(): config = StepConfig(retry_strategy=mock_retry_strategy) interrupted_error = StepInterruptedError("Step interrupted") mock_callable = Mock(side_effect=interrupted_error) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -413,7 +420,7 @@ def test_step_handler_retry_interrupted_error(): step_handler( mock_callable, mock_state, - OperationIdentifier("step11", None, "test_step"), + OperationIdentifier("step11", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -451,7 +458,7 @@ def test_step_handler_retry_with_existing_attempts(): step_handler( mock_callable, mock_state, - OperationIdentifier("step12", None, "test_step"), + OperationIdentifier("step12", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -487,7 +494,7 @@ def test_step_handler_pending_without_existing_attempts(): step_handler( mock_callable, mock_state, - OperationIdentifier("step12", None, "test_step"), + OperationIdentifier("step12", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -519,6 +526,7 @@ def test_step_handler_retry_handler_no_exception(mock_retry_handler): mock_retry_handler.return_value = None mock_callable = Mock(side_effect=RuntimeError("Test error")) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -529,7 +537,7 @@ def test_step_handler_retry_handler_no_exception(mock_retry_handler): step_handler( mock_callable, mock_state, - OperationIdentifier("step13", None, "test_step"), + OperationIdentifier("step13", OperationSubType.STEP, None, "test_step"), None, mock_logger, ) @@ -548,13 +556,14 @@ def test_step_handler_custom_serdes_success(): ) complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger step_handler( mock_callable, mock_state, - OperationIdentifier("step6", None, "test_step"), + OperationIdentifier("step6", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) @@ -588,7 +597,7 @@ def test_step_handler_custom_serdes_already_succeeded(): result = step_handler( mock_callable, mock_state, - OperationIdentifier("step1", None, "test_step"), + OperationIdentifier("step1", OperationSubType.STEP, None, "test_step"), StepConfig(serdes=CustomDictSerDes()), mock_logger, ) @@ -618,13 +627,16 @@ def test_step_immediate_response_get_checkpoint_called_twice(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger result = step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_1", None, "test_step"), + OperationIdentifier( + "step_immediate_1", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -652,13 +664,16 @@ def test_step_immediate_response_create_checkpoint_sync_at_most_once(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_2", None, "test_step"), + OperationIdentifier( + "step_immediate_2", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -679,13 +694,16 @@ def test_step_immediate_response_create_checkpoint_async_at_least_once(): config = StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY) mock_callable = Mock(return_value="success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_3", None, "test_step"), + OperationIdentifier( + "step_immediate_3", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -718,13 +736,16 @@ def test_step_immediate_response_immediate_success(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="immediate_success_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger result = step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_4", None, "test_step"), + OperationIdentifier( + "step_immediate_4", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -756,6 +777,7 @@ def test_step_immediate_response_immediate_failure(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) # Make the step function raise an error mock_callable = Mock(side_effect=RuntimeError("Step execution error")) + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -773,7 +795,9 @@ def test_step_immediate_response_immediate_failure(): step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_5", None, "test_step"), + OperationIdentifier( + "step_immediate_5", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -802,13 +826,16 @@ def test_step_immediate_response_no_immediate_response(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="normal_execution_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger result = step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_6", None, "test_step"), + OperationIdentifier( + "step_immediate_6", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -842,7 +869,9 @@ def test_step_immediate_response_already_completed(): result = step_handler( mock_callable, mock_state, - OperationIdentifier("step_immediate_7", None, "test_step"), + OperationIdentifier( + "step_immediate_7", OperationSubType.STEP, None, "test_step" + ), config, mock_logger, ) @@ -875,6 +904,7 @@ def test_step_executes_function_when_second_check_returns_started(): mock_state.get_checkpoint_result.side_effect = [not_found, started] mock_step_function = Mock(return_value="result") + mock_state.wrap_user_function.return_value = mock_step_function mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -882,7 +912,9 @@ def test_step_executes_function_when_second_check_returns_started(): func=mock_step_function, config=StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY), state=mock_state, - operation_identifier=OperationIdentifier("step-1", None, "test_step"), + operation_identifier=OperationIdentifier( + "step-1", OperationSubType.STEP, None, "test_step" + ), context_logger=mock_logger, ) result = executor.process() @@ -922,13 +954,14 @@ def test_step_creates_start_checkpoint_when_status_is_ready(): config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) mock_callable = Mock(return_value="ready_step_result") + mock_state.wrap_user_function.return_value = mock_callable mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger result = step_handler( mock_callable, mock_state, - OperationIdentifier("step_ready_1", None, "test_step"), + OperationIdentifier("step_ready_1", OperationSubType.STEP, None, "test_step"), config, mock_logger, ) diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/wait_for_condition_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/wait_for_condition_test.py index 7ed8dd18..da35a6fe 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/wait_for_condition_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/wait_for_condition_test.py @@ -19,6 +19,7 @@ OperationStatus, OperationType, StepDetails, + OperationSubType, ) from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.wait_for_condition import ( @@ -59,11 +60,15 @@ def test_wait_for_condition_first_execution_condition_met(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + def wait_strategy(state, attempt): return WaitForConditionDecision.stop_polling() @@ -92,11 +97,15 @@ def test_wait_for_condition_first_execution_condition_not_met(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + def wait_strategy(state, attempt): return WaitForConditionDecision.continue_waiting(Duration.from_seconds(30)) @@ -128,11 +137,15 @@ def test_wait_for_condition_already_succeeded(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -164,11 +177,15 @@ def test_wait_for_condition_already_succeeded_none_result(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -201,11 +218,15 @@ def test_wait_for_condition_already_failed(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -237,11 +258,15 @@ def test_wait_for_condition_retry_with_state(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -275,11 +300,15 @@ def test_wait_for_condition_retry_without_state(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -312,11 +341,15 @@ def test_wait_for_condition_retry_invalid_json_state(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -344,12 +377,16 @@ def test_wait_for_condition_check_function_exception(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): msg = "Test error" raise ValueError(msg) + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -378,7 +415,9 @@ def test_wait_for_condition_check_context(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) captured_context = None @@ -387,6 +426,8 @@ def check_func(state, context): captured_context = context return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -415,11 +456,15 @@ def test_wait_for_condition_delay_seconds_none(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + def wait_strategy(state, attempt): return WaitForConditionDecision(should_continue=True, delay=Duration()) @@ -455,11 +500,15 @@ def test_wait_for_condition_no_operation_in_checkpoint(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -504,11 +553,15 @@ def test_wait_for_condition_operation_no_step_details(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -536,11 +589,15 @@ def test_wait_for_condition_custom_delay_seconds(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + def wait_strategy(state, attempt): return WaitForConditionDecision( should_continue=True, delay=Duration.from_minutes(1) @@ -574,11 +631,15 @@ def test_wait_for_condition_attempt_number_passed_to_strategy(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + captured_attempt = None def wait_strategy(state, attempt): @@ -616,11 +677,15 @@ def test_wait_for_condition_attempt_sequence_is_monotonic(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + captured_attempts = [] def wait_strategy(state, attempt): @@ -730,11 +795,15 @@ def test_wait_for_condition_state_passed_to_strategy(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state * 2 + mock_state.wrap_user_function.return_value = check_func + captured_state = None def wait_strategy(state, attempt): @@ -766,11 +835,15 @@ def test_wait_for_condition_logger_with_log_info(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -801,11 +874,15 @@ def test_wait_for_condition_zero_delay_seconds(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + def wait_strategy(state, attempt): return WaitForConditionDecision( should_continue=True, delay=Duration.from_seconds(0) @@ -833,12 +910,16 @@ def test_wait_for_condition_custom_serdes_first_execution_condition_met(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} def check_func(state, context): return complex_result + mock_state.wrap_user_function.return_value = check_func + def wait_strategy(state, attempt): return WaitForConditionDecision.stop_polling() @@ -877,7 +958,9 @@ def test_wait_for_condition_custom_serdes_already_succeeded(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 @@ -919,7 +1002,9 @@ def test_wait_for_condition_pending(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): msg = "Should not be called" @@ -960,7 +1045,9 @@ def test_wait_for_condition_pending_without_next_attempt(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): msg = "Should not be called" @@ -999,11 +1086,15 @@ def test_wait_for_condition_checkpoint_called_once_with_is_sync_false(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -1040,7 +1131,9 @@ def test_wait_for_condition_immediate_success_without_executing_check(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) # Check function should NOT be called def check_func(state, context): @@ -1082,7 +1175,9 @@ def test_wait_for_condition_immediate_failure_without_executing_check(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) # Check function should NOT be called def check_func(state, context): @@ -1129,7 +1224,9 @@ def test_wait_for_condition_pending_suspends_without_executing_check(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) # Check function should NOT be called def check_func(state, context): @@ -1168,7 +1265,9 @@ def test_wait_for_condition_no_checkpoint_executes_check_function(): mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) check_called = False @@ -1177,6 +1276,8 @@ def check_func(state, context): check_called = True return state + 1 + mock_state.wrap_user_function.return_value = check_func + config = WaitForConditionConfig( initial_state=5, wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(), @@ -1212,7 +1313,9 @@ def test_wait_for_condition_already_completed_no_checkpoint_created(): mock_state.get_checkpoint_result.return_value = mock_result mock_logger = Mock(spec=Logger) - op_id = OperationIdentifier("op1", None, "test_wait") + op_id = OperationIdentifier( + "op1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wait" + ) def check_func(state, context): return state + 1 @@ -1255,6 +1358,7 @@ def test_wait_for_condition_executes_check_when_checkpoint_not_terminal(): mock_check_function = Mock(return_value="final_state") mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger + mock_state.wrap_user_function.return_value = mock_check_function def mock_wait_strategy(state, attempt): return WaitForConditionDecision( @@ -1268,7 +1372,9 @@ def mock_wait_strategy(state, attempt): wait_strategy=mock_wait_strategy, ), state=mock_state, - operation_identifier=OperationIdentifier("wfc-1", None, "test_wfc"), + operation_identifier=OperationIdentifier( + "wfc-1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wfc" + ), context_logger=mock_logger, ) result = executor.process() @@ -1296,6 +1402,7 @@ def test_wait_for_condition_executes_check_when_checkpoint_not_terminal_duplicat ) mock_check_function = Mock(return_value="final_state") + mock_state.wrap_user_function.return_value = mock_check_function mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger @@ -1309,7 +1416,9 @@ def mock_wait_strategy(state, attempt): wait_strategy=mock_wait_strategy, ), state=mock_state, - operation_identifier=OperationIdentifier("wfc-1", None, "test_wfc"), + operation_identifier=OperationIdentifier( + "wfc-1", OperationSubType.WAIT_FOR_CONDITION, None, "test_wfc" + ), context_logger=mock_logger, ) result = executor.process() diff --git a/packages/aws-durable-execution-sdk-python/tests/operation/wait_test.py b/packages/aws-durable-execution-sdk-python/tests/operation/wait_test.py index ca3083e5..07274826 100644 --- a/packages/aws-durable-execution-sdk-python/tests/operation/wait_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/operation/wait_test.py @@ -40,7 +40,7 @@ def test_wait_handler_already_completed(): wait_handler( seconds=10, state=mock_state, - operation_identifier=OperationIdentifier("wait1", None), + operation_identifier=OperationIdentifier("wait1", OperationSubType.WAIT, None), ) mock_state.get_checkpoint_result.assert_called_once_with("wait1") @@ -67,7 +67,9 @@ def test_wait_handler_not_completed(): wait_handler( seconds=30, state=mock_state, - operation_identifier=OperationIdentifier("wait2", None), + operation_identifier=OperationIdentifier( + "wait2", OperationSubType.WAIT, None + ), ) # Should be called twice: once before checkpoint, once after to check for immediate response @@ -105,7 +107,9 @@ def test_wait_handler_with_none_name(): with pytest.raises(SuspendExecution, match="Wait for 5 seconds"): wait_handler( state=mock_state, - operation_identifier=OperationIdentifier("wait3", None), + operation_identifier=OperationIdentifier( + "wait3", OperationSubType.WAIT, None + ), seconds=5, ) @@ -136,7 +140,9 @@ def test_wait_handler_with_existent(): with pytest.raises(SuspendExecution, match="Wait for 5 seconds"): wait_handler( state=mock_state, - operation_identifier=OperationIdentifier("wait4", None), + operation_identifier=OperationIdentifier( + "wait4", OperationSubType.WAIT, None + ), seconds=5, ) @@ -173,7 +179,9 @@ def test_wait_status_evaluation_after_checkpoint(): executor = WaitOperationExecutor( seconds=30, state=mock_state, - operation_identifier=OperationIdentifier("wait_eval", None, "test_wait"), + operation_identifier=OperationIdentifier( + "wait_eval", OperationSubType.WAIT, None, "test_wait" + ), ) # Act @@ -223,7 +231,7 @@ def test_wait_immediate_success_handling(): seconds=5, state=mock_state, operation_identifier=OperationIdentifier( - "wait_immediate", None, "immediate_wait" + "wait_immediate", OperationSubType.WAIT, None, "immediate_wait" ), ) @@ -264,7 +272,9 @@ def test_wait_no_immediate_response_suspends(): executor = WaitOperationExecutor( seconds=60, state=mock_state, - operation_identifier=OperationIdentifier("wait_suspend", None), + operation_identifier=OperationIdentifier( + "wait_suspend", OperationSubType.WAIT, None + ), ) # Act & Assert - verify suspend occurs @@ -299,7 +309,9 @@ def test_wait_already_completed_no_checkpoint(): executor = WaitOperationExecutor( seconds=10, state=mock_state, - operation_identifier=OperationIdentifier("wait_replay", None, "completed_wait"), + operation_identifier=OperationIdentifier( + "wait_replay", OperationSubType.WAIT, None, "completed_wait" + ), ) # Act @@ -338,7 +350,9 @@ def test_wait_with_various_durations(): executor = WaitOperationExecutor( seconds=seconds, state=mock_state, - operation_identifier=OperationIdentifier(f"wait_duration_{seconds}", None), + operation_identifier=OperationIdentifier( + f"wait_duration_{seconds}", OperationSubType.WAIT, None + ), ) # Act @@ -376,7 +390,9 @@ def test_wait_suspends_when_second_check_returns_started(): executor = WaitOperationExecutor( seconds=5, state=mock_state, - operation_identifier=OperationIdentifier("wait-1", None, "test_wait"), + operation_identifier=OperationIdentifier( + "wait-1", OperationSubType.WAIT, None, "test_wait" + ), ) with pytest.raises(SuspendExecution): @@ -408,7 +424,9 @@ def test_wait_suspends_when_second_check_returns_started_duplicate(): executor = WaitOperationExecutor( seconds=5, state=mock_state, - operation_identifier=OperationIdentifier("wait-1", None, "test_wait"), + operation_identifier=OperationIdentifier( + "wait-1", OperationSubType.WAIT, None, "test_wait" + ), ) with pytest.raises(SuspendExecution): diff --git a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py new file mode 100644 index 00000000..b26365c0 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py @@ -0,0 +1,787 @@ +import datetime +import logging +import unittest +from unittest.mock import MagicMock + +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + InvocationStatus, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, + PluginExecutor, + UserFunctionOutcome, + UserFunctionStartInfo, + UserFunctionEndInfo, +) + + +# region Dataclass Tests + +ERROR = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) +START_TS = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) +END_TS = datetime.datetime(2025, 1, 2, tzinfo=datetime.UTC) +LAMBDA_CTX = MagicMock() +LAMBDA_CTX.aws_request_id = "req-1" + +OPERATION_START_INFO = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_time=START_TS, +) +OPERATION_END_INFO = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + sub_type=OperationSubType.STEP, + name="my-op", + parent_id="parent-1", + start_time=START_TS, + status=OperationStatus.FAILED, + end_time=END_TS, + error=ERROR, +) + +INVOCATION_START_INFO = InvocationStartInfo( + request_id="req-1", + execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", + start_time=START_TS, + is_first_invocation=True, +) +INVOCATION_END_INFO = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_time=START_TS, + status=InvocationStatus.FAILED, + error=ERROR, + is_first_invocation=False, + end_time=END_TS, +) + +USER_FUNCTION_START_INFO = UserFunctionStartInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + sub_type=OperationSubType.STEP, + name="func", + parent_id="parent-1", + start_time=START_TS, +) + +USER_FUNCTION_END_INFO = UserFunctionEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + sub_type=OperationSubType.STEP, + name="func", + parent_id="parent-1", + start_time=START_TS, + is_replay_children=False, + attempt=1, + outcome=UserFunctionOutcome.FAILED, + end_time=END_TS, + error=ERROR, +) + + +class TestDataClasses(unittest.TestCase): + def test_operation_start_info(self): + self.assertEqual(OPERATION_START_INFO.sub_type, OperationSubType.CALLBACK) + self.assertEqual(OPERATION_START_INFO.name, "my-op") + self.assertEqual(OPERATION_START_INFO.parent_id, "parent-1") + self.assertEqual(OPERATION_START_INFO.start_time, START_TS) + + def test_operation_end_info(self): + self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED) + self.assertEqual(OPERATION_END_INFO.end_time, END_TS) + self.assertEqual(OPERATION_END_INFO.error, ERROR) + self.assertEqual(OPERATION_END_INFO.operation_type, OperationType.STEP) + self.assertEqual(OPERATION_END_INFO.sub_type, OperationSubType.STEP) + self.assertEqual(OPERATION_END_INFO.name, "my-op") + self.assertEqual(OPERATION_END_INFO.parent_id, "parent-1") + self.assertEqual(OPERATION_END_INFO.operation_id, "op-1") + self.assertEqual(OPERATION_END_INFO.status, OperationStatus.FAILED) + self.assertEqual(OPERATION_END_INFO.operation_id, "op-1") + + def test_invocation_start_info(self): + self.assertEqual(INVOCATION_START_INFO.request_id, "req-1") + self.assertEqual( + INVOCATION_START_INFO.execution_arn, + "arn:aws:lambda:us-east-1:123:durable:abc", + ) + self.assertEqual(INVOCATION_START_INFO.start_time, START_TS) + self.assertTrue(INVOCATION_START_INFO.is_first_invocation) + + def test_invocation_end_info(self): + self.assertEqual(INVOCATION_END_INFO.request_id, "req-1") + self.assertEqual(INVOCATION_END_INFO.execution_arn, "arn:test") + self.assertEqual(INVOCATION_END_INFO.start_time, START_TS) + self.assertFalse(INVOCATION_END_INFO.is_first_invocation) + self.assertEqual(INVOCATION_END_INFO.status, InvocationStatus.FAILED) + self.assertEqual(INVOCATION_END_INFO.error.message, "boom") + self.assertEqual(INVOCATION_END_INFO.end_time, END_TS) + + def test_user_function_start_info(self): + self.assertEqual(USER_FUNCTION_START_INFO.operation_id, "op-1") + self.assertEqual(USER_FUNCTION_START_INFO.operation_type, OperationType.STEP) + self.assertEqual(USER_FUNCTION_START_INFO.sub_type, OperationSubType.STEP) + self.assertEqual(USER_FUNCTION_START_INFO.name, "func") + self.assertEqual(USER_FUNCTION_START_INFO.parent_id, "parent-1") + self.assertEqual(USER_FUNCTION_START_INFO.start_time, START_TS) + + def test_user_function_end_info(self): + self.assertEqual(USER_FUNCTION_END_INFO.operation_id, "op-1") + self.assertEqual(USER_FUNCTION_END_INFO.operation_type, OperationType.STEP) + self.assertEqual(USER_FUNCTION_END_INFO.sub_type, OperationSubType.STEP) + self.assertEqual(USER_FUNCTION_END_INFO.name, "func") + self.assertEqual(USER_FUNCTION_END_INFO.parent_id, "parent-1") + self.assertEqual(USER_FUNCTION_END_INFO.start_time, START_TS) + self.assertFalse(USER_FUNCTION_END_INFO.is_replay_children) + self.assertEqual(USER_FUNCTION_END_INFO.attempt, 1) + self.assertEqual(USER_FUNCTION_END_INFO.outcome, UserFunctionOutcome.FAILED) + self.assertEqual(USER_FUNCTION_END_INFO.end_time, END_TS) + self.assertEqual(USER_FUNCTION_END_INFO.error.message, "boom") + + +# endregion Dataclass Tests + + +# region DurableInstrumentationPlugin Tests +class TestDurableInstrumentationPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + self.assertIsNone(plugin.on_invocation_start(INVOCATION_START_INFO)) + self.assertIsNone(plugin.on_invocation_end(INVOCATION_END_INFO)) + self.assertIsNone(plugin.on_operation_start(OPERATION_START_INFO)) + self.assertIsNone(plugin.on_operation_end(OPERATION_END_INFO)) + self.assertIsNone(plugin.on_user_function_start(USER_FUNCTION_START_INFO)) + self.assertIsNone(plugin.on_user_function_end(USER_FUNCTION_END_INFO)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + + plugin.on_invocation_start(INVOCATION_START_INFO) + plugin.on_operation_start(OPERATION_START_INFO) + + self.assertEqual( + ["invocation_start:req-1", "operation_start:op-2"], plugin.calls + ) + + +# endregion DurableInstrumentationPlugin Tests + + +# region PluginExecutor Tests + + +class TestPluginExecutorInit(unittest.TestCase): + def test_init_with_none(self): + executor = PluginExecutor(plugins=None) + self.assertEqual(executor._plugins, []) + + def test_init_with_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertEqual(executor._plugins, []) + + def test_init_with_plugins(self): + p1 = _NoOpPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + self.assertEqual(len(executor._plugins), 2) + + +class TestPluginExecutor(unittest.TestCase): + def test_no_thread_pool_when_plugins_is_none(self): + """Tests that PluginExecutor does not create a thread pool when plugins is empty.""" + executor = PluginExecutor(plugins=None) + self.assertIsNone(executor._executor) + + def test_no_thread_pool_when_plugins_is_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertIsNone(executor._executor) + + def test_thread_pool_created_when_plugins_provided(self): + executor = PluginExecutor(plugins=[_NoOpPlugin()]) + with executor.run(): + self.assertIsNotNone(executor._executor) + + def test_start_is_noop_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + with executor.run(): + pass + + def test_on_invocation_start_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=False, + ) + + def test_on_invocation_end_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=False, + ) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + # Should not raise + executor.on_invocation_end( + output=output, + ) + + def test_on_operation_action_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = None + + # Should not raise + executor.on_operation_action(update) + + def test_on_operation_update_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = None + op.start_time = START_TS + op.end_time = END_TS + op.status = OperationStatus.SUCCEEDED + op.step_details = MagicMock() + op.step_details.attempt = 1 + op.step_details.error = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + # Should not raise + executor.on_operation_update(op) + + +class TestPluginExecutorExecutePlugins(unittest.TestCase): + """Tests for the execute_plugins dispatch method.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def test_dispatch_invocation_start_info(self): + with self.executor.run(): + self.executor.execute_plugins(INVOCATION_START_INFO, sync=True) + self.assertIn("invocation_start:req-1", self.plugin.calls) + + def test_dispatch_invocation_end_info(self): + with self.executor.run(): + self.executor.execute_plugins(INVOCATION_END_INFO, sync=True) + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_dispatch_operation_end_info(self): + with self.executor.run(): + self.executor.execute_plugins(OPERATION_END_INFO, sync=False) + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_dispatch_operation_start_info(self): + with self.executor.run(): + self.executor.execute_plugins(OPERATION_START_INFO, sync=False) + self.assertIn("operation_start:op-2", self.plugin.calls) + + def test_dispatch_user_function_start_info(self): + with self.executor.run(): + self.executor.execute_plugins(USER_FUNCTION_START_INFO, sync=True) + self.assertIn("user_function_start:op-1", self.plugin.calls) + + def test_dispatch_user_function_end_info(self): + with self.executor.run(): + self.executor.execute_plugins(USER_FUNCTION_END_INFO, sync=True) + self.assertIn("user_function_end:op-1", self.plugin.calls) + + def test_dispatch_unknown_type_logs_exception(self): + """Unknown info types should be caught and logged.""" + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + with self.executor.run(): + self.executor.execute_plugins("not a valid info type", sync=True) + + def test_plugin_exception_is_swallowed(self): + """If a plugin raises, the exception is logged and execution continues.""" + failing_plugin = _FailingPlugin() + tracking_plugin = _TrackingPlugin() + executor = PluginExecutor(plugins=[failing_plugin, tracking_plugin]) + + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + with executor.run(): + executor.execute_plugins(OPERATION_START_INFO, sync=True) + + # The second plugin should still have been called + self.assertIn("operation_start:op-2", tracking_plugin.calls) + + def test_multiple_plugins_all_called(self): + p1 = _TrackingPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + + with executor.run(): + executor.execute_plugins(OPERATION_START_INFO, sync=True) + + self.assertIn("operation_start:op-2", p1.calls) + self.assertIn("operation_start:op-2", p2.calls) + + +class TestPluginExecutorOnInvocationStart(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_start.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_operation(self, start_time=None): + op = MagicMock() + op.start_time = start_time or self.ts + return op + + def test_first_invocation_fires_invocation_start(self): + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=False, + ) + + self.assertEqual("arn:exec", self.executor._invocation_status.execution_arn) + self.assertEqual( + LAMBDA_CTX.aws_request_id, self.executor._invocation_status.request_id + ) + self.assertEqual(START_TS, self.executor._invocation_status.start_time) + self.assertFalse(self.executor._invocation_status.is_first_invocation) + + self.assertIsNone(self.executor._invocation_status) + + # ExecutionStartInfo dispatches to on_invocation_start in match + # InvocationStartInfo dispatches to on_invocation_start in match + # So we expect two invocation_start calls + invocation_calls = [ + c for c in self.plugin.calls if c.startswith("invocation_start") + ] + self.assertEqual(1, len(invocation_calls)) + + def test_replay_invocation_fires_invocation_start(self): + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=True, + ) + + # Only InvocationStartInfo should be dispatched (not ExecutionStartInfo) + invocation_calls = [ + c for c in self.plugin.calls if c.startswith("invocation_start") + ] + self.assertEqual(1, len(invocation_calls)) + + def test_none_context_uses_none_request_id(self): + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=None, + execution_start_time=START_TS, + is_first_invocation=False, + ) + + invocation_calls = [ + c for c in self.plugin.calls if c.startswith("invocation_start") + ] + # Both ExecutionStartInfo and InvocationStartInfo dispatched + self.assertEqual(len(invocation_calls), 1) + # request_id should be None + self.assertIn("invocation_start:None", self.plugin.calls) + + +class TestPluginExecutorOnInvocationEnd(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_end.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_operation(self, start_ts=None, end_ts=None): + op = MagicMock() + op.start_time = start_ts or self.ts + op.end_time = end_ts + return op + + def test_succeeded_fires_invocation_end(self): + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_failed_fires_invocation_end(self): + output = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, result=None, error=ERROR + ) + + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_pending_fires_invocation_end(self): + output = DurableExecutionInvocationOutput( + status=InvocationStatus.PENDING, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + execution_arn="arn:exec", + lambda_context=LAMBDA_CTX, + execution_start_time=START_TS, + is_first_invocation=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-1", self.plugin.calls) + + +class TestPluginExecutorOnOperationAction(unittest.TestCase): + """Tests for PluginExecutor.on_operation_action.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def test_start_action_fires_operation_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + with self.executor.run(): + self.executor.on_operation_action(update) + + self.assertIn("operation_start:op-1", self.plugin.calls) + + def test_non_start_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.SUCCEED + update.operation_id = "op-1" + + self.executor.on_operation_action(update) + + self.assertEqual(self.plugin.calls, []) + + def test_fail_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.FAIL + update.operation_id = "op-1" + + self.executor.on_operation_action(update) + + self.assertEqual(self.plugin.calls, []) + + +class TestPluginExecutorOnOperationUpdate(unittest.TestCase): + """Tests for PluginExecutor.on_operation_update.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def _make_operation( + self, + status=OperationStatus.SUCCEEDED, + step_details=None, + callback_details=None, + chained_invoke_details=None, + context_details=None, + ): + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = "parent-1" + op.start_time = START_TS + op.end_time = END_TS + op.status = status + op.step_details = step_details + op.callback_details = callback_details + op.chained_invoke_details = chained_invoke_details + op.context_details = context_details + return op + + def test_terminal_status_without_step_details_fires_operation_only(self): + op = self._make_operation(status=OperationStatus.FAILED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_non_terminal_status_without_step_details_fires_nothing(self): + op = self._make_operation(status=OperationStatus.STARTED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertEqual(self.plugin.calls, []) + + def test_ready_status_fires_nothing(self): + op = self._make_operation(status=OperationStatus.READY, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertEqual(self.plugin.calls, []) + + def test_timed_out_is_terminal(self): + op = self._make_operation(status=OperationStatus.TIMED_OUT, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_cancelled_is_terminal(self): + op = self._make_operation(status=OperationStatus.CANCELLED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_stopped_is_terminal(self): + op = self._make_operation(status=OperationStatus.STOPPED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + +class TestPluginExecutorExtractError(unittest.TestCase): + """Tests for PluginExecutor._extract_error static method.""" + + def test_extract_error_from_step_details(self): + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = ERROR + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_from_callback_details(self): + op = MagicMock() + op.step_details = None + op.callback_details = MagicMock() + op.callback_details.error = ERROR + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_from_chained_invoke_details(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = MagicMock() + op.chained_invoke_details.error = ERROR + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_from_context_details(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = MagicMock() + op.context_details.error = ERROR + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + def test_extract_error_returns_none_when_no_error(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertIsNone(result) + + def test_extract_error_step_details_no_error(self): + """step_details exists but has no error - falls through to callback.""" + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = None + op.callback_details = MagicMock() + op.callback_details.error = ERROR + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "boom") + + +class TestPluginExecutorIsTerminalStatus(unittest.TestCase): + """Tests for PluginExecutor._is_terminal_status static method.""" + + def test_succeeded_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.SUCCEEDED)) + + def test_failed_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.FAILED)) + + def test_timed_out_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.TIMED_OUT)) + + def test_cancelled_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.CANCELLED)) + + def test_stopped_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.STOPPED)) + + def test_started_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.STARTED)) + + def test_pending_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.PENDING)) + + def test_ready_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.READY)) + + +# endregion PluginExecutor Tests + + +# region Helper Classes + + +class _NoOpPlugin(DurableInstrumentationPlugin): + """Concrete subclass that inherits all default no-op methods.""" + + pass + + +class _TrackingPlugin(DurableInstrumentationPlugin): + """Concrete subclass that tracks calls to all hooks.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + self.calls.append(f"invocation_start:{info.request_id}") + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + self.calls.append(f"invocation_end:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info: OperationEndInfo) -> None: + self.calls.append(f"operation_end:{info.operation_id}") + + def on_user_function_start(self, info: UserFunctionStartInfo) -> None: + self.calls.append(f"user_function_start:{info.operation_id}") + + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + self.calls.append(f"user_function_end:{info.operation_id}") + + +class _FailingPlugin(DurableInstrumentationPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("boom") + + def on_execution_end(self, info): + raise RuntimeError("boom") + + def on_invocation_start(self, info): + raise RuntimeError("boom") + + def on_invocation_end(self, info): + raise RuntimeError("boom") + + def on_operation_start(self, info): + raise RuntimeError("boom") + + def on_operation_end(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("boom") + + +# endregion Helper Classes + + +if __name__ == "__main__": + unittest.main() diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index 0152ca6c..5e7e7fb8 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -9,7 +9,7 @@ import time import unittest.mock from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, call, patch, create_autospec import pytest @@ -36,6 +36,11 @@ OperationUpdate, StateOutput, StepDetails, + OperationSubType, +) +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, + PluginExecutor, ) from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, @@ -332,7 +337,7 @@ def test_checkpointerd_result_is_pending(): assert result_no_op.is_pending() is False -def test_checkpointerd_result_is_ready(): +def test_checkpointed_result_is_ready(): """Test CheckpointedResult.is_ready method.""" operation = Operation( operation_id="op1", @@ -405,6 +410,7 @@ def test_execution_state_creation(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) assert state.durable_execution_arn == "test_arn" assert state.operations == {} @@ -425,6 +431,7 @@ def test_get_checkpoint_result_success_with_result(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -446,6 +453,7 @@ def test_get_checkpoint_result_success_without_step_details(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -467,6 +475,7 @@ def test_get_checkpoint_result_operation_not_succeeded(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -483,6 +492,7 @@ def test_get_checkpoint_result_operation_not_found(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("nonexistent") @@ -500,6 +510,7 @@ def test_create_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -530,6 +541,7 @@ def test_create_checkpoint_with_none(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with None and is_sync=False enqueues an empty checkpoint @@ -554,6 +566,7 @@ def test_create_checkpoint_with_no_args(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with no args and is_sync=False enqueues an empty checkpoint @@ -582,6 +595,7 @@ def test_get_checkpoint_result_started(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -675,6 +689,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) state.fetch_paginated_operations( @@ -773,6 +788,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -811,6 +827,7 @@ def test_fetch_paginated_operations_logs_error(caplog): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -920,6 +937,7 @@ def test_checkpoint_batch_respects_default_max_items_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -988,6 +1006,7 @@ def test_collect_checkpoint_batch_respects_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1021,6 +1040,7 @@ def test_collect_checkpoint_batch_uses_overflow_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Put operations in overflow queue @@ -1072,6 +1092,7 @@ def test_collect_checkpoint_batch_handles_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue empty checkpoint @@ -1107,6 +1128,7 @@ def test_collect_checkpoint_batch_returns_empty_when_stopped(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal stop before collecting @@ -1128,6 +1150,7 @@ def test_parent_child_relationship_building(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1169,6 +1192,7 @@ def test_descendant_cancellation_when_parent_completes(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1208,6 +1232,7 @@ def test_rejection_of_operations_from_completed_parents(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1257,6 +1282,7 @@ def test_nested_parallel_operations_deep_hierarchy(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build deep hierarchy: grandparent -> parent -> child @@ -1313,6 +1339,7 @@ def test_synchronous_checkpoint_blocks_until_complete(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -1361,6 +1388,7 @@ def test_concurrent_access_to_operations_dictionary(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add initial operation @@ -1430,6 +1458,7 @@ def test_stop_checkpointing_signals_background_thread(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Verify event is not set initially @@ -1523,6 +1552,7 @@ def test_create_checkpoint_sync_with_parent_id(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1574,6 +1604,7 @@ def test_create_checkpoint_sync_rejects_orphaned_operation(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child relationship @@ -1638,6 +1669,7 @@ def test_mark_orphans_handles_cycles(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Manually create a cycle (shouldn't happen in practice, but test defensive code) @@ -1668,6 +1700,7 @@ def test_checkpoint_batches_forever_exception_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation @@ -1715,6 +1748,7 @@ def test_collect_checkpoint_batch_shutdown_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add operation to queue (would be a non-essential async checkpoint in practice) @@ -1744,6 +1778,7 @@ def test_collect_checkpoint_batch_shutdown_empty_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal shutdown with empty queue @@ -1771,6 +1806,7 @@ def test_collect_checkpoint_batch_overflow_put_back(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1816,6 +1852,7 @@ def test_create_checkpoint_sync_with_none_operation_update(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Simulate background processor @@ -1848,6 +1885,7 @@ def test_checkpoint_batches_forever_exception_with_no_sync_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -1887,6 +1925,7 @@ def test_collect_checkpoint_batch_size_limit_during_time_window(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1940,6 +1979,7 @@ def test_collect_checkpoint_batch_respects_max_operations_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1983,6 +2023,7 @@ def test_collect_checkpoint_batch_time_window_expires(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2030,6 +2071,7 @@ def test_collect_checkpoint_batch_empty_overflow_queue_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Ensure overflow queue is empty (it should be by default) @@ -2067,6 +2109,7 @@ def test_collect_checkpoint_batch_overflow_queue_hits_operation_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2106,6 +2149,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2155,6 +2199,7 @@ def test_checkpoint_error_signals_completion_events_with_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation with completion event @@ -2211,6 +2256,7 @@ def test_synchronous_caller_receives_error_on_background_thread_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2288,6 +2334,7 @@ def test_exception_propagates_through_threadpoolexecutor(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue an operation @@ -2321,6 +2368,7 @@ def test_multiple_sync_operations_all_remain_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create multiple synchronous operations @@ -2372,6 +2420,7 @@ def test_async_operations_not_affected_by_error_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -2409,6 +2458,7 @@ def test_mixed_sync_async_operations_only_sync_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create sync operation with completion event @@ -2469,6 +2519,7 @@ def test_create_checkpoint_accepts_is_sync_parameter(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2503,6 +2554,7 @@ def test_create_checkpoint_default_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2549,6 +2601,7 @@ def test_create_checkpoint_explicit_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2590,6 +2643,7 @@ def test_create_checkpoint_is_sync_false_no_completion_event(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2620,6 +2674,7 @@ def test_create_checkpoint_is_sync_false_returns_immediately(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2658,6 +2713,7 @@ def test_create_checkpoint_with_none_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with None (will block) @@ -2694,6 +2750,7 @@ def test_create_checkpoint_no_args_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with no arguments (will block) @@ -2733,6 +2790,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit_final(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2788,6 +2846,7 @@ def test_create_checkpoint_blocks_until_completion_default(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2859,6 +2918,7 @@ def test_create_checkpoint_blocks_until_completion_explicit_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2930,6 +2990,7 @@ def test_create_checkpoint_completion_event_created_and_signaled(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2994,6 +3055,7 @@ def test_create_checkpoint_completion_event_not_signaled_on_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3080,6 +3142,7 @@ def test_create_checkpoint_caller_remains_blocked_on_background_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3162,6 +3225,7 @@ def test_create_checkpoint_multiple_sync_calls_all_block(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) num_callers = 3 @@ -3238,6 +3302,7 @@ def test_create_checkpoint_sync_with_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Track timing and completion @@ -3296,6 +3361,7 @@ def test_create_checkpoint_sync_success(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3304,7 +3370,7 @@ def test_create_checkpoint_sync_success(): try: operation_update = OperationUpdate.create_step_start( - OperationIdentifier("test-op", None, "test-step") + OperationIdentifier("test-op", OperationSubType.STEP, None, "test-step") ) # Should work normally without error @@ -3330,6 +3396,7 @@ def test_create_checkpoint_sync_unwraps_background_thread_error(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3338,7 +3405,7 @@ def test_create_checkpoint_sync_unwraps_background_thread_error(): try: operation_update = OperationUpdate.create_step_start( - OperationIdentifier("test-op", None, "test-step") + OperationIdentifier("test-op", OperationSubType.STEP, None, "test-step") ) # Should raise the original RuntimeError, not BackgroundThreadError @@ -3363,6 +3430,7 @@ def test_create_checkpoint_sync_always_synchronous(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3371,7 +3439,7 @@ def test_create_checkpoint_sync_always_synchronous(): try: operation_update = OperationUpdate.create_step_start( - OperationIdentifier("test-op", None, "test-step") + OperationIdentifier("test-op", OperationSubType.STEP, None, "test-step") ) # Should block until completion (synchronous behavior) @@ -3400,6 +3468,7 @@ def test_state_replay_mode(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3433,6 +3502,7 @@ def test_state_replay_mode_with_timed_out(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3464,6 +3534,7 @@ def test_collect_checkpoint_batch_coalesces_many_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3497,6 +3568,7 @@ def test_collect_checkpoint_batch_empty_checkpoints_with_real_ops_respects_limit initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3536,6 +3608,7 @@ def test_collect_checkpoint_batch_overflow_coalesces_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3576,6 +3649,7 @@ def test_checkpoint_batches_forever_single_api_call_for_many_empty_checkpoints() initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3624,6 +3698,7 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3676,6 +3751,7 @@ def test_execution_state_get_execution_operation_no_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3707,6 +3783,7 @@ def test_initial_execution_state_get_execution_operation_wrong_type(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3743,8 +3820,443 @@ def test_initial_execution_state_get_input_payload_none(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) result = state.get_input_payload() assert result is None + + +# region Plugin Executor Integration Tests + + +class _RecordingPlugin(DurableInstrumentationPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append("execution_end") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append("invocation_end") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_user_function_start(self, info): + self.calls.append(f"user_function_start:{info.operation_id}") + + def on_user_function_end(self, info): + self.calls.append(f"user_function_end:{info.operation_id}") + + +def test_execution_state_accepts_plugin_executor_parameter(): + """Test that ExecutionState can be created with a plugin_executor parameter.""" + mock_client = Mock(spec=LambdaClient) + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + assert state._plugin_executor is plugin_executor + + +def test_plugin_executor_on_operation_action_called_on_checkpoint(): + """Test that plugin_executor.on_operation_action is called for each update after checkpoint.""" + mock_client = create_autospec(LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + # Start background thread + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_operation_action is called for START updates + assert "operation_start:step-1" in plugin.calls + + +def test_plugin_executor_on_operation_update_called_for_terminal_operations(): + """Test that plugin_executor.on_operation_update is called for terminal operations.""" + mock_client = create_autospec(LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="my-step", + payload='"done"', + ) + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + assert "operation_end:step-1" in plugin.calls + + +def test_plugin_executor_not_called_for_non_terminal_operations(): + """Test that plugin_executor.on_operation_update does not fire for non-terminal operations.""" + mock_client = create_autospec(spec=LambdaClient) + + # Return a STARTED step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=None, + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_operation_action fires for START + assert "operation_start:step-1" in plugin.calls + # But on_operation_update should NOT fire operation_end for STARTED status + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + + +def test_plugin_executor_called_for_multiple_updates_in_batch(): + """Test that plugin_executor is called for each update in a batch.""" + mock_client = create_autospec(spec=LambdaClient) + + # Return multiple operations from checkpoint + step_op1 = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result1"'), + ) + step_op2 = Operation( + operation_id="step-2", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result2"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op1, step_op2], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + config = CheckpointBatcherConfig( + max_batch_time_seconds=0.2, + max_batch_operations=10, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + batcher_config=config, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + op1 = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-1", + ) + op2 = OperationUpdate( + operation_id="step-2", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-2", + ) + # Enqueue both without blocking so they batch together + state.create_checkpoint(op1, is_sync=False) + state.create_checkpoint(op2, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # Both operations should have triggered on_operation_action + assert "operation_start:step-1" in plugin.calls + assert "operation_start:step-2" in plugin.calls + # Both terminal operations should have triggered on_operation_update + assert "operation_end:step-1" in plugin.calls + assert "operation_end:step-2" in plugin.calls + + +def test_plugin_executor_not_called_on_checkpoint_failure(): + """Test that plugin_executor is NOT called when checkpoint API fails.""" + mock_client = create_autospec(spec=LambdaClient) + mock_client.checkpoint.side_effect = RuntimeError("API error") + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + + with pytest.raises(BackgroundThreadError): + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # Plugin should NOT have been called since checkpoint failed + assert "operation_start:step-1" not in plugin.calls + assert "operation_end:step-1" not in plugin.calls + + +def test_plugin_executor_exception_does_not_break_checkpointing(): + """Test that a plugin exception does not break the checkpoint processing loop.""" + mock_client = create_autospec(spec=LambdaClient) + + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + class _ExplodingPlugin(DurableInstrumentationPlugin): + def on_operation_start(self, info): + raise RuntimeError("plugin exploded") + + def on_operation_end(self, info): + raise RuntimeError("plugin exploded") + + exploding_plugin = _ExplodingPlugin() + plugin_executor = PluginExecutor(plugins=[exploding_plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + # Should not raise even though plugin explodes + state.create_checkpoint(operation_update, is_sync=True) + + # Checkpoint should still have been called successfully + assert mock_client.checkpoint.call_count == 1 + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_not_called_for_pending_operations(): + """Test that plugin_executor.on_operation_update fires on_user_function_end for PENDING operations.""" + mock_client = create_autospec(spec=LambdaClient) + + # Return a PENDING step operation from checkpoint (simulates a retry scenario) + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails( + attempt=1, + result=None, + error=ErrorObject( + message="transient failure", + type="RetryableError", + data=None, + stack_trace=None, + ), + ), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # operation_end should NOT fire for PENDING (only for terminal statuses) + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + + +# endregion Plugin Executor Integration Tests